Skip to content

Commit

Permalink
Fix UDA samples to work with Impala 1.x
Browse files Browse the repository at this point in the history
  • Loading branch information
skye committed Dec 17, 2013
1 parent 25952a0 commit 9d9a187
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 49 deletions.
86 changes: 73 additions & 13 deletions uda-sample-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ bool TestCount() {
UdaTestHarness<BigIntVal, BigIntVal, IntVal> test(
CountInit, CountUpdate, CountMerge, NULL, CountFinalize);

// Run the UDA over empty input
vector<IntVal> empty;
if (!test.Execute(empty, BigIntVal(0))) {
cerr << test.GetErrorMsg() << endl;
return false;
}

// Run the UDA over 10000 non-null values
vector<IntVal> no_nulls;
no_nulls.resize(10000);
Expand All @@ -52,15 +59,23 @@ bool TestCount() {
}

bool TestAvg() {
UdaTestHarness<DoubleVal, BufferVal, DoubleVal> test(
UdaTestHarness<StringVal, StringVal, DoubleVal> test(
AvgInit, AvgUpdate, AvgMerge, NULL, AvgFinalize);
test.SetIntermediateSize(16);

vector<DoubleVal> vals;

// Test empty input
if (!test.Execute<DoubleVal>(vals, StringVal::null())) {
cerr << test.GetErrorMsg() << endl;
return false;
}

// Test values
for (int i = 0; i < 1001; ++i) {
vals.push_back(DoubleVal(i));
}
if (!test.Execute<DoubleVal>(vals, DoubleVal(500))) {
if (!test.Execute<DoubleVal>(vals, StringVal("500"))) {
cerr << test.GetErrorMsg() << endl;
return false;
}
Expand All @@ -74,10 +89,18 @@ bool TestStringConcat() {
StringConcatFinalize);

vector<StringVal> values;
vector<StringVal> separators;

// Test empty input
if (!test.Execute(values, separators, StringVal::null())) {
cerr << test.GetErrorMsg() << endl;
return false;
}

// Test values
values.push_back("Hello");
values.push_back("World");

vector<StringVal> separators;
for(int i = 0; i < values.size(); ++i) {
separators.push_back(",");
}
Expand All @@ -101,24 +124,51 @@ bool FuzzyCompare(const DoubleVal& x, const DoubleVal& y) {
return fabs(x.val - y.val) < 0.00001;
}

// Reimplementation of FuzzyCompare that parses doubles encoded as StringVals.
// TODO: This can be removed when separate intermediate types are supported in Impala 2.0
bool FuzzyCompareStrings(const StringVal& x, const StringVal& y) {
if (x.is_null && y.is_null) return true;
if (x.is_null || y.is_null) return false;
// Note that atof expects null-terminated strings, which is not guaranteed by
// StringVals. However, since our UDAs serialize double to StringVals via stringstream,
// we know the serialized StringVals will be null-terminated in this case.
double x_val = atof(reinterpret_cast<char*>(x.ptr));
double y_val = atof(reinterpret_cast<char*>(y.ptr));
return fabs(x_val - y_val) < 0.00001;
}

bool TestVariance() {
// Setup the test UDAs.
UdaTestHarness<DoubleVal, StringVal, DoubleVal> simple_variance(
UdaTestHarness<StringVal, StringVal, DoubleVal> simple_variance(
VarianceInit, VarianceUpdate, VarianceMerge, NULL, VarianceFinalize);
simple_variance.SetResultComparator(FuzzyCompare);
simple_variance.SetResultComparator(FuzzyCompareStrings);

UdaTestHarness<DoubleVal, StringVal, DoubleVal> knuth_variance(
UdaTestHarness<StringVal, StringVal, DoubleVal> knuth_variance(
KnuthVarianceInit, KnuthVarianceUpdate, KnuthVarianceMerge, NULL,
KnuthVarianceFinalize);
knuth_variance.SetResultComparator(FuzzyCompare);
knuth_variance.SetResultComparator(FuzzyCompareStrings);

UdaTestHarness<DoubleVal, StringVal, DoubleVal> stddev(
UdaTestHarness<StringVal, StringVal, DoubleVal> stddev(
KnuthVarianceInit, KnuthVarianceUpdate, KnuthVarianceMerge, NULL,
StdDevFinalize);
stddev.SetResultComparator(FuzzyCompare);
stddev.SetResultComparator(FuzzyCompareStrings);

// Initialize the test values.
// Test empty input
vector<DoubleVal> vals;
if (!simple_variance.Execute(vals, StringVal::null())) {
cerr << "Simple variance: " << simple_variance.GetErrorMsg() << endl;
return false;
}
if (!knuth_variance.Execute(vals, StringVal::null())) {
cerr << "Knuth variance: " << knuth_variance.GetErrorMsg() << endl;
return false;
}
if (!stddev.Execute(vals, StringVal::null())) {
cerr << "Stddev: " << stddev.GetErrorMsg() << endl;
return false;
}

// Initialize the test values.
double sum = 0;
for (int i = 0; i < 1001; ++i) {
vals.push_back(DoubleVal(i));
Expand All @@ -133,16 +183,26 @@ bool TestVariance() {
expected_variance /= (vals.size() - 1);
double expected_stddev = sqrt(expected_variance);

stringstream expected_variance_ss;
expected_variance_ss << expected_variance;
string expected_variance_str = expected_variance_ss.str();
StringVal expected_variance_sv(expected_variance_str.c_str());

stringstream expected_stddev_ss;
expected_stddev_ss << expected_stddev;
string expected_stddev_str = expected_stddev_ss.str();
StringVal expected_stddev_sv(expected_stddev_str.c_str());

// Run the tests
if (!simple_variance.Execute(vals, DoubleVal(expected_variance))) {
if (!simple_variance.Execute(vals, expected_variance_sv)) {
cerr << "Simple variance: " << simple_variance.GetErrorMsg() << endl;
return false;
}
if (!knuth_variance.Execute(vals, DoubleVal(expected_variance))) {
if (!knuth_variance.Execute(vals, expected_variance_sv)) {
cerr << "Knuth variance: " << knuth_variance.GetErrorMsg() << endl;
return false;
}
if (!stddev.Execute(vals, DoubleVal(expected_stddev))) {
if (!stddev.Execute(vals, expected_stddev_sv)) {
cerr << "Stddev: " << stddev.GetErrorMsg() << endl;
return false;
}
Expand Down
55 changes: 40 additions & 15 deletions uda-sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,26 @@

#include "uda-sample.h"
#include <assert.h>
#include <sstream>

using namespace impala_udf;
using namespace std;

template <typename T>
StringVal ToStringVal(FunctionContext* context, const T& val) {
stringstream ss;
ss << val;
string str = ss.str();
StringVal string_val(context, str.size());
memcpy(string_val.ptr, str.c_str(), str.size());
return string_val;
}

template <>
StringVal ToStringVal<DoubleVal>(FunctionContext* context, const DoubleVal& val) {
if (val.is_null) return StringVal::null();
return ToStringVal(context, val.val);
}

// ---------------------------------------------------------------------------
// This is a sample of implementing a COUNT aggregate function.
Expand Down Expand Up @@ -46,30 +64,37 @@ struct AvgStruct {
int64_t count;
};

void AvgInit(FunctionContext* context, BufferVal* val) {
assert(sizeof(AvgStruct) == 16);
memset(*val, 0, sizeof(AvgStruct));
// Initialize the StringVal intermediate to a zero'd AvgStruct
void AvgInit(FunctionContext* context, StringVal* val) {
val->is_null = false;
val->len = sizeof(AvgStruct);
val->ptr = context->Allocate(val->len);
memset(val->ptr, 0, val->len);
}

void AvgUpdate(FunctionContext* context, const DoubleVal& input, BufferVal* val) {
void AvgUpdate(FunctionContext* context, const DoubleVal& input, StringVal* val) {
if (input.is_null) return;
AvgStruct* avg = reinterpret_cast<AvgStruct*>(*val);
assert(!val->is_null);
assert(val->len == sizeof(AvgStruct));
AvgStruct* avg = reinterpret_cast<AvgStruct*>(val->ptr);
avg->sum += input.val;
++avg->count;
}

void AvgMerge(FunctionContext* context, const BufferVal& src, BufferVal* dst) {
if (src == NULL) return;
const AvgStruct* src_struct = reinterpret_cast<const AvgStruct*>(src);
AvgStruct* dst_struct = reinterpret_cast<AvgStruct*>(*dst);
dst_struct->sum += src_struct->sum;
dst_struct->count += src_struct->count;
void AvgMerge(FunctionContext* context, const StringVal& src, StringVal* dst) {
if (src.is_null) return;
const AvgStruct* src_avg = reinterpret_cast<const AvgStruct*>(src.ptr);
AvgStruct* dst_avg = reinterpret_cast<AvgStruct*>(dst->ptr);
dst_avg->sum += src_avg->sum;
dst_avg->count += src_avg->count;
}

DoubleVal AvgFinalize(FunctionContext* context, const BufferVal& val) {
if (val == NULL) return DoubleVal::null();
AvgStruct* val_struct = reinterpret_cast<AvgStruct*>(val);
return DoubleVal(val_struct->sum / val_struct->count);
StringVal AvgFinalize(FunctionContext* context, const StringVal& val) {
assert(!val.is_null);
assert(val.len == sizeof(AvgStruct));
AvgStruct* avg = reinterpret_cast<AvgStruct*>(val.ptr);
if (avg->count == 0) return StringVal::null();
return ToStringVal(context, avg->sum / avg->count);
}

// ---------------------------------------------------------------------------
Expand Down
80 changes: 70 additions & 10 deletions uda-sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,27 @@

using namespace impala_udf;

// Note: As of Impala 1.2, UDAs must have the same intermediate and result types (see the
// udf.h header for the full Impala UDA specification, which can be found at
// https://github.com/cloudera/impala/blob/master/be/src/udf/udf.h). Some UDAs naturally
// conform to this limitation, such as Count and StringConcat. However, other UDAs return
// a numeric value but use a custom intermediate struct type that must be stored in a
// StringVal or BufferVal, such as Variance.
//
// As a workaround for now, these UDAs that require an intermediate buffer use StringVal
// for the intermediate and result type. In the UDAs' finalize functions, the numeric
// result is serialized to an ASCII string (see the ToStringVal() utility function
// provided with these samples). The returned StringVal is then cast back to the correct
// numeric type (see the Usage examples below).
//
// This restriction will be lifted in Impala 2.0.


// This is an example of the COUNT aggregate function.
//
// Usage: > create aggregate function my_count(int) returns bigint
// location '/user/cloudera/libudasample.so' update_fn='CountUpdate';
// > select my_count(col) from tbl;
void CountInit(FunctionContext* context, BigIntVal* val);
void CountUpdate(FunctionContext* context, const IntVal& input, BigIntVal* val);
void CountMerge(FunctionContext* context, const BigIntVal& src, BigIntVal* dst);
Expand All @@ -30,32 +50,72 @@ BigIntVal CountFinalize(FunctionContext* context, const BigIntVal& val);
// maintain two pieces of state, the current sum and the count. We do this using
// the BufferVal intermediate type. When this UDA is registered, it would specify
// 16 bytes (8 byte sum + 8 byte count) as the size for this buffer.
void AvgInit(FunctionContext* context, BufferVal* val);
void AvgUpdate(FunctionContext* context, const DoubleVal& input, BufferVal* val);
void AvgMerge(FunctionContext* context, const BufferVal& src, BufferVal* dst);
DoubleVal AvgFinalize(FunctionContext* context, const BufferVal& val);
//
// Usage: > create aggregate function my_avg(double) returns string
// location '/user/cloudera/libudasample.so' update_fn='AvgUpdate';
// > select cast(my_avg(col) as double) from tbl;
//
// TODO: The StringVal intermediate type should be replaced by a prealloacted BufferVal
// and the return type changed to DoubleVal in Impala 2.0
void AvgInit(FunctionContext* context, StringVal* val);
void AvgUpdate(FunctionContext* context, const DoubleVal& input, StringVal* val);
void AvgMerge(FunctionContext* context, const StringVal& src, StringVal* dst);
StringVal AvgFinalize(FunctionContext* context, const StringVal& val);

// This is a sample of implementing the STRING_CONCAT aggregate function.
// Example: select string_concat(string_col, ",") from table
//
// Usage: > create aggregate function string_concat(string, string) returns string
// location '/user/cloudera/libudasample.so' update_fn='StringConcatUpdate';
// > select string_concat(string_col, ",") from table;
void StringConcatInit(FunctionContext* context, StringVal* val);
void StringConcatUpdate(FunctionContext* context, const StringVal& arg1,
const StringVal& arg2, StringVal* val);
void StringConcatMerge(FunctionContext* context, const StringVal& src, StringVal* dst);
StringVal StringConcatFinalize(FunctionContext* context, const StringVal& val);

// This is a example of the variance aggregate function.
//
// Usage: > create aggregate function var(double) returns string
// location '/user/cloudera/libudasample.so' update_fn='VarianceUpdate';
// > select cast(var(col) as double) from tbl;
//
// TODO: The StringVal intermediate type should be replaced by a prealloacted BufferVal
// and the return type changed to DoubleVal in Impala 2.0
void VarianceInit(FunctionContext* context, StringVal* val);
void VarianceUpdate(FunctionContext* context, const DoubleVal& input, StringVal* val);
void VarianceMerge(FunctionContext* context, const StringVal& src, StringVal* dst);
DoubleVal VarianceFinalize(FunctionContext* context, const StringVal& val);
StringVal VarianceFinalize(FunctionContext* context, const StringVal& val);

// An implementation of the Knuth online variance algorithm, which is also single pass and
// more numerically stable.
//
// Usage: > create aggregate function knuth_var(double) returns string
// location '/user/cloudera/libudasample.so' update_fn='KnuthVarianceUpdate';
// > select cast(knuth_var(col) as double) from tbl;
//
// TODO: The StringVal intermediate type should be replaced by a prealloacted BufferVal
// and the return type changed to DoubleVal in Impala 2.0
void KnuthVarianceInit(FunctionContext* context, StringVal* val);
void KnuthVarianceUpdate(FunctionContext* context, const DoubleVal& input, StringVal* val);
void KnuthVarianceMerge(FunctionContext* context, const StringVal& src, StringVal* dst);
DoubleVal KnuthVarianceFinalize(FunctionContext* context, const StringVal& val);
StringVal KnuthVarianceFinalize(FunctionContext* context, const StringVal& val);

// The different steps of the UDA are composable. In this case, we'the UDA will use the
// other steps from the Knuth variance computation.
//
// Usage: > create aggregate function stddev(double) returns string
// location '/user/cloudera/libudasample.so' update_fn='KnuthVarianceUpdate'
// finalize_fn="StdDevFinalize";
// > select cast(stddev(col) as double) from tbl;
//
// TODO: The StringVal intermediate type should be replaced by a prealloacted BufferVal
// and the return type changed to DoubleVal in Impala 2.0
StringVal StdDevFinalize(FunctionContext* context, const StringVal& val);

// The different steps of the UDA are composable. In this case, we'll the UDA will
// use the other steps from the variance computation.
DoubleVal StdDevFinalize(FunctionContext* context, const StringVal& val);
// Utility function for serialization to StringVal
// TODO: this will be unnecessary in Impala 2.0, when we will no longer have to serialize
// results to StringVals in order to match the intermediate type
template <typename T>
StringVal ToStringVal(FunctionContext* context, const T& val);

#endif
26 changes: 15 additions & 11 deletions variance-uda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,15 @@ void VarianceMerge(FunctionContext* ctx, const StringVal& src, StringVal* dst) {
dst_state->count += src_state->count;
}

DoubleVal VarianceFinalize(FunctionContext* ctx, const StringVal& src) {
StringVal VarianceFinalize(FunctionContext* ctx, const StringVal& src) {
VarianceState* state = reinterpret_cast<VarianceState*>(src.ptr);
if (state->count == 0 || state->count == 1) return DoubleVal::null();
if (state->count == 0 || state->count == 1) return StringVal::null();
double mean = state->sum / state->count;
double variance =
(state->sum_squared - state->sum * state->sum / state->count) / (state->count - 1);
return DoubleVal(variance);
return ToStringVal(ctx, variance);
}

// An implementation of the Knuth online variance algorithm, which is also single pass
// and more numerically stable.
struct KnuthVarianceState {
int64_t count;
double mean;
Expand Down Expand Up @@ -108,17 +106,23 @@ void KnuthVarianceMerge(FunctionContext* ctx, const StringVal& src, StringVal* d
dst_state->count = sum_count;
}

DoubleVal KnuthVarianceFinalize(FunctionContext* ctx, const StringVal& src) {
KnuthVarianceState* state = reinterpret_cast<KnuthVarianceState*>(src.ptr);
// TODO: this can be used as the actual variance finalize function once the return type
// doesn't need to match the intermediate type in Impala 2.0.
DoubleVal KnuthVarianceFinalize(const StringVal& state_sv) {
KnuthVarianceState* state = reinterpret_cast<KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0 || state->count == 1) return DoubleVal::null();
double variance_n = state->m2 / state->count;
double variance = variance_n * state->count / (state->count - 1);
return DoubleVal(variance);
}

DoubleVal StdDevFinalize(FunctionContext* ctx, const StringVal& src) {
DoubleVal variance = KnuthVarianceFinalize(ctx, src);
if (variance.is_null) return variance;
return DoubleVal(sqrt(variance.val));
StringVal KnuthVarianceFinalize(FunctionContext* ctx, const StringVal& src) {
return ToStringVal(ctx, KnuthVarianceFinalize(src));
}

StringVal StdDevFinalize(FunctionContext* ctx, const StringVal& src) {
DoubleVal variance = KnuthVarianceFinalize(src);
if (variance.is_null) return StringVal::null();
return ToStringVal(ctx, sqrt(variance.val));
}

0 comments on commit 9d9a187

Please sign in to comment.