Skip to content

Commit

Permalink
Tighten checks for "bins" elements in width_bucket(x, bins).
Browse files Browse the repository at this point in the history
  • Loading branch information
spershin committed Nov 22, 2024
1 parent b19167e commit 72e5c57
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
8 changes: 6 additions & 2 deletions presto-docs/src/main/sphinx/functions/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,12 @@ Mathematical Functions
.. function:: width_bucket(x, bins) -> bigint

Returns the bin number of ``x`` according to the bins specified by the
array ``bins``. The ``bins`` parameter must be an array of doubles and is
assumed to be in sorted ascending order.
array ``bins``. The ``bins`` parameter must be an array of doubles, should not
contain ``null`` or non-finite elements, and is assumed to be in sorted ascending order.

Note: The function returns an error if it encounters a ``null`` or non-finite
element in ``bins``, but due to the binary search algorithm some such elements
might go unnoticed and the function will return a result.

Probability Functions: cdf
--------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1572,17 +1572,23 @@ public static long widthBucket(@SqlType(StandardTypes.DOUBLE) double operand, @S

int index;
double bin;
double lowerBin;
double upperBin;

while (lower < upper) {
if (DOUBLE.getDouble(bins, lower) > DOUBLE.getDouble(bins, upper - 1)) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin values are not sorted in ascending order");
index = (lower + upper) / 2;
if (bins.isNull(lower) || bins.isNull(index) || bins.isNull(upper - 1)) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin values cannot be NULL");
}

index = (lower + upper) / 2;
bin = DOUBLE.getDouble(bins, index);

if (!isFinite(bin)) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin value must be finite, got " + bin);
lowerBin = DOUBLE.getDouble(bins, lower);
upperBin = DOUBLE.getDouble(bins, upper - 1);
if (lowerBin > upperBin || lowerBin > bin || bin > upperBin) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin values are not sorted in ascending order");
}
if (!isFinite(bin) || !isFinite(lowerBin) || !isFinite(upperBin)) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Bin values must be finite");
}

if (operand < bin) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1315,15 +1315,26 @@ public void testWidthBucketArray()
// failure modes
assertInvalidFunction("width_bucket(3.14E0, array[])", "Bins cannot be an empty array");
assertInvalidFunction("width_bucket(nan(), array[1.0E0, 2.0E0, 3.0E0])", "Operand cannot be NaN");
assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, infinity()])", "Bin value must be finite, got Infinity");
assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, infinity()])", "Bin values must be finite");

// fail if we aren't sorted
assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, infinity(), 10.0E0])", "Bin values are not sorted in ascending order");
assertInvalidFunction("width_bucket(3.145E0, array[1.0E0, 0.0E0])", "Bin values are not sorted in ascending order");
assertInvalidFunction("width_bucket(3.145E0, array[1.0E0, 0.0E0, -1.0E0])", "Bin values are not sorted in ascending order");
assertInvalidFunction("width_bucket(3.145E0, array[1.0E0, 0.3E0, 0.0E0, -1.0E0])", "Bin values are not sorted in ascending order");

// this is a case that we can't catch because we are using binary search to bisect the bins array
assertFunction("width_bucket(1.5E0, array[1.0E0, 2.3E0, 2.0E0])", BIGINT, 1L);
assertInvalidFunction("width_bucket(1.5E0, array[1.0E0, 2.3E0, 2.0E0])", "Bin values are not sorted in ascending order");

// Cases with nulls. When we hit a null element we throw.
assertInvalidFunction("width_bucket(3.14E0, array[cast(null as double)])", "Bin values cannot be NULL");
assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, null, 4.0E0])", "Bin values cannot be NULL");
assertInvalidFunction("width_bucket(3.14E0, array[0.0E0, 2.0E0, 4.0E0, null])", "Bin values cannot be NULL");

// Cases we cannot catch due to the binary search algorithm.
// Has null elements.
assertFunction("width_bucket(3.14E0, array[0.0E0, null, 2.0E0, 4.0E0])", BIGINT, 3L);
assertFunction("width_bucket(3.14E0, array[0.0E0, null, 1.0E0, 2.0E0, 4.0E0])", BIGINT, 4L);
// Not properly sorted and has infinity.
assertFunction("width_bucket(3.14E0, array[0.0E0, infinity(), 1.0E0, 2.0E0, 4.0E0])", BIGINT, 4L);
}

@Test
Expand Down

0 comments on commit 72e5c57

Please sign in to comment.