Skip to content

Commit

Permalink
Add truncate(x,n) Presto function (facebookincubator#2892)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#2892

Adding truncate(x,n) Presto function

Differential Revision: D40516665

fbshipit-source-id: e076863863fb8b54d6960eb479d7f1447c708486
  • Loading branch information
Gosh Arzumanyan authored and facebook-github-bot committed Nov 3, 2022
1 parent 4711235 commit 831af7b
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 2 deletions.
4 changes: 4 additions & 0 deletions velox/docs/functions/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ Mathematical Functions

Returns x rounded to integer by dropping digits after decimal point.

.. function:: truncate(x, n) -> double

Returns x truncated to n decimal places. n can be negative to truncate n digits left of the decimal point.

.. function:: width_bucket(x, bound1, bound2, n) -> bigint

Returns the bin number of ``x`` in an equi-width histogram with the
Expand Down
5 changes: 5 additions & 0 deletions velox/functions/prestosql/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,11 @@ struct TruncateFunction {
FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a) {
result = std::trunc(a);
}

template <typename TInput>
FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a, int32_t n) {
result = truncate(a, n);
}
};

} // namespace
Expand Down
30 changes: 30 additions & 0 deletions velox/functions/prestosql/ArithmeticImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
#pragma once

#include <gmp.h>
#include <mpfr.h>
#include <algorithm>
#include <cmath>
#include <type_traits>
Expand Down Expand Up @@ -132,4 +134,32 @@ T ceil(const T& arg) {
return results;
}

template <typename TNum, typename TDecimals>
FOLLY_ALWAYS_INLINE TNum
truncate(const TNum& number, const TDecimals& decimals = 0) {
if (std::isnan(number) || std::isinf(number) || decimals == 0) {
return std::trunc(number);
}

mpfr_t factor, result, dec;

mpfr_init2(factor, 200);
mpfr_set_d(factor, 10, MPFR_RNDD);
mpfr_init2(result, 200);
mpfr_set_d(result, number, MPFR_RNDD);
mpfr_init2(dec, 200);
mpfr_set_d(dec, decimals, MPFR_RNDD);

mpfr_pow(factor, factor, dec, MPFR_RNDD);
mpfr_mul(result, result, factor, MPFR_RNDD);
mpfr_rint_trunc(result, result, MPFR_RNDD);
mpfr_div(result, result, factor, MPFR_RNDD);

const auto resultToReturn = mpfr_get_d(result, MPFR_RNDD);
mpfr_clear(factor);
mpfr_clear(result);
mpfr_clear(dec);
return resultToReturn;
}

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ void registerSimpleFunctions() {
registerFunction<PiFunction, double>({"pi"});
registerFunction<EulerConstantFunction, double>({"e"});
registerUnaryNumeric<TruncateFunction>({"truncate"});
registerFunction<TruncateFunction, double, double, int32_t>({"truncate"});
}

} // namespace
Expand Down
43 changes: 41 additions & 2 deletions velox/functions/prestosql/tests/ArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <optional>

#include <gmock/gmock.h>
Expand All @@ -33,6 +34,10 @@ MATCHER(IsNan, "is NaN") {
return arg && std::isnan(*arg);
}

MATCHER(IsInf, "is Infinity") {
return arg && std::isinf(*arg);
}

class ArithmeticTest : public functions::test::FunctionBaseTest {
protected:
template <typename T, typename TExpected = T>
Expand Down Expand Up @@ -631,15 +636,49 @@ TEST_F(ArithmeticTest, clamp) {
}

TEST_F(ArithmeticTest, truncate) {
const auto truncate = [&](std::optional<double> a) {
return evaluateOnce<double>("truncate(c0)", a);
const auto truncate = [&](std::optional<double> a,
std::optional<int32_t> n = 0) {
return evaluateOnce<double>("truncate(c0,c1)", a, n);
};

EXPECT_EQ(truncate(0), 0);
EXPECT_EQ(truncate(1.5), 1);
EXPECT_EQ(truncate(-1.5), -1);
EXPECT_EQ(truncate(std::nullopt), std::nullopt);
EXPECT_THAT(truncate(kNan), IsNan());
EXPECT_THAT(truncate(kInf), IsInf());

EXPECT_EQ(truncate(0, 0), 0);
EXPECT_EQ(truncate(1.5, 0), 1);
EXPECT_EQ(truncate(-1.5, 0), -1);
EXPECT_EQ(truncate(std::nullopt, 0), std::nullopt);
EXPECT_EQ(truncate(1.5, std::nullopt), std::nullopt);
EXPECT_THAT(truncate(kNan, 0), IsNan());
EXPECT_THAT(truncate(kNan, 1), IsNan());
EXPECT_THAT(truncate(kInf, 0), IsInf());
EXPECT_THAT(truncate(kInf, 1), IsInf());

EXPECT_DOUBLE_EQ(truncate(1.5678, 2).value(), 1.56);
EXPECT_DOUBLE_EQ(truncate(-1.5678, 2).value(), -1.56);
EXPECT_DOUBLE_EQ(truncate(1.333, -1).value(), 0);
EXPECT_DOUBLE_EQ(truncate(3.54555, 2).value(), 3.54);
EXPECT_DOUBLE_EQ(truncate(1234, 1).value(), 1234);
EXPECT_DOUBLE_EQ(truncate(1234, -1).value(), 1230);
EXPECT_DOUBLE_EQ(truncate(1234.56, 1).value(), 1234.5);
EXPECT_DOUBLE_EQ(truncate(1234.56, -1).value(), 1230.0);
EXPECT_DOUBLE_EQ(truncate(1239.999, 2).value(), 1239.99);
EXPECT_DOUBLE_EQ(truncate(1239.999, -2).value(), 1200.0);
EXPECT_DOUBLE_EQ(
truncate(123456789012345678901.23, 3).value(), 123456789012345678901.23);
EXPECT_DOUBLE_EQ(
truncate(-123456789012345678901.23, 3).value(),
-123456789012345678901.23);
EXPECT_DOUBLE_EQ(
truncate(123456789123456.999, 2).value(), 123456789123456.99);
EXPECT_DOUBLE_EQ(truncate(123456789012345678901.0, -21).value(), 0.0);
EXPECT_DOUBLE_EQ(truncate(123456789012345678901.23, -21).value(), 0.0);
EXPECT_DOUBLE_EQ(truncate(123456789012345678901.0, -21).value(), 0.0);
EXPECT_DOUBLE_EQ(truncate(123456789012345678901.23, -21).value(), 0.0);
}

} // namespace
Expand Down

0 comments on commit 831af7b

Please sign in to comment.