Skip to content

Commit

Permalink
Add support for Decimal Type Sum Aggreggation
Browse files Browse the repository at this point in the history
  • Loading branch information
majetideepak committed Aug 1, 2022
1 parent 1c6ad90 commit 15884ca
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 6 deletions.
5 changes: 3 additions & 2 deletions velox/exec/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,14 @@ void GroupingSet::initializeGlobalAggregation() {

// Row layout is:
// - null flags - one bit per aggregate,
// - uint32_t row size,
// - uint32_t row size,
// - fixed-width accumulators - one per aggregate
//
// Here we always make space for a row size since we only have one
// row and no RowContainer.
int32_t rowSizeOffset = bits::nbytes(aggregates_.size());
int32_t offset = rowSizeOffset + sizeof(int32_t);
// Fixed-width accumulators must be word aligned.
int32_t offset = bits::roundUp(rowSizeOffset + sizeof(int32_t), 16);
int32_t nullOffset = 0;

for (auto& aggregate : aggregates_) {
Expand Down
77 changes: 77 additions & 0 deletions velox/exec/tests/utils/QueryAssertions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ std::string makeCreateTableSql(
} else if (child->isMap()) {
sql << "MAP(" << child->asMap().keyType()->kindName() << ", "
<< child->asMap().valueType()->kindName() << ")";
} else if (child->isShortDecimal() || child->isLongDecimal()) {
int precision;
int scale;
getDecimalPrecisionScale(*child, precision, scale);
sql << "decimal(" << precision << ", " << scale << ")";
} else {
sql << child->kindName();
}
Expand Down Expand Up @@ -85,6 +90,31 @@ ::duckdb::Value duckValueAt<TypeKind::DATE>(
vector->as<SimpleVector<T>>()->valueAt(index).days()));
}

template <>
::duckdb::Value duckValueAt<TypeKind::SHORT_DECIMAL>(
const VectorPtr& vector,
vector_size_t index) {
using T = typename KindToFlatVector<TypeKind::SHORT_DECIMAL>::WrapperType;
auto type = vector->type()->asShortDecimal();
return ::duckdb::Value::DECIMAL(
vector->as<SimpleVector<T>>()->valueAt(index).unscaledValue(),
type.precision(),
type.scale());
}

template <>
::duckdb::Value duckValueAt<TypeKind::LONG_DECIMAL>(
const VectorPtr& vector,
vector_size_t index) {
using T = typename KindToFlatVector<TypeKind::LONG_DECIMAL>::WrapperType;
auto type = vector->type()->asLongDecimal();
auto val = vector->as<SimpleVector<T>>()->valueAt(index).unscaledValue();
auto duckVal = ::duckdb::hugeint_t();
duckVal.lower = (val << 64) >> 64;
duckVal.upper = (val >> 64);
return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale());
}

template <>
::duckdb::Value duckValueAt<TypeKind::ARRAY>(
const VectorPtr& vector,
Expand Down Expand Up @@ -196,6 +226,20 @@ velox::variant variantAt(const ::duckdb::Value& value) {
return velox::variant(value.GetValue<T>());
}

velox::variant decimalVariantAt(const ::duckdb::Value& value) {
uint8_t precision;
uint8_t scale;
value.type().GetDecimalProperties(precision, scale);
auto duckType = DECIMAL(precision, scale);
if (duckType->isShortDecimal()) {
return velox::variant::shortDecimal(value.GetValue<int64_t>(), duckType);
} else {
auto val = value.GetValueUnsafe<::duckdb::hugeint_t>();
return velox::variant::longDecimal(
buildInt128(val.upper, val.lower), duckType);
}
}

velox::variant rowVariantAt(
const ::duckdb::Value& vector,
const TypePtr& rowType) {
Expand Down Expand Up @@ -264,6 +308,8 @@ std::vector<MaterializedRow> materialize(
} else if (typeKind == TypeKind::ROW) {
row.push_back(
rowVariantAt(dataChunk->GetValue(j, i), rowType->childAt(j)));
} else if (isDecimalKind(typeKind)) {
row.push_back(decimalVariantAt(dataChunk->GetValue(j, i)));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
variantAt, typeKind, dataChunk, i, j);
Expand Down Expand Up @@ -291,6 +337,26 @@ velox::variant variantAt<TypeKind::TIMESTAMP>(VectorPtr vector, int32_t row) {
veloxTimestampToDuckDB(vector->as<SimpleVector<T>>()->valueAt(row))));
}

template <>
velox::variant variantAt<TypeKind::SHORT_DECIMAL>(
VectorPtr vector,
int32_t row) {
using T = typename KindToFlatVector<TypeKind::SHORT_DECIMAL>::WrapperType;
return velox::variant::shortDecimal(
vector->as<SimpleVector<T>>()->valueAt(row).unscaledValue(),
vector->type());
}

template <>
velox::variant variantAt<TypeKind::LONG_DECIMAL>(
VectorPtr vector,
int32_t row) {
using T = typename KindToFlatVector<TypeKind::LONG_DECIMAL>::WrapperType;
return velox::variant::longDecimal(
vector->as<SimpleVector<T>>()->valueAt(row).unscaledValue(),
vector->type());
}

velox::variant arrayVariantAt(const VectorPtr& vector, vector_size_t row) {
auto arrayVector = vector->as<ArrayVector>();
auto& elements = arrayVector->elements();
Expand Down Expand Up @@ -379,6 +445,11 @@ std::vector<MaterializedRow> materialize(const RowVectorPtr& vector) {
row.push_back(arrayVariantAt(vector->childAt(j), i));
} else if (typeKind == TypeKind::MAP) {
row.push_back(mapVariantAt(vector->childAt(j), i));
} else if (typeKind == TypeKind::SHORT_DECIMAL) {
row.push_back(
variantAt<TypeKind::SHORT_DECIMAL>(vector->childAt(j), i));
} else if (typeKind == TypeKind::LONG_DECIMAL) {
row.push_back(variantAt<TypeKind::LONG_DECIMAL>(vector->childAt(j), i));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
variantAt, typeKind, vector->childAt(j), i);
Expand Down Expand Up @@ -494,6 +565,12 @@ void DuckDbQueryRunner::createTable(
appender.Append(duckValueAt<TypeKind::ARRAY>(columnVector, row));
} else if (rowType.childAt(column)->isMap()) {
appender.Append(duckValueAt<TypeKind::MAP>(columnVector, row));
} else if (rowType.childAt(column)->isShortDecimal()) {
appender.Append(
duckValueAt<TypeKind::SHORT_DECIMAL>(columnVector, row));
} else if (rowType.childAt(column)->isLongDecimal()) {
appender.Append(
duckValueAt<TypeKind::LONG_DECIMAL>(columnVector, row));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
duckValueAt, rowType.childAt(column)->kind(), columnVector, row);
Expand Down
16 changes: 15 additions & 1 deletion velox/functions/prestosql/aggregates/SimpleNumericAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "velox/exec/Aggregate.h"
#include "velox/exec/AggregationHook.h"
#include "velox/type/LongDecimal.h"
#include "velox/vector/DecodedVector.h"
#include "velox/vector/FlatVector.h"
#include "velox/vector/LazyVector.h"
Expand Down Expand Up @@ -146,7 +147,20 @@ class SimpleNumericAggregate : public exec::Aggregate {
if (!decoded.isNullAt(0)) {
updateDuplicateValues(
initialValue, decoded.valueAt<TInput>(0), rows.countSelected());
updateNonNullValue<true, TData>(group, initialValue, updateSingleValue);
// Some DECIMAL type aggregations requires conversion from LongDecimal
// to ShortDecimal. However, this conversion is not desired as it
// requires a safety check. Specialize this case instead.
if constexpr (
std::is_same<TInput, ShortDecimal>::value &&
std::is_same<TData, LongDecimal>::value) {
updateNonNullValue<true, TData>(
group,
static_cast<LongDecimal>(initialValue).unscaledValue(),
updateSingleValue);
} else {
updateNonNullValue<true, TData>(
group, initialValue, updateSingleValue);
}
}
} else if (decoded.mayHaveNulls()) {
rows.applyToSelected([&](vector_size_t i) {
Expand Down
13 changes: 13 additions & 0 deletions velox/functions/prestosql/aggregates/SumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ bool registerSumAggregate(const std::string& name) {
.intermediateType("double")
.argumentType("double")
.build(),
exec::AggregateFunctionSignatureBuilder()
.argumentType("DECIMAL(a_precision, a_scale)")
.intermediateType("DECIMAL(38, i_scale)")
.variableConstraint("i_scale", "a_scale")
.returnType("DECIMAL(38, r_scale)")
.variableConstraint("r_scale", "a_scale")
.build(),
};

for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) {
Expand Down Expand Up @@ -190,6 +197,12 @@ bool registerSumAggregate(const std::string& name) {
return std::make_unique<T<double, double, float>>(resultType);
}
return std::make_unique<T<double, double, double>>(DOUBLE());
case TypeKind::SHORT_DECIMAL:
return std::make_unique<T<ShortDecimal, LongDecimal, LongDecimal>>(
resultType);
case TypeKind::LONG_DECIMAL:
return std::make_unique<T<LongDecimal, LongDecimal, LongDecimal>>(
resultType);
default:
VELOX_CHECK(
false,
Expand Down
29 changes: 28 additions & 1 deletion velox/functions/prestosql/aggregates/tests/SumTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,34 @@ TEST_F(SumTest, sumDouble) {
assertQuery(agg, "SELECT sum(c0), sum(c1) FROM tmp");
}

TEST_F(SumTest, sumDecimal) {
auto input = makeRowVector(
{makeNullableShortDecimalFlatVector(
{1000, 2000, 3000, 4000, 5000, std::nullopt}, DECIMAL(10, 1)),
makeNullableLongDecimalFlatVector(
{buildInt128(10, 100),
buildInt128(10, 200),
buildInt128(10, 300),
buildInt128(10, 400),
buildInt128(10, 500),
std::nullopt},
DECIMAL(23, 4))});

auto expected = makeRowVector(
{makeLongDecimalFlatVector({15000}, DECIMAL(38, 1)),
makeLongDecimalFlatVector({buildInt128(50, 1500)}, DECIMAL(38, 4))});

// Global final aggregation.
auto agg = PlanBuilder()
.values({input})
.partialAggregation({}, {"sum(c0)", "sum(c1)"})
.intermediateAggregation()
.finalAggregation()
.planNode();
ASSERT_TRUE(expected->type()->equivalent(*agg->outputType()));
exec::test::assertQuery(agg, {expected});
}

TEST_F(SumTest, sumWithMask) {
auto rowType =
ROW({"c0", "c1", "c2", "c3", "c4"},
Expand Down Expand Up @@ -200,7 +228,6 @@ TEST_F(SumTest, sumWithMask) {
// Test aggregation over boolean key
TEST_F(SumTest, boolKey) {
vector_size_t size = 1'000;
auto rowType = ROW({"c0", "c1"}, {BOOLEAN(), INTEGER()});
auto vector = makeRowVector(
{makeFlatVector<bool>(size, [](auto row) { return row % 3 == 0; }),
makeFlatVector<int32_t>(size, [](auto row) { return row; })});
Expand Down
20 changes: 19 additions & 1 deletion velox/type/LongDecimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <string>
#include "velox/common/base/BitUtil.h"
#include "velox/common/base/Exceptions.h"
#include "velox/type/ShortDecimal.h"
#include "velox/type/StringView.h"

#pragma once
Expand All @@ -35,7 +36,9 @@ struct LongDecimal {
public:
// Default required for creating vector with NULL values.
LongDecimal() = default;
constexpr explicit LongDecimal(int128_t value) : unscaledValue_(value) {}
constexpr LongDecimal(int128_t value) : unscaledValue_(value) {}
constexpr LongDecimal(ShortDecimal value)
: unscaledValue_(value.unscaledValue()) {}

int128_t unscaledValue() const {
return unscaledValue_;
Expand All @@ -57,9 +60,24 @@ struct LongDecimal {
return unscaledValue_ <= other.unscaledValue_;
}

LongDecimal& operator+=(const LongDecimal& other) {
unscaledValue_ += other.unscaledValue_;
return *this;
}

LongDecimal& operator+=(const ShortDecimal& other) {
unscaledValue_ += other.unscaledValue();
return *this;
}

private:
int128_t unscaledValue_;
}; // struct LongDecimal

static inline LongDecimal operator*(const int value, const LongDecimal& other) {
return LongDecimal(value * other.unscaledValue());
}

} // namespace facebook::velox

namespace folly {
Expand Down
14 changes: 13 additions & 1 deletion velox/type/ShortDecimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct ShortDecimal {
public:
// Default required for creating vector with NULL values.
ShortDecimal() = default;
constexpr explicit ShortDecimal(int64_t value) : unscaledValue_(value) {}
constexpr ShortDecimal(int64_t value) : unscaledValue_(value) {}

int64_t unscaledValue() const {
return unscaledValue_;
Expand Down Expand Up @@ -57,9 +57,21 @@ struct ShortDecimal {
return unscaledValue_ >= other.unscaledValue_;
}

ShortDecimal& operator+=(const ShortDecimal& other) {
unscaledValue_ += other.unscaledValue_;
return *this;
}

private:
int64_t unscaledValue_;
};

static inline ShortDecimal operator*(
const int value,
const ShortDecimal& other) {
return ShortDecimal(value * other.unscaledValue());
}

} // namespace facebook::velox

namespace folly {
Expand Down
14 changes: 14 additions & 0 deletions velox/vector/VectorTypeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ struct KindToFlatVector<TypeKind::DATE> {
using HashRowType = Date;
};

template <>
struct KindToFlatVector<TypeKind::SHORT_DECIMAL> {
using type = FlatVector<ShortDecimal>;
using WrapperType = ShortDecimal;
using HashRowType = ShortDecimal;
};

template <>
struct KindToFlatVector<TypeKind::LONG_DECIMAL> {
using type = FlatVector<LongDecimal>;
using WrapperType = LongDecimal;
using HashRowType = LongDecimal;
};

template <>
struct KindToFlatVector<TypeKind::INTERVAL_DAY_TIME> {
using type = FlatVector<IntervalDayTime>;
Expand Down

0 comments on commit 15884ca

Please sign in to comment.