Skip to content

Commit

Permalink
Add support for Decimal Type Sum Aggregation (#2163)
Browse files Browse the repository at this point in the history
Summary:
Improve Decimal Type Function signature to re-use variables.

Pull Request resolved: #2163

Reviewed By: kgpai

Differential Revision: D38837924

Pulled By: xiaoxmeng

fbshipit-source-id: 59f3f5570050472102b89189496feb8c06c2040f
  • Loading branch information
majetideepak authored and facebook-github-bot committed Aug 29, 2022
1 parent 015a43c commit bc0fc81
Show file tree
Hide file tree
Showing 21 changed files with 460 additions and 63 deletions.
66 changes: 66 additions & 0 deletions velox/common/base/CheckedArithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <string>
#include "folly/Likely.h"
#include "velox/common/base/Exceptions.h"
#include "velox/type/UnscaledLongDecimal.h"
#include "velox/type/UnscaledShortDecimal.h"

namespace facebook::velox {

Expand All @@ -33,6 +35,38 @@ T checkedPlus(const T& a, const T& b) {
return result;
}

template <>
inline UnscaledShortDecimal checkedPlus(
const UnscaledShortDecimal& a,
const UnscaledShortDecimal& b) {
int64_t result;
bool overflow =
__builtin_add_overflow(a.unscaledValue(), b.unscaledValue(), &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR(
"short decimal plus overflow: {} + {}",
a.unscaledValue(),
b.unscaledValue());
}
return UnscaledShortDecimal(result);
}

template <>
inline UnscaledLongDecimal checkedPlus(
const UnscaledLongDecimal& a,
const UnscaledLongDecimal& b) {
int128_t result;
bool overflow =
__builtin_add_overflow(a.unscaledValue(), b.unscaledValue(), &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR(
"long decimal plus overflow: {} + {}",
a.unscaledValue(),
b.unscaledValue());
}
return UnscaledLongDecimal(result);
}

template <typename T>
T checkedMinus(const T& a, const T& b) {
T result;
Expand All @@ -53,6 +87,38 @@ T checkedMultiply(const T& a, const T& b) {
return result;
}

template <>
inline UnscaledShortDecimal checkedMultiply(
const UnscaledShortDecimal& a,
const UnscaledShortDecimal& b) {
int64_t result;
bool overflow =
__builtin_mul_overflow(a.unscaledValue(), b.unscaledValue(), &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR(
"short decimal multiply overflow: {} * {}",
a.unscaledValue(),
b.unscaledValue());
}
return UnscaledShortDecimal(result);
}

template <>
inline UnscaledLongDecimal checkedMultiply(
const UnscaledLongDecimal& a,
const UnscaledLongDecimal& b) {
int128_t result;
bool overflow =
__builtin_mul_overflow(a.unscaledValue(), b.unscaledValue(), &result);
if (UNLIKELY(overflow)) {
VELOX_ARITHMETIC_ERROR(
"long decimal multiply overflow: {} * {}",
a.unscaledValue(),
b.unscaledValue());
}
return UnscaledLongDecimal(result);
}

template <typename T>
T checkedDivide(const T& a, const T& b) {
if (b == 0) {
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class Aggregate {
// width part of the state from the fixed part.
virtual int32_t accumulatorFixedWidthSize() const = 0;

/// Returns the alignment size of the accumulator.
/// Some types such as int128_t require aligned access.
virtual int32_t accumulatorAlignmentSize() const {
return 1;
}

// Return true if accumulator is allocated from external memory, e.g. memory
// not managed by Velox.
virtual bool accumulatorUsesExternalMemory() const {
Expand Down
12 changes: 6 additions & 6 deletions velox/exec/AggregationHook.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ class AggregationHook : public ValueHook {
};

namespace {
template <typename TOutput, typename TInput>
inline void updateSingleValue(TOutput& result, TInput value) {
template <typename TValue>
inline void updateSingleValue(TValue& result, TValue value) {
if constexpr (
std::is_same_v<TOutput, double> || std::is_same_v<TOutput, float>) {
std::is_same_v<TValue, double> || std::is_same_v<TValue, float>) {
result += value;
} else {
result = checkedPlus<TOutput>(result, value);
result = checkedPlus<TValue>(result, value);
}
}
} // namespace
Expand Down Expand Up @@ -143,7 +143,7 @@ class SumHook final : public AggregationHook {
clearNull(group);
updateSingleValue(
*reinterpret_cast<TAggregate*>(group + offset_),
*reinterpret_cast<const TValue*>(value));
TAggregate(*reinterpret_cast<const TValue*>(value)));
}
};

Expand All @@ -169,7 +169,7 @@ class SimpleCallableHook final : public AggregationHook {
clearNull(group);
updateSingleValue_(
*reinterpret_cast<TAggregate*>(group + offset_),
*reinterpret_cast<const TValue*>(value));
TAggregate(*reinterpret_cast<const TValue*>(value)));
}

private:
Expand Down
4 changes: 3 additions & 1 deletion velox/exec/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ void GroupingSet::initializeGlobalAggregation() {

// Row layout is:
// - null flags - one bit per aggregate,
// - uint32_t row size,
// - uint32_t row size,
// - fixed-width accumulators - one per aggregate
//
// Here we always make space for a row size since we only have one
Expand All @@ -294,6 +294,8 @@ void GroupingSet::initializeGlobalAggregation() {

for (auto& aggregate : aggregates_) {
aggregate->setAllocator(&stringAllocator_);
// Accumulator offset must be aligned by their alignment size.
offset = bits::roundUp(offset, aggregate->accumulatorAlignmentSize());
aggregate->setOffsets(
offset,
RowContainer::nullByte(nullOffset),
Expand Down
2 changes: 2 additions & 0 deletions velox/exec/RowContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ RowContainer::RowContainer(
int32_t firstAggregate = offsets_.size();
int32_t firstAggregateOffset = offset;
for (auto& aggregate : aggregates) {
// Accumulator offset must be aligned by their alignment size.
offset = bits::roundUp(offset, aggregate->accumulatorAlignmentSize());
offsets_.push_back(offset);
offset += aggregate->accumulatorFixedWidthSize();
nullOffsets_.push_back(nullOffset);
Expand Down
74 changes: 74 additions & 0 deletions velox/exec/tests/utils/QueryAssertions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ std::string makeCreateTableSql(
} else if (child->isMap()) {
sql << "MAP(" << child->asMap().keyType()->kindName() << ", "
<< child->asMap().valueType()->kindName() << ")";
} else if (child->isShortDecimal() || child->isLongDecimal()) {
const auto& [precision, scale] = getDecimalPrecisionScale(*child);
sql << "DECIMAL(" << precision << ", " << scale << ")";
} else {
sql << child->kindName();
}
Expand Down Expand Up @@ -85,6 +88,31 @@ ::duckdb::Value duckValueAt<TypeKind::DATE>(
vector->as<SimpleVector<T>>()->valueAt(index).days()));
}

template <>
::duckdb::Value duckValueAt<TypeKind::SHORT_DECIMAL>(
const VectorPtr& vector,
vector_size_t index) {
using T = typename KindToFlatVector<TypeKind::SHORT_DECIMAL>::WrapperType;
auto type = vector->type()->asShortDecimal();
return ::duckdb::Value::DECIMAL(
vector->as<SimpleVector<T>>()->valueAt(index).unscaledValue(),
type.precision(),
type.scale());
}

template <>
::duckdb::Value duckValueAt<TypeKind::LONG_DECIMAL>(
const VectorPtr& vector,
vector_size_t index) {
using T = typename KindToFlatVector<TypeKind::LONG_DECIMAL>::WrapperType;
auto type = vector->type()->asLongDecimal();
auto val = vector->as<SimpleVector<T>>()->valueAt(index).unscaledValue();
auto duckVal = ::duckdb::hugeint_t();
duckVal.lower = (val << 64) >> 64;
duckVal.upper = (val >> 64);
return ::duckdb::Value::DECIMAL(duckVal, type.precision(), type.scale());
}

template <>
::duckdb::Value duckValueAt<TypeKind::ARRAY>(
const VectorPtr& vector,
Expand Down Expand Up @@ -196,6 +224,19 @@ velox::variant variantAt(const ::duckdb::Value& value) {
return velox::variant(value.GetValue<T>());
}

velox::variant decimalVariantAt(const ::duckdb::Value& value) {
uint8_t precision;
uint8_t scale;
value.type().GetDecimalProperties(precision, scale);
auto type = DECIMAL(precision, scale);
if (type->isShortDecimal()) {
return velox::variant::shortDecimal(value.GetValue<int64_t>(), type);
} else {
auto val = value.GetValueUnsafe<::duckdb::hugeint_t>();
return velox::variant::longDecimal(buildInt128(val.upper, val.lower), type);
}
}

velox::variant rowVariantAt(
const ::duckdb::Value& vector,
const TypePtr& rowType) {
Expand Down Expand Up @@ -295,6 +336,8 @@ std::vector<MaterializedRow> materialize(
} else if (typeKind == TypeKind::ROW) {
row.push_back(
rowVariantAt(dataChunk->GetValue(j, i), rowType->childAt(j)));
} else if (isDecimalKind(typeKind)) {
row.push_back(decimalVariantAt(dataChunk->GetValue(j, i)));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
variantAt, typeKind, dataChunk, i, j);
Expand Down Expand Up @@ -322,6 +365,26 @@ velox::variant variantAt<TypeKind::TIMESTAMP>(VectorPtr vector, int32_t row) {
veloxTimestampToDuckDB(vector->as<SimpleVector<T>>()->valueAt(row))));
}

template <>
velox::variant variantAt<TypeKind::SHORT_DECIMAL>(
VectorPtr vector,
int32_t row) {
using T = typename KindToFlatVector<TypeKind::SHORT_DECIMAL>::WrapperType;
return velox::variant::shortDecimal(
vector->as<SimpleVector<T>>()->valueAt(row).unscaledValue(),
vector->type());
}

template <>
velox::variant variantAt<TypeKind::LONG_DECIMAL>(
VectorPtr vector,
int32_t row) {
using T = typename KindToFlatVector<TypeKind::LONG_DECIMAL>::WrapperType;
return velox::variant::longDecimal(
vector->as<SimpleVector<T>>()->valueAt(row).unscaledValue(),
vector->type());
}

velox::variant arrayVariantAt(const VectorPtr& vector, vector_size_t row) {
auto arrayVector = vector->as<ArrayVector>();
auto& elements = arrayVector->elements();
Expand Down Expand Up @@ -410,6 +473,11 @@ std::vector<MaterializedRow> materialize(const RowVectorPtr& vector) {
row.push_back(arrayVariantAt(vector->childAt(j), i));
} else if (typeKind == TypeKind::MAP) {
row.push_back(mapVariantAt(vector->childAt(j), i));
} else if (typeKind == TypeKind::SHORT_DECIMAL) {
row.push_back(
variantAt<TypeKind::SHORT_DECIMAL>(vector->childAt(j), i));
} else if (typeKind == TypeKind::LONG_DECIMAL) {
row.push_back(variantAt<TypeKind::LONG_DECIMAL>(vector->childAt(j), i));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
variantAt, typeKind, vector->childAt(j), i);
Expand Down Expand Up @@ -525,6 +593,12 @@ void DuckDbQueryRunner::createTable(
appender.Append(duckValueAt<TypeKind::ARRAY>(columnVector, row));
} else if (rowType.childAt(column)->isMap()) {
appender.Append(duckValueAt<TypeKind::MAP>(columnVector, row));
} else if (rowType.childAt(column)->isShortDecimal()) {
appender.Append(
duckValueAt<TypeKind::SHORT_DECIMAL>(columnVector, row));
} else if (rowType.childAt(column)->isLongDecimal()) {
appender.Append(
duckValueAt<TypeKind::LONG_DECIMAL>(columnVector, row));
} else {
auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
duckValueAt, rowType.childAt(column)->kind(), columnVector, row);
Expand Down
18 changes: 12 additions & 6 deletions velox/expression/SignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,18 @@ TypePtr inferDecimalType(
const std::unordered_map<std::string, std::string>& constraints) {
const auto& precisionVar = typeSignature.variables()[0];
const auto& scaleVar = typeSignature.variables()[1];
// check for constraints, else set defaults.
const auto& precisionConstraint = constraints.find(precisionVar);
const auto& scaleConstraint = constraints.find(scaleVar);

int precision = 0;
int scale = 0;
// Determine precision.
// Handle constant.
if (isPositiveInteger(precisionVar)) {
// Handle constant.
precision = atoi(precisionVar.c_str());
} else if (variables.find(precisionVar) != variables.end()) {
// Check if variable is already computed.
precision = variables[precisionVar];
} else {
// Check constraints and evaluate.
const auto& precisionConstraint = constraints.find(precisionVar);
VELOX_CHECK(
precisionConstraint != constraints.end(),
"Missing constraint for variable {}",
Expand All @@ -66,10 +67,15 @@ TypePtr inferDecimalType(
precision = variables[precisionVar];
}
// Determine scale.
// Handle constant.
if (isPositiveInteger(scaleVar)) {
// Handle constant.
scale = atoi(scaleVar.c_str());
} else if (variables.find(scaleVar) != variables.end()) {
// Check if variable is already computed.
scale = variables[scaleVar];
} else {
// Check constraints and evaluate.
const auto& scaleConstraint = constraints.find(scaleVar);
VELOX_CHECK(
scaleConstraint != constraints.end(),
"Missing constraint for variable {}",
Expand Down
1 change: 0 additions & 1 deletion velox/expression/SignatureBinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class SignatureBinder {
if (!variable.constraint().empty()) {
constraints_.insert({variable.name(), variable.constraint()});
}
variables_.insert({variable.name(), -1});
}
}

Expand Down
6 changes: 2 additions & 4 deletions velox/expression/tests/SignatureBinderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,8 @@ TEST(SignatureBinderTest, decimals) {
{
auto signature = exec::AggregateFunctionSignatureBuilder()
.argumentType("DECIMAL(a_precision, a_scale)")
.intermediateType("DECIMAL(38, i_scale)")
.variableConstraint("i_scale", "a_scale")
.returnType("DECIMAL(38, r_scale)")
.variableConstraint("r_scale", "a_scale")
.intermediateType("DECIMAL(38, a_scale)")
.returnType("DECIMAL(38, a_scale)")
.build();

const std::vector<TypePtr> actualTypes{DECIMAL(10, 4)};
Expand Down
Loading

0 comments on commit bc0fc81

Please sign in to comment.