Skip to content

Commit

Permalink
Add support for Decimal Type Sum Aggreggation
Browse files Browse the repository at this point in the history
  • Loading branch information
majetideepak committed Aug 22, 2022
1 parent f831b91 commit 2d914d0
Show file tree
Hide file tree
Showing 14 changed files with 390 additions and 26 deletions.
58 changes: 58 additions & 0 deletions velox/common/base/CheckedArithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <string>
#include "folly/Likely.h"
#include "velox/common/base/Exceptions.h"
#include "velox/type/UnscaledLongDecimal.h"
#include "velox/type/UnscaledShortDecimal.h"

namespace facebook::velox {

Expand All @@ -33,6 +35,34 @@ T checkedPlus(const T& a, const T& b) {
return result;
}

template <>
inline UnscaledShortDecimal checkedPlus(
const UnscaledShortDecimal& a,
const UnscaledShortDecimal& b) {
int64_t result;
bool overflow =
__builtin_add_overflow(a.unscaledValue(), b.unscaledValue(), &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR(
"integer overflow: {} + {}", a.unscaledValue(), b.unscaledValue());
}
return UnscaledShortDecimal(result);
}

template <>
inline UnscaledLongDecimal checkedPlus(
const UnscaledLongDecimal& a,
const UnscaledLongDecimal& b) {
int128_t result;
bool overflow =
__builtin_add_overflow(a.unscaledValue(), b.unscaledValue(), &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR(
"integer overflow: {} + {}", a.unscaledValue(), b.unscaledValue());
}
return UnscaledLongDecimal(result);
}

template <typename T>
T checkedMinus(const T& a, const T& b) {
T result;
Expand All @@ -53,6 +83,34 @@ T checkedMultiply(const T& a, const T& b) {
return result;
}

template <>
inline UnscaledShortDecimal checkedMultiply(
const UnscaledShortDecimal& a,
const UnscaledShortDecimal& b) {
int64_t result;
bool overflow =
__builtin_mul_overflow(a.unscaledValue(), b.unscaledValue(), &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR(
"integer overflow: {} * {}", a.unscaledValue(), b.unscaledValue());
}
return UnscaledShortDecimal(result);
}

template <>
inline UnscaledLongDecimal checkedMultiply(
const UnscaledLongDecimal& a,
const UnscaledLongDecimal& b) {
int128_t result;
bool overflow =
__builtin_mul_overflow(a.unscaledValue(), b.unscaledValue(), &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR(
"integer overflow: {} * {}", a.unscaledValue(), b.unscaledValue());
}
return UnscaledLongDecimal(result);
}

template <typename T>
T checkedDivide(const T& a, const T& b) {
if (b == 0) {
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class Aggregate {
// width part of the state from the fixed part.
virtual int32_t accumulatorFixedWidthSize() const = 0;

/// Some types such as int128_t require aligned access.
/// Returns the alignment size of the accumulator.
virtual int32_t accumulatorAlignmentSize() const {
return 0;
};

// Return true if accumulator is allocated from external memory, e.g. memory
// not managed by Velox.
virtual bool accumulatorUsesExternalMemory() const {
Expand Down
4 changes: 3 additions & 1 deletion velox/exec/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ void GroupingSet::initializeGlobalAggregation() {

// Row layout is:
// - null flags - one bit per aggregate,
// - uint32_t row size,
// - uint32_t row size,
// - fixed-width accumulators - one per aggregate
//
// Here we always make space for a row size since we only have one
Expand All @@ -289,6 +289,8 @@ void GroupingSet::initializeGlobalAggregation() {

for (auto& aggregate : aggregates_) {
aggregate->setAllocator(&stringAllocator_);
// Accumulator offset must be aligned by their alignment size.
offset = bits::roundUp(offset, aggregate->accumulatorAlignmentSize());
aggregate->setOffsets(
offset,
RowContainer::nullByte(nullOffset),
Expand Down
2 changes: 2 additions & 0 deletions velox/exec/RowContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ RowContainer::RowContainer(
int32_t firstAggregate = offsets_.size();
int32_t firstAggregateOffset = offset;
for (auto& aggregate : aggregates) {
// Accumulator offset must be aligned by their alignment size.
offset = bits::roundUp(offset, aggregate->accumulatorAlignmentSize());
offsets_.push_back(offset);
offset += aggregate->accumulatorFixedWidthSize();
nullOffsets_.push_back(nullOffset);
Expand Down
75 changes: 75 additions & 0 deletions velox/exec/tests/utils/QueryAssertions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ std::string makeCreateTableSql(
} else if (child->isMap()) {
sql << "MAP(" << child->asMap().keyType()->kindName() << ", "
<< child->asMap().valueType()->kindName() << ")";
} else if (child->isShortDecimal() || child->isLongDecimal()) {
const auto& [precision, scale] = getDecimalPrecisionScale(*child);
sql << "decimal(" << precision << ", " << scale << ")";
} else {
sql << child->kindName();
}
Expand Down Expand Up @@ -85,6 +88,31 @@ ::duckdb::Value duckValueAt<TypeKind::DATE>(
vector->as<SimpleVector<T>>()->valueAt(index).days()));
}

template <>
::duckdb::Value duckValueAt<TypeKind::SHORT_DECIMAL>(
const VectorPtr& vector,
vector_size_t index) {
using T = typename KindToFlatVector<TypeKind::SHORT_DECIMAL>::WrapperType;
auto type = vector->type()->asShortDecimal();
return ::duckdb::Value::DECIMAL(
vector->as<SimpleVector<T>>()->valueAt(index).unscaledValue(),
type.precision(),
type.scale());
}

template <>
::duckdb::Value duckValueAt<TypeKind::LONG_DECIMAL>(
const VectorPtr& vector,
vector_size_t index) {
using T = typename KindToFlatVector<TypeKind::LONG_DECIMAL>::WrapperType;
auto type = vector->type()->asLongDecimal();
auto val = vector->as<SimpleVector<T>>()->valueAt(index).unscaledValue();
auto duckVal = ::duckdb::hugeint_t();
duckVal.lower = (val << 64) >> 64;
duckVal.upper = (val >> 64);
return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale());
}

template <>
::duckdb::Value duckValueAt<TypeKind::ARRAY>(
const VectorPtr& vector,
Expand Down Expand Up @@ -196,6 +224,20 @@ velox::variant variantAt(const ::duckdb::Value& value) {
return velox::variant(value.GetValue<T>());
}

velox::variant decimalVariantAt(const ::duckdb::Value& value) {
uint8_t precision;
uint8_t scale;
value.type().GetDecimalProperties(precision, scale);
auto duckType = DECIMAL(precision, scale);
if (duckType->isShortDecimal()) {
return velox::variant::shortDecimal(value.GetValue<int64_t>(), duckType);
} else {
auto val = value.GetValueUnsafe<::duckdb::hugeint_t>();
return velox::variant::longDecimal(
buildInt128(val.upper, val.lower), duckType);
}
}

velox::variant rowVariantAt(
const ::duckdb::Value& vector,
const TypePtr& rowType) {
Expand Down Expand Up @@ -264,6 +306,8 @@ std::vector<MaterializedRow> materialize(
} else if (typeKind == TypeKind::ROW) {
row.push_back(
rowVariantAt(dataChunk->GetValue(j, i), rowType->childAt(j)));
} else if (isDecimalKind(typeKind)) {
row.push_back(decimalVariantAt(dataChunk->GetValue(j, i)));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
variantAt, typeKind, dataChunk, i, j);
Expand Down Expand Up @@ -291,6 +335,26 @@ velox::variant variantAt<TypeKind::TIMESTAMP>(VectorPtr vector, int32_t row) {
veloxTimestampToDuckDB(vector->as<SimpleVector<T>>()->valueAt(row))));
}

template <>
velox::variant variantAt<TypeKind::SHORT_DECIMAL>(
VectorPtr vector,
int32_t row) {
using T = typename KindToFlatVector<TypeKind::SHORT_DECIMAL>::WrapperType;
return velox::variant::shortDecimal(
vector->as<SimpleVector<T>>()->valueAt(row).unscaledValue(),
vector->type());
}

template <>
velox::variant variantAt<TypeKind::LONG_DECIMAL>(
VectorPtr vector,
int32_t row) {
using T = typename KindToFlatVector<TypeKind::LONG_DECIMAL>::WrapperType;
return velox::variant::longDecimal(
vector->as<SimpleVector<T>>()->valueAt(row).unscaledValue(),
vector->type());
}

velox::variant arrayVariantAt(const VectorPtr& vector, vector_size_t row) {
auto arrayVector = vector->as<ArrayVector>();
auto& elements = arrayVector->elements();
Expand Down Expand Up @@ -379,6 +443,11 @@ std::vector<MaterializedRow> materialize(const RowVectorPtr& vector) {
row.push_back(arrayVariantAt(vector->childAt(j), i));
} else if (typeKind == TypeKind::MAP) {
row.push_back(mapVariantAt(vector->childAt(j), i));
} else if (typeKind == TypeKind::SHORT_DECIMAL) {
row.push_back(
variantAt<TypeKind::SHORT_DECIMAL>(vector->childAt(j), i));
} else if (typeKind == TypeKind::LONG_DECIMAL) {
row.push_back(variantAt<TypeKind::LONG_DECIMAL>(vector->childAt(j), i));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
variantAt, typeKind, vector->childAt(j), i);
Expand Down Expand Up @@ -494,6 +563,12 @@ void DuckDbQueryRunner::createTable(
appender.Append(duckValueAt<TypeKind::ARRAY>(columnVector, row));
} else if (rowType.childAt(column)->isMap()) {
appender.Append(duckValueAt<TypeKind::MAP>(columnVector, row));
} else if (rowType.childAt(column)->isShortDecimal()) {
appender.Append(
duckValueAt<TypeKind::SHORT_DECIMAL>(columnVector, row));
} else if (rowType.childAt(column)->isLongDecimal()) {
appender.Append(
duckValueAt<TypeKind::LONG_DECIMAL>(columnVector, row));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
duckValueAt, rowType.childAt(column)->kind(), columnVector, row);
Expand Down
68 changes: 65 additions & 3 deletions velox/functions/prestosql/aggregates/SumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class SumAggregate
return sizeof(TAccumulator);
}

int32_t accumulatorAlignmentSize() const override {
return 0;
}

void initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) override {
Expand Down Expand Up @@ -87,7 +91,7 @@ class SumAggregate
&updateSingleValue<TAccumulator>,
&updateDuplicateValues<TAccumulator>,
mayPushdown,
0);
TInput(0));
}

void addSingleGroupIntermediateResults(
Expand All @@ -102,7 +106,7 @@ class SumAggregate
&updateSingleValue<ResultType>,
&updateDuplicateValues<ResultType>,
mayPushdown,
0);
TInput(0));
}

protected:
Expand Down Expand Up @@ -152,7 +156,7 @@ class SumAggregate
result += n * value;
} else {
result = functions::checkedPlus<TData>(
result, functions::checkedMultiply<TData>(n, value));
result, functions::checkedMultiply<TData>(TData(n), value));
}
}
};
Expand Down Expand Up @@ -183,6 +187,48 @@ inline void SumAggregate<float, double, float>::extractValues(
});
}

/// Override 'initializeNewGroups' for decimal values to call set method to
/// initialize the decimal value properly.
template <>
inline void
SumAggregate<UnscaledShortDecimal, UnscaledLongDecimal, UnscaledLongDecimal>::
initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) {
exec::Aggregate::setAllNulls(groups, indices);
for (auto i : indices) {
exec::Aggregate::value<UnscaledLongDecimal>(groups[i])->setUnscaledValue(0);
}
}

template <>
inline void
SumAggregate<UnscaledLongDecimal, UnscaledLongDecimal, UnscaledLongDecimal>::
initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) {
exec::Aggregate::setAllNulls(groups, indices);
for (auto i : indices) {
exec::Aggregate::value<UnscaledLongDecimal>(groups[i])->setUnscaledValue(0);
}
}

/// Override 'accumulatorAlignmentSize' for UnscaledLongDecimal values as it
/// uses int128_t type.
template <>
inline int32_t
SumAggregate<UnscaledShortDecimal, UnscaledLongDecimal, UnscaledLongDecimal>::
accumulatorAlignmentSize() const {
return static_cast<int32_t>(sizeof(UnscaledLongDecimal));
}

template <>
inline int32_t
SumAggregate<UnscaledLongDecimal, UnscaledLongDecimal, UnscaledLongDecimal>::
accumulatorAlignmentSize() const {
return static_cast<int32_t>(sizeof(UnscaledLongDecimal));
}

template <template <typename U, typename V, typename W> class T>
bool registerSumAggregate(const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
Expand All @@ -196,6 +242,12 @@ bool registerSumAggregate(const std::string& name) {
.intermediateType("double")
.argumentType("double")
.build(),
exec::AggregateFunctionSignatureBuilder()
.argumentType("DECIMAL(a_precision, a_scale)")
.intermediateType("DECIMAL(38, r_scale)")
.variableConstraint("r_scale", "a_scale")
.returnType("DECIMAL(38, r_scale)")
.build(),
};

for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) {
Expand Down Expand Up @@ -234,6 +286,16 @@ bool registerSumAggregate(const std::string& name) {
return std::make_unique<T<double, double, float>>(resultType);
}
return std::make_unique<T<double, double, double>>(DOUBLE());
case TypeKind::SHORT_DECIMAL:
return std::make_unique<
T<UnscaledShortDecimal,
UnscaledLongDecimal,
UnscaledLongDecimal>>(resultType);
case TypeKind::LONG_DECIMAL:
return std::make_unique<
T<UnscaledLongDecimal,
UnscaledLongDecimal,
UnscaledLongDecimal>>(resultType);
default:
VELOX_CHECK(
false,
Expand Down
Loading

0 comments on commit 2d914d0

Please sign in to comment.