diff --git a/presto-docs/src/main/sphinx/functions/math.rst b/presto-docs/src/main/sphinx/functions/math.rst index ceab1579c5ead..e1353fa892800 100644 --- a/presto-docs/src/main/sphinx/functions/math.rst +++ b/presto-docs/src/main/sphinx/functions/math.rst @@ -168,8 +168,11 @@ Mathematical Functions .. function:: width_bucket(x, bins) -> bigint Returns the bin number of ``x`` according to the bins specified by the - array ``bins``. The ``bins`` parameter must be an array of doubles and is - assumed to be in sorted ascending order. + array ``bins``. The ``bins`` parameter must be an array of doubles, should not + contain ``null`` or non-finite elements and assumed to be in sorted ascending order. + The function will generally throw an error if it encounters a ``null`` or non-finite + element in ``bins``. However, due to the binary search algorithm some such elements + might go unnoticed and the function will return a result. Probability Functions: cdf -------------------------- diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index 231e8db1b0aa8..707de7d28cd0f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -1572,17 +1572,23 @@ public static long widthBucket(@SqlType(StandardTypes.DOUBLE) double operand, @S int index; double bin; + double lowerBin; + double upperBin; while (lower < upper) { - if (DOUBLE.getDouble(bins, lower) > DOUBLE.getDouble(bins, upper - 1)) { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin values are not sorted in ascending order"); + index = (lower + upper) / 2; + if (bins.isNull(lower) || bins.isNull(index) || bins.isNull(upper - 1)) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin values cannot be NULL"); } - index = (lower + upper) / 2; bin = DOUBLE.getDouble(bins, index); - - if (!isFinite(bin)) { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin value must be finite, got " + bin); + lowerBin = DOUBLE.getDouble(bins, lower); + upperBin = DOUBLE.getDouble(bins, upper - 1); + if (lowerBin > upperBin || lowerBin > bin || bin > upperBin) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin values are not sorted in ascending order"); + } + if (!isFinite(bin) || !isFinite(lowerBin) || !isFinite(upperBin)) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin values must be finite"); } if (operand < bin) { diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java index a5ab98de85fa3..bb63a91410bc7 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java @@ -1315,15 +1315,26 @@ public void testWidthBucketArray() // failure modes assertInvalidFunction("width_bucket(3.14E0, array[])", "Bins cannot be an empty array"); assertInvalidFunction("width_bucket(nan(), array[1.0E0, 2.0E0, 3.0E0])", "Operand cannot be NaN"); - assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, infinity()])", "Bin value must be finite, got Infinity"); + assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, infinity()])", "Bin values must be finite"); // fail if we aren't sorted + assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, infinity(), 10.0E0])", "Bin values are not sorted in ascending order"); assertInvalidFunction("width_bucket(3.145E0, array[1.0E0, 0.0E0])", "Bin values are not sorted in ascending order"); assertInvalidFunction("width_bucket(3.145E0, array[1.0E0, 0.0E0, -1.0E0])", "Bin values are not sorted in ascending order"); assertInvalidFunction("width_bucket(3.145E0, array[1.0E0, 0.3E0, 0.0E0, -1.0E0])", "Bin values are not sorted in ascending order"); + assertInvalidFunction("width_bucket(1.5E0, array[1.0E0, 2.3E0, 2.0E0])", "Bin values are not sorted in ascending order"); - // this is a case that we can't catch because we are using binary search to bisect the bins array - assertFunction("width_bucket(1.5E0, array[1.0E0, 2.3E0, 2.0E0])", BIGINT, 1L); + // Cases with nulls. When we hit a null element we throw. + assertInvalidFunction("width_bucket(3.14E0, array[cast(null as double)])", "Bin values cannot be NULL"); + assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, null, 4.0E0])", "Bin values cannot be NULL"); + assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, 2.0E0, 4.0E0, null])", "Bin values cannot be NULL"); + + // Cases we cannot catch due to the binary search algorithm. + // Has null elements. + assertFunction("width_bucket(3.14E0, array[0.0E0, null, 2.0E0, 4.0E0])", BIGINT, 3L); + assertFunction("width_bucket(3.14E0, array[0.0E0, null, 1.0E0, 2.0E0, 4.0E0])", BIGINT, 4L); + // Not properly sorted and has infinity. + assertFunction("width_bucket(3.14E0, array[0.0E0, infinity(), 1.0E0, 2.0E0, 4.0E0])", BIGINT, 4L); } @Test @@ -1335,7 +1346,7 @@ public void testCosineSimilarity() assertFunction("cosine_similarity(map(array ['a', 'b', 'c'], array [1.0E0, 2.0E0, -1.0E0]), map(array ['c', 'b'], array [1.0E0, 3.0E0]))", DOUBLE, - (2 * 3 + (-1) * 1) / (Math.sqrt(1 + 4 + 1) * Math.sqrt(1 + 9))); + (2 * 3 + (-1)) / (Math.sqrt(1 + 4 + 1) * Math.sqrt(1 + 9))); assertFunction("cosine_similarity(map(array ['a', 'b', 'c'], array [1.0E0, 2.0E0, -1.0E0]), map(array ['d', 'e'], array [1.0E0, 3.0E0]))", DOUBLE,