Skip to content

Commit

Permalink
Added knuth variance.
Browse files Browse the repository at this point in the history
  • Loading branch information
Nong Li committed Nov 19, 2013
1 parent 3fce648 commit 69a6d65
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
48 changes: 48 additions & 0 deletions uda-sample-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <iostream>
#include <math.h>

#include <impala_udf/uda-test-harness.h>
#include "uda-sample.h"
Expand Down Expand Up @@ -88,11 +89,58 @@ bool TestStringConcat() {
return true;
}

// For algorithms that work on floating point values, the results might not match
// exactly due to floating point inprecision. The test harness allows passing a
// custom equality compartor. Here's an example of one that can tolerate some small
// error.
bool FuzzyCompare(const DoubleVal& x, const DoubleVal& y) {
if (x.is_null && y.is_null) return true;
if (x.is_null || y.is_null) return false;
return fabs(x.val - y.val) < 0.00001;
}

bool TestVariance() {
UdaTestHarness<DoubleVal, StringVal, DoubleVal> simple_variance(
VarianceInit, VarianceUpdate, VarianceMerge, NULL, VarianceFinalize);
simple_variance.SetResultComparator(FuzzyCompare);

UdaTestHarness<DoubleVal, StringVal, DoubleVal> knuth_variance(
KnuthVarianceInit, KnuthVarianceUpdate, KnuthVarianceMerge, NULL,
KnuthVarianceFinalize);
knuth_variance.SetResultComparator(FuzzyCompare);

vector<DoubleVal> vals;
double sum = 0;
for (int i = 0; i < 1001; ++i) {
vals.push_back(DoubleVal(i));
sum += i;
}
double mean = sum / vals.size();
double expected_variance = 0;
for (int i = 0; i < vals.size(); ++i) {
double d = mean - vals[i].val;
expected_variance += d * d;
}
expected_variance /= (vals.size() - 1);

if (!simple_variance.Execute(vals, DoubleVal(expected_variance))) {
cerr << "Simple variance: " << simple_variance.GetErrorMsg() << endl;
return false;
}
if (!knuth_variance.Execute(vals, DoubleVal(expected_variance))) {
cerr << "Knuth variance: " << knuth_variance.GetErrorMsg() << endl;
return false;
}

return true;
}

int main(int argc, char** argv) {
bool passed = true;
passed &= TestCount();
passed &= TestAvg();
passed &= TestStringConcat();
passed &= TestVariance();
cerr << (passed ? "Tests passed." : "Tests failed.") << endl;
return 0;
}
11 changes: 11 additions & 0 deletions uda-sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,15 @@ void StringConcatUpdate(FunctionContext* context, const StringVal& arg1,
void StringConcatMerge(FunctionContext* context, const StringVal& src, StringVal* dst);
StringVal StringConcatFinalize(FunctionContext* context, const StringVal& val);

// This is a example of the variance aggregate function.
void VarianceInit(FunctionContext* context, StringVal* val);
void VarianceUpdate(FunctionContext* context, const DoubleVal& input, StringVal* val);
void VarianceMerge(FunctionContext* context, const StringVal& src, StringVal* dst);
DoubleVal VarianceFinalize(FunctionContext* context, const StringVal& val);

void KnuthVarianceInit(FunctionContext* context, StringVal* val);
void KnuthVarianceUpdate(FunctionContext* context, const DoubleVal& input, StringVal* val);
void KnuthVarianceMerge(FunctionContext* context, const StringVal& src, StringVal* dst);
DoubleVal KnuthVarianceFinalize(FunctionContext* context, const StringVal& val);

#endif
58 changes: 56 additions & 2 deletions variance-uda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@
#include <iostream>
#include <impala_udf/udf.h>

#include "uda-sample.h"

using namespace std;
using namespace impala_udf;

// An implementation of a simple single pass variance algorithm. A standard UDA must
// be single pass (i.e. does not scan the table more than once), so the most canonical
// two pass approach is not practical.
// This algorithms suffers from numerical precision issues if the input values are
// large due to floating point rounding.
struct VarianceState {
// Sum of all input values.
double sum;
Expand Down Expand Up @@ -56,9 +63,56 @@ void VarianceMerge(FunctionContext* ctx, const StringVal& src, StringVal* dst) {

DoubleVal VarianceFinalize(FunctionContext* ctx, const StringVal& src) {
VarianceState* state = reinterpret_cast<VarianceState*>(src.ptr);
if (state->count == 0) return DoubleVal::null();
if (state->count == 0 || state->count == 1) return DoubleVal::null();
double mean = state->sum / state->count;
double variance = state->sum_squared / state->count - mean * mean;
double variance =
(state->sum_squared - state->sum * state->sum / state->count) / (state->count - 1);
return DoubleVal(variance);
}

// An implementation of the Knuth online variance algorithm, which is also single pass
// and more numerically stable.
struct KnuthVarianceState {
int64_t count;
double mean;
double m2;
};

void KnuthVarianceInit(FunctionContext* ctx, StringVal* dst) {
dst->is_null = false;
dst->len = sizeof(KnuthVarianceState);
dst->ptr = ctx->Allocate(dst->len);
memset(dst->ptr, 0, dst->len);
}

void KnuthVarianceUpdate(FunctionContext* ctx, const DoubleVal& src, StringVal* dst) {
if (src.is_null) return;
KnuthVarianceState* state = reinterpret_cast<KnuthVarianceState*>(dst->ptr);
double temp = 1 + state->count;
double delta = src.val - state->mean;
double r = delta / temp;
state->mean += r;
state->m2 += state->count * delta * r;
state->count = temp;
}

void KnuthVarianceMerge(FunctionContext* ctx, const StringVal& src, StringVal* dst) {
KnuthVarianceState* src_state = reinterpret_cast<KnuthVarianceState*>(src.ptr);
KnuthVarianceState* dst_state = reinterpret_cast<KnuthVarianceState*>(dst->ptr);
if (src_state->count == 0) return;
double delta = dst_state->mean - src_state->mean;
double sum_count = dst_state->count + src_state->count;
dst_state->mean = src_state->mean + delta * (dst_state->count / sum_count);
dst_state->m2 = (src_state->m2) + dst_state->m2 +
(delta * delta) * (src_state->count * dst_state->count / sum_count);
dst_state->count = sum_count;
}

DoubleVal KnuthVarianceFinalize(FunctionContext* ctx, const StringVal& src) {
KnuthVarianceState* state = reinterpret_cast<KnuthVarianceState*>(src.ptr);
if (state->count == 0 || state->count == 1) return DoubleVal::null();
double variance_n = state->m2 / state->count;
double variance = variance_n * state->count / (state->count - 1);
return DoubleVal(variance);
}

0 comments on commit 69a6d65

Please sign in to comment.