From 4841bbd886ede94cd1e6cb427c8f33fb1ba5fa57 Mon Sep 17 00:00:00 2001 From: Deepak Majeti Date: Mon, 1 Aug 2022 11:13:55 -0400 Subject: [PATCH] add LongDecimal->ShortDecimal conversion --- .../prestosql/aggregates/SimpleNumericAggregate.h | 15 +-------------- velox/type/LongDecimal.cpp | 1 - velox/type/LongDecimal.h | 9 +++++++-- velox/type/ShortDecimal.h | 3 +++ 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/velox/functions/prestosql/aggregates/SimpleNumericAggregate.h b/velox/functions/prestosql/aggregates/SimpleNumericAggregate.h index 302ad24e6808e..253731a2e09d3 100644 --- a/velox/functions/prestosql/aggregates/SimpleNumericAggregate.h +++ b/velox/functions/prestosql/aggregates/SimpleNumericAggregate.h @@ -147,20 +147,7 @@ class SimpleNumericAggregate : public exec::Aggregate { if (!decoded.isNullAt(0)) { updateDuplicateValues( initialValue, decoded.valueAt(0), rows.countSelected()); - // Some DECIMAL type aggregations requires conversion from LongDecimal - // to ShortDecimal. However, this conversion is not desired as it - // requires a safety check. Specialize this case instead. - if constexpr ( - std::is_same::value && - std::is_same::value) { - updateNonNullValue( - group, - static_cast(initialValue).unscaledValue(), - updateSingleValue); - } else { - updateNonNullValue( - group, initialValue, updateSingleValue); - } + updateNonNullValue(group, initialValue, updateSingleValue); } } else if (decoded.mayHaveNulls()) { rows.applyToSelected([&](vector_size_t i) { diff --git a/velox/type/LongDecimal.cpp b/velox/type/LongDecimal.cpp index 020aa7cb1e364..74ab08990f66a 100644 --- a/velox/type/LongDecimal.cpp +++ b/velox/type/LongDecimal.cpp @@ -34,5 +34,4 @@ string to_string(facebook::velox::int128_t x) { reverse(ans.begin(), ans.end()); return ans; } - } // namespace std diff --git a/velox/type/LongDecimal.h b/velox/type/LongDecimal.h index ff643ec213f46..ff97aa96f61f5 100644 --- a/velox/type/LongDecimal.h +++ b/velox/type/LongDecimal.h @@ -37,8 +37,7 @@ struct LongDecimal { // Default required for creating vector with NULL values. LongDecimal() = default; constexpr LongDecimal(int128_t value) : unscaledValue_(value) {} - constexpr LongDecimal(ShortDecimal value) - : unscaledValue_(value.unscaledValue()) {} + LongDecimal(ShortDecimal& value) : unscaledValue_(value.unscaledValue()) {} int128_t unscaledValue() const { return unscaledValue_; @@ -78,6 +77,12 @@ static inline LongDecimal operator*(const int value, const LongDecimal& other) { return LongDecimal(value * other.unscaledValue()); } +inline ShortDecimal::ShortDecimal(LongDecimal& other) { + // Ensure the value fits. + VELOX_DCHECK((other.unscaledValue() >> 64) == 0); + unscaledValue_ = other.unscaledValue(); +} + } // namespace facebook::velox namespace folly { diff --git a/velox/type/ShortDecimal.h b/velox/type/ShortDecimal.h index 79b613456bb5a..6e47caf4543d0 100644 --- a/velox/type/ShortDecimal.h +++ b/velox/type/ShortDecimal.h @@ -23,11 +23,14 @@ namespace facebook::velox { +struct LongDecimal; + struct ShortDecimal { public: // Default required for creating vector with NULL values. ShortDecimal() = default; constexpr ShortDecimal(int64_t value) : unscaledValue_(value) {} + ShortDecimal(LongDecimal& value); int64_t unscaledValue() const { return unscaledValue_;