diff --git a/velox/docs/functions/presto/math.rst b/velox/docs/functions/presto/math.rst index c394182c6615..41c9852ccb82 100644 --- a/velox/docs/functions/presto/math.rst +++ b/velox/docs/functions/presto/math.rst @@ -192,12 +192,16 @@ Mathematical Functions :noindex: Returns the zero-based bin number of ``x`` according to the bins specified - by the array ``bins``. The ``bins`` parameter must be an array of doubles and - is assumed to be in sorted ascending order. + by the array ``bins``. The ``bins`` parameter must be an array of doubles, should not + contain ``null`` or non-finite elements, and is assumed to be in sorted ascending order. For example, if ``bins`` is ``ARRAY[0, 2, 4]``, then we have four bins: ``(-infinity(), 0)``, ``[0, 2)``, ``[2, 4)`` and ``[4, infinity())``. + Note: The function returns an error if it encounters a ``null`` or non-finite + element in ``bins``, but due to the binary search algorithm some such elements + might go unnoticed and the function will return a result. + ==================================== Trigonometric Functions diff --git a/velox/functions/prestosql/WidthBucketArray.cpp b/velox/functions/prestosql/WidthBucketArray.cpp index a8c3a7fbb478..5ad91e56f21d 100644 --- a/velox/functions/prestosql/WidthBucketArray.cpp +++ b/velox/functions/prestosql/WidthBucketArray.cpp @@ -34,15 +34,22 @@ int64_t widthBucket( int lower = 0; int upper = binCount; while (lower < upper) { - VELOX_USER_CHECK_LE( - elementsHolder.valueAt(offset + lower), - elementsHolder.valueAt(offset + upper - 1), - "Bin values are not sorted in ascending order"); - - int index = (lower + upper) / 2; - auto bin = elementsHolder.valueAt(offset + index); + const int index = (lower + upper) / 2; + VELOX_USER_CHECK( + !elementsHolder.isNullAt(lower) && !elementsHolder.isNullAt(index) && + !elementsHolder.isNullAt(upper - 1), + "Bin values cannot be NULL"); - VELOX_USER_CHECK(std::isfinite(bin), "Bin value must be finite"); + const auto bin = elementsHolder.valueAt(offset + index); + const auto lowerBin = elementsHolder.valueAt(offset + lower); + const auto upperBin = elementsHolder.valueAt(offset + upper - 1); + VELOX_USER_CHECK( + lowerBin <= bin && bin <= upperBin, + "Bin values are not sorted in ascending order"); + VELOX_USER_CHECK( + std::isfinite(bin) && std::isfinite(lowerBin) && + std::isfinite(upperBin), + "Bin values must be finite"); if (operand < bin) { upper = index; @@ -162,9 +169,9 @@ std::vector toBinValues( for (int i = 0; i < size; i++) { VELOX_USER_CHECK( - !simpleVector->isNullAt(offset + i), "Bin value cannot be null"); + !simpleVector->isNullAt(offset + i), "Bin values cannot be null"); auto value = simpleVector->valueAt(offset + i); - VELOX_USER_CHECK(std::isfinite(value), "Bin value must be finite"); + VELOX_USER_CHECK(std::isfinite(value), "Bin values must be finite"); if (i > 0) { VELOX_USER_CHECK_GT( value, diff --git a/velox/functions/prestosql/tests/WidthBucketArrayTest.cpp b/velox/functions/prestosql/tests/WidthBucketArrayTest.cpp index 8bc923f3bc9f..07f410c4d982 100644 --- a/velox/functions/prestosql/tests/WidthBucketArrayTest.cpp +++ b/velox/functions/prestosql/tests/WidthBucketArrayTest.cpp @@ -61,39 +61,51 @@ TEST_F(WidthBucketArrayTest, success) { assertEqualVectors(dictExpected, dictResult); }; - { - binsVector = makeArrayVector({{0.0, 2.0, 4.0}, {0.0}}); - testWidthBucketArray(3.14, {2, 1}); - testWidthBucketArray(kInf, {3, 1}); - testWidthBucketArray(-1, {0, 0}); - } - - { - binsVector = makeArrayVector({{0, 2, 4}, {0}}); - testWidthBucketArray(3.14, {2, 1}); - testWidthBucketArray(kInf, {3, 1}); - testWidthBucketArray(-1, {0, 0}); - } + binsVector = makeArrayVector({{0.0, 2.0, 4.0}, {0.0}}); + testWidthBucketArray(3.14, {2, 1}); + testWidthBucketArray(kInf, {3, 1}); + testWidthBucketArray(-1, {0, 0}); + + binsVector = makeArrayVector({{0, 2, 4}, {0}}); + testWidthBucketArray(3.14, {2, 1}); + testWidthBucketArray(kInf, {3, 1}); + testWidthBucketArray(-1, {0, 0}); + + // Cases we cannot catch due to the binary search algorithm. + binsVector = makeNullableArrayVector( + {{0.0, std::nullopt, 2.0, 4.0}, + {0.0, std::nullopt, 1.0, 2.0, 4.0}, + {0.0, kInf, 1.0, 2.0, 4.0}}); + testWidthBucketArray(3.14, {3, 4, 4}); } TEST_F(WidthBucketArrayTest, failure) { - auto testFailure = [&](const double operand, - const std::vector>& bins, - const std::string& expected_message) { - auto binsVector = makeArrayVector(bins); - VELOX_ASSERT_THROW( - evaluate>( - "width_bucket(c0, c1)", - makeRowVector( - {makeConstant(operand, binsVector->size()), binsVector})), - expected_message); - }; + auto testFailure = + [&](const double operand, + const std::vector>>& bins, + const std::string& expected_message) { + auto binsVector = makeNullableArrayVector(bins); + VELOX_ASSERT_THROW( + evaluate>( + "width_bucket(c0, c1)", + makeRowVector( + {makeConstant(operand, binsVector->size()), binsVector})), + expected_message); + }; testFailure(0, {{}}, "Bins cannot be an empty array"); testFailure(kNan, {{0}}, "Operand cannot be NaN"); - testFailure(1, {{0, kInf}}, "Bin value must be finite"); + testFailure(1, {{0, kInf}}, "Bin values must be finite"); testFailure(1, {{0, kNan}}, "Bin values are not sorted in ascending order"); testFailure(2, {{1, 0}}, "Bin values are not sorted in ascending order"); + testFailure( + 3.14, {{0, kInf, 10}}, "Bin values are not sorted in ascending order"); + testFailure( + 1.5, {{1.0, 2, 3, 2, 0}}, "Bin values are not sorted in ascending order"); + testFailure(3.14, {{std::nullopt}}, "Bin values cannot be NULL"); + testFailure(3.14, {{0.0, std::nullopt, 4.0}}, "Bin values cannot be NULL"); + testFailure( + 3.14, {{0.0, 2.0, 4.0, std::nullopt}}, "Bin values cannot be NULL"); } TEST_F(WidthBucketArrayTest, successForConstantArray) { @@ -120,9 +132,16 @@ TEST_F(WidthBucketArrayTest, successForConstantArray) { testWidthBucketArray(3.14, "ARRAY[0.0]", 1); testWidthBucketArray(kInf, "ARRAY[0.0]", 1); testWidthBucketArray(-1, "ARRAY[0.0]", 0); + + // Cases we cannot catch due to the binary search algorithm. + // If the 'bins' vector has issues we simply fall back to the non-constant + // case. + testWidthBucketArray(3.14, "ARRAY[0.0, NULL, 2.0, 4.0]", 3); + testWidthBucketArray(3.14, "ARRAY[0.0, NULL, 1.0, 2.0, 4.0]", 4); + testWidthBucketArray(3.14, "ARRAY[0.0, Infinity(), 1.0, 2.0, 4.0]", 4); } -TEST_F(WidthBucketArrayTest, failureForConstant) { +TEST_F(WidthBucketArrayTest, failureForConstantArray) { auto testFailure = [&](const double operand, const std::string& bins, const std::string& expected_message) { @@ -133,12 +152,20 @@ TEST_F(WidthBucketArrayTest, failureForConstant) { expected_message); }; - // TODO: Add tests for empty bin and bins that contains infinity(), nan() - // once corresponding casting and non-constant array literal element is - // supported. testFailure(kNan, "ARRAY[0.0]", "Operand cannot be NaN"); testFailure( 2, "ARRAY[1.0, 0.0]", "Bin values are not sorted in ascending order"); + testFailure( + 3.14, + "ARRAY[0.0, Infinity(), 10.0]", + "Bin values are not sorted in ascending order"); + testFailure( + 1.5, + "ARRAY[1.0, 2.0, 3.0, 2.0, 0.0]", + "Bin values are not sorted in ascending order"); + testFailure(3.14, "ARRAY[cast(NULL as double)]", "Bin values cannot be NULL"); + testFailure(3.14, "ARRAY[0.0, NULL, 4.0]", "Bin values cannot be NULL"); + testFailure(3.14, "ARRAY[0.0, 2.0, 4.0, NULL]", "Bin values cannot be NULL"); } TEST_F(WidthBucketArrayTest, makeWidthBucketArrayNoThrow) {