From 89721c83ae3782dae2062f567afad62fb600924c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= <ivancea96@outlook.com>
Date: Wed, 16 Oct 2024 13:19:56 +0200
Subject: [PATCH] ESQL: Fix MvPercentileTests precision issues (#114844)

Fixes https://github.com/elastic/elasticsearch/issues/114588
Fixes https://github.com/elastic/elasticsearch/issues/114587
Fixes https://github.com/elastic/elasticsearch/issues/114586
Fixes https://github.com/elastic/elasticsearch/issues/114585
Fixes https://github.com/elastic/elasticsearch/issues/113008
Fixes https://github.com/elastic/elasticsearch/issues/113007
Fixes https://github.com/elastic/elasticsearch/issues/113006
Fixes https://github.com/elastic/elasticsearch/issues/113005

Fixed the long precision issue by allowing a +/-1 range.

Also made a minor refactor to simplify using different matchers for different types.
---
 .../scalar/multivalue/MvPercentileTests.java  | 47 ++++++++++++-------
 1 file changed, 30 insertions(+), 17 deletions(-)

diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPercentileTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPercentileTests.java
index 29cc959e6a943..0a419d44e3448 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPercentileTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPercentileTests.java
@@ -17,6 +17,7 @@
 import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
 import org.elasticsearch.xpack.esql.expression.function.MultivalueTestCaseSupplier;
 import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+import org.hamcrest.Matcher;
 
 import java.math.BigDecimal;
 import java.util.ArrayList;
@@ -28,6 +29,7 @@
 import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
 import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
 import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.nullValue;
@@ -375,27 +377,25 @@ private static TestCaseSupplier makeSupplier(
                 var values = (List<Number>) fieldTypedData.data();
                 var percentile = ((Number) percentileTypedData.data()).doubleValue();
 
-                var expected = calculatePercentile(values, percentile);
+                var expectedMatcher = makePercentileMatcher(values, percentile);
 
                 return new TestCaseSupplier.TestCase(
                     List.of(fieldTypedData, percentileTypedData),
                     evaluatorString(fieldSupplier.type(), percentileSupplier.type()),
                     fieldSupplier.type(),
-                    expected instanceof Double expectedDouble
-                        ? closeTo(expectedDouble, Math.abs(expectedDouble * 0.0000001))
-                        : equalTo(expected)
+                    expectedMatcher
                 );
             }
         );
     }
 
-    private static Number calculatePercentile(List<Number> rawValues, double percentile) {
+    private static Matcher<?> makePercentileMatcher(List<Number> rawValues, double percentile) {
         if (rawValues.isEmpty() || percentile < 0 || percentile > 100) {
-            return null;
+            return nullValue();
         }
 
         if (rawValues.size() == 1) {
-            return rawValues.get(0);
+            return equalTo(rawValues.get(0));
         }
 
         int valueCount = rawValues.size();
@@ -407,49 +407,62 @@ private static Number calculatePercentile(List<Number> rawValues, double percent
 
         if (rawValues.get(0) instanceof Integer) {
             var values = rawValues.stream().mapToInt(Number::intValue).sorted().toArray();
+            int expected;
 
             if (percentile == 0) {
-                return values[0];
+                expected = values[0];
             } else if (percentile == 100) {
-                return values[valueCount - 1];
+                expected = values[valueCount - 1];
             } else {
                 assert lowerIndex >= 0 && upperIndex < valueCount;
                 var difference = (long) values[upperIndex] - values[lowerIndex];
-                return values[lowerIndex] + (int) (fraction * difference);
+                expected = values[lowerIndex] + (int) (fraction * difference);
             }
+
+            return equalTo(expected);
         }
 
         if (rawValues.get(0) instanceof Long) {
             var values = rawValues.stream().mapToLong(Number::longValue).sorted().toArray();
+            long expected;
 
             if (percentile == 0) {
-                return values[0];
+                expected = values[0];
             } else if (percentile == 100) {
-                return values[valueCount - 1];
+                expected = values[valueCount - 1];
             } else {
                 assert lowerIndex >= 0 && upperIndex < valueCount;
-                return calculatePercentile(fraction, new BigDecimal(values[lowerIndex]), new BigDecimal(values[upperIndex])).longValue();
+                expected = calculatePercentile(fraction, BigDecimal.valueOf(values[lowerIndex]), BigDecimal.valueOf(values[upperIndex]))
+                    .longValue();
             }
+
+            // Double*bigLong may lose precision, we allow a small range
+            return anyOf(equalTo(Math.min(expected, expected - 1)), equalTo(expected), equalTo(Math.max(expected, expected + 1)));
         }
 
         if (rawValues.get(0) instanceof Double) {
             var values = rawValues.stream().mapToDouble(Number::doubleValue).sorted().toArray();
+            double expected;
 
             if (percentile == 0) {
-                return values[0];
+                expected = values[0];
             } else if (percentile == 100) {
-                return values[valueCount - 1];
+                expected = values[valueCount - 1];
             } else {
                 assert lowerIndex >= 0 && upperIndex < valueCount;
-                return calculatePercentile(fraction, new BigDecimal(values[lowerIndex]), new BigDecimal(values[upperIndex])).doubleValue();
+                expected = calculatePercentile(fraction, new BigDecimal(values[lowerIndex]), new BigDecimal(values[upperIndex]))
+                    .doubleValue();
             }
+
+            return closeTo(expected, Math.abs(expected * 0.0000001));
         }
 
         throw new IllegalArgumentException("Unsupported type: " + rawValues.get(0).getClass());
     }
 
     private static BigDecimal calculatePercentile(double fraction, BigDecimal lowerValue, BigDecimal upperValue) {
-        return lowerValue.add(new BigDecimal(fraction).multiply(upperValue.subtract(lowerValue)));
+        var difference = upperValue.subtract(lowerValue);
+        return lowerValue.add(new BigDecimal(fraction).multiply(difference));
     }
 
     private static TestCaseSupplier.TypedData percentileWithType(Number value, DataType type) {