From 89721c83ae3782dae2062f567afad62fb600924c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= <ivancea96@outlook.com> Date: Wed, 16 Oct 2024 13:19:56 +0200 Subject: [PATCH] ESQL: Fix MvPercentileTests precision issues (#114844) Fixes https://github.com/elastic/elasticsearch/issues/114588 Fixes https://github.com/elastic/elasticsearch/issues/114587 Fixes https://github.com/elastic/elasticsearch/issues/114586 Fixes https://github.com/elastic/elasticsearch/issues/114585 Fixes https://github.com/elastic/elasticsearch/issues/113008 Fixes https://github.com/elastic/elasticsearch/issues/113007 Fixes https://github.com/elastic/elasticsearch/issues/113006 Fixes https://github.com/elastic/elasticsearch/issues/113005 Fixed the long precision issue by allowing a +/-1 range. Also made a minor refactor to simplify using different matchers for different types. --- .../scalar/multivalue/MvPercentileTests.java | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPercentileTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPercentileTests.java index 29cc959e6a943..0a419d44e3448 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPercentileTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPercentileTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.MultivalueTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; import java.math.BigDecimal; import java.util.ArrayList; @@ -28,6 +29,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.nullValue; @@ -375,27 +377,25 @@ private static TestCaseSupplier makeSupplier( var values = (List<Number>) fieldTypedData.data(); var percentile = ((Number) percentileTypedData.data()).doubleValue(); - var expected = calculatePercentile(values, percentile); + var expectedMatcher = makePercentileMatcher(values, percentile); return new TestCaseSupplier.TestCase( List.of(fieldTypedData, percentileTypedData), evaluatorString(fieldSupplier.type(), percentileSupplier.type()), fieldSupplier.type(), - expected instanceof Double expectedDouble - ? closeTo(expectedDouble, Math.abs(expectedDouble * 0.0000001)) - : equalTo(expected) + expectedMatcher ); } ); } - private static Number calculatePercentile(List<Number> rawValues, double percentile) { + private static Matcher<?> makePercentileMatcher(List<Number> rawValues, double percentile) { if (rawValues.isEmpty() || percentile < 0 || percentile > 100) { - return null; + return nullValue(); } if (rawValues.size() == 1) { - return rawValues.get(0); + return equalTo(rawValues.get(0)); } int valueCount = rawValues.size(); @@ -407,49 +407,62 @@ private static Number calculatePercentile(List<Number> rawValues, double percent if (rawValues.get(0) instanceof Integer) { var values = rawValues.stream().mapToInt(Number::intValue).sorted().toArray(); + int expected; if (percentile == 0) { - return values[0]; + expected = values[0]; } else if (percentile == 100) { - return values[valueCount - 1]; + expected = values[valueCount - 1]; } else { assert lowerIndex >= 0 && upperIndex < valueCount; var difference = (long) values[upperIndex] - values[lowerIndex]; - return values[lowerIndex] + (int) (fraction * difference); + expected = values[lowerIndex] + (int) (fraction * difference); } + + return equalTo(expected); } if (rawValues.get(0) instanceof Long) { var values = rawValues.stream().mapToLong(Number::longValue).sorted().toArray(); + long expected; if (percentile == 0) { - return values[0]; + expected = values[0]; } else if (percentile == 100) { - return values[valueCount - 1]; + expected = values[valueCount - 1]; } else { assert lowerIndex >= 0 && upperIndex < valueCount; - return calculatePercentile(fraction, new BigDecimal(values[lowerIndex]), new BigDecimal(values[upperIndex])).longValue(); + expected = calculatePercentile(fraction, BigDecimal.valueOf(values[lowerIndex]), BigDecimal.valueOf(values[upperIndex])) + .longValue(); } + + // Double*bigLong may lose precision, we allow a small range + return anyOf(equalTo(Math.min(expected, expected - 1)), equalTo(expected), equalTo(Math.max(expected, expected + 1))); } if (rawValues.get(0) instanceof Double) { var values = rawValues.stream().mapToDouble(Number::doubleValue).sorted().toArray(); + double expected; if (percentile == 0) { - return values[0]; + expected = values[0]; } else if (percentile == 100) { - return values[valueCount - 1]; + expected = values[valueCount - 1]; } else { assert lowerIndex >= 0 && upperIndex < valueCount; - return calculatePercentile(fraction, new BigDecimal(values[lowerIndex]), new BigDecimal(values[upperIndex])).doubleValue(); + expected = calculatePercentile(fraction, new BigDecimal(values[lowerIndex]), new BigDecimal(values[upperIndex])) + .doubleValue(); } + + return closeTo(expected, Math.abs(expected * 0.0000001)); } throw new IllegalArgumentException("Unsupported type: " + rawValues.get(0).getClass()); } private static BigDecimal calculatePercentile(double fraction, BigDecimal lowerValue, BigDecimal upperValue) { - return lowerValue.add(new BigDecimal(fraction).multiply(upperValue.subtract(lowerValue))); + var difference = upperValue.subtract(lowerValue); + return lowerValue.add(new BigDecimal(fraction).multiply(difference)); } private static TestCaseSupplier.TypedData percentileWithType(Number value, DataType type) {