From 82c08dd7f25eb349c47a816ebab2eb4d6157ff79 Mon Sep 17 00:00:00 2001 From: panguixin Date: Wed, 27 Nov 2024 21:32:54 +0800 Subject: [PATCH] Fix multi-value sort for unsigned long Signed-off-by: panguixin --- .../UnsignedLongMultiValueMode.java | 400 ++++++++++++++++++ .../UnsignedLongValuesComparatorSource.java | 6 +- .../UnsignedLongMultiValueModeTests.java | 271 ++++++++++++ 3 files changed, 675 insertions(+), 2 deletions(-) create mode 100644 server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongMultiValueMode.java create mode 100644 server/src/test/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongMultiValueModeTests.java diff --git a/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongMultiValueMode.java b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongMultiValueMode.java new file mode 100644 index 0000000000000..4e20460f5c820 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongMultiValueMode.java @@ -0,0 +1,400 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.opensearch.common.Numbers; +import org.opensearch.index.fielddata.AbstractNumericDocValues; +import org.opensearch.index.fielddata.FieldData; +import org.opensearch.search.MultiValueMode; + +import java.io.IOException; +import java.util.Locale; + +/** + * Defines what values to pick in the case a document contains multiple values for an unsigned long field. + * + * @opensearch.internal + */ +enum UnsignedLongMultiValueMode { + /** + * Pick the sum of all the values. + */ + SUM { + @Override + protected long pick(SortedNumericDocValues values) throws IOException { + final int count = values.docValueCount(); + long total = 0; + for (int index = 0; index < count; ++index) { + total += values.nextValue(); + } + return total; + } + + @Override + protected long pick( + SortedNumericDocValues values, + long missingValue, + DocIdSetIterator docItr, + int startDoc, + int endDoc, + int maxChildren + ) throws IOException { + int totalCount = 0; + long totalValue = 0; + int count = 0; + for (int doc = startDoc; doc < endDoc; doc = docItr.nextDoc()) { + if (values.advanceExact(doc)) { + if (++count > maxChildren) { + break; + } + + final int docCount = values.docValueCount(); + for (int index = 0; index < docCount; ++index) { + totalValue += values.nextValue(); + } + totalCount += docCount; + } + } + return totalCount > 0 ? totalValue : missingValue; + } + }, + + /** + * Pick the average of all the values. + */ + AVG { + @Override + protected long pick(SortedNumericDocValues values) throws IOException { + final int count = values.docValueCount(); + long total = 0; + for (int index = 0; index < count; ++index) { + total += values.nextValue(); + } + return count > 1 ? divideUnsignedAndRoundUp(total, count) : total; + } + + @Override + protected long pick( + SortedNumericDocValues values, + long missingValue, + DocIdSetIterator docItr, + int startDoc, + int endDoc, + int maxChildren + ) throws IOException { + int totalCount = 0; + long totalValue = 0; + int count = 0; + for (int doc = startDoc; doc < endDoc; doc = docItr.nextDoc()) { + if (values.advanceExact(doc)) { + if (++count > maxChildren) { + break; + } + final int docCount = values.docValueCount(); + for (int index = 0; index < docCount; ++index) { + totalValue += values.nextValue(); + } + totalCount += docCount; + } + } + if (totalCount < 1) { + return missingValue; + } + return totalCount > 1 ? divideUnsignedAndRoundUp(totalValue, totalCount) : totalValue; + } + }, + + /** + * Pick the median of the values. + */ + MEDIAN { + @Override + protected long pick(SortedNumericDocValues values) throws IOException { + int count = values.docValueCount(); + long firstValue = values.nextValue(); + if (count == 1) { + return firstValue; + } else if (count == 2) { + long total = firstValue + values.nextValue(); + return (total >>> 1) + (total & 1); + } else if (firstValue >= 0) { + for (int i = 1; i < (count - 1) / 2; ++i) { + values.nextValue(); + } + if (count % 2 == 0) { + long total = values.nextValue() + values.nextValue(); + return (total >>> 1) + (total & 1); + } else { + return values.nextValue(); + } + } + + final long[] docValues = new long[count]; + docValues[0] = firstValue; + int firstPositiveIndex = 0; + for (int i = 1; i < count; ++i) { + docValues[i] = values.nextValue(); + if (docValues[i] >= 0 && firstPositiveIndex == 0) { + firstPositiveIndex = i; + } + } + final int mid = ((count - 1) / 2 + firstPositiveIndex) % count; + if (count % 2 == 0) { + long total = docValues[mid] + docValues[(mid + 1) % count]; + return (total >>> 1) + (total & 1); + } else { + return docValues[mid]; + } + } + }, + + /** + * Pick the lowest value. + */ + MIN { + @Override + protected long pick(SortedNumericDocValues values) throws IOException { + final int count = values.docValueCount(); + final long min = values.nextValue(); + if (count == 1 || min > 0) { + return min; + } + for (int i = 1; i < count; ++i) { + long val = values.nextValue(); + if (val >= 0) { + return val; + } + } + return min; + } + + @Override + protected long pick( + SortedNumericDocValues values, + long missingValue, + DocIdSetIterator docItr, + int startDoc, + int endDoc, + int maxChildren + ) throws IOException { + boolean hasValue = false; + long minValue = Numbers.MAX_UNSIGNED_LONG_VALUE_AS_LONG; + int count = 0; + for (int doc = startDoc; doc < endDoc; doc = docItr.nextDoc()) { + if (values.advanceExact(doc)) { + if (++count > maxChildren) { + break; + } + final long docMin = pick(values); + minValue = Long.compareUnsigned(docMin, minValue) < 0 ? docMin : minValue; + hasValue = true; + } + } + return hasValue ? minValue : missingValue; + } + }, + + /** + * Pick the highest value. + */ + MAX { + @Override + protected long pick(SortedNumericDocValues values) throws IOException { + final int count = values.docValueCount(); + long max = values.nextValue(); + long val; + for (int i = 1; i < count; ++i) { + val = values.nextValue(); + if (max < 0 && val >= 0) { + return max; + } + max = val; + } + return max; + } + + @Override + protected long pick( + SortedNumericDocValues values, + long missingValue, + DocIdSetIterator docItr, + int startDoc, + int endDoc, + int maxChildren + ) throws IOException { + boolean hasValue = false; + long maxValue = Numbers.MIN_UNSIGNED_LONG_VALUE_AS_LONG; + int count = 0; + for (int doc = startDoc; doc < endDoc; doc = docItr.nextDoc()) { + if (values.advanceExact(doc)) { + if (++count > maxChildren) { + break; + } + final long docMax = pick(values); + maxValue = Long.compareUnsigned(maxValue, docMax) < 0 ? docMax : maxValue; + hasValue = true; + } + } + return hasValue ? maxValue : missingValue; + } + }; + + /** + * A case insensitive version of {@link #valueOf(String)} + * + * @throws IllegalArgumentException if the given string doesn't match a sort mode or is null. + */ + private static UnsignedLongMultiValueMode fromString(String sortMode) { + try { + return valueOf(sortMode.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Illegal sort mode: " + sortMode); + } + } + + /** + * Convert a {@link MultiValueMode} to a {@link UnsignedLongMultiValueMode}. + */ + public static UnsignedLongMultiValueMode toUnsignedSortMode(MultiValueMode sortMode) { + return fromString(sortMode.name()); + } + + /** + * Return a {@link NumericDocValues} instance that can be used to sort documents + * with this mode and the provided values. When a document has no value, + * missingValue is returned. + *

+ * Allowed Modes: SUM, AVG, MEDIAN, MIN, MAX + */ + public NumericDocValues select(final SortedNumericDocValues values) { + final NumericDocValues singleton = DocValues.unwrapSingleton(values); + if (singleton != null) { + return singleton; + } else { + return new AbstractNumericDocValues() { + + private long value; + + @Override + public boolean advanceExact(int target) throws IOException { + if (values.advanceExact(target)) { + value = pick(values); + return true; + } + return false; + } + + @Override + public int docID() { + return values.docID(); + } + + @Override + public long longValue() throws IOException { + return value; + } + }; + } + } + + protected long pick(SortedNumericDocValues values) throws IOException { + throw new IllegalArgumentException("Unsupported sort mode: " + this); + } + + /** + * Return a {@link NumericDocValues} instance that can be used to sort root documents + * with this mode, the provided values and filters for root/inner documents. + *

+ * For every root document, the values of its inner documents will be aggregated. + * If none of the inner documents has a value, then missingValue is returned. + *

+ * Allowed Modes: SUM, AVG, MIN, MAX + *

+ * NOTE: Calling the returned instance on docs that are not root docs is illegal + * The returned instance can only be evaluate the current and upcoming docs + */ + public NumericDocValues select( + final SortedNumericDocValues values, + final long missingValue, + final BitSet parentDocs, + final DocIdSetIterator childDocs, + int maxDoc, + int maxChildren + ) throws IOException { + if (parentDocs == null || childDocs == null) { + return FieldData.replaceMissing(DocValues.emptyNumeric(), missingValue); + } + + return new AbstractNumericDocValues() { + + int lastSeenParentDoc = -1; + long lastEmittedValue = missingValue; + + @Override + public boolean advanceExact(int parentDoc) throws IOException { + assert parentDoc >= lastSeenParentDoc : "can only evaluate current and upcoming parent docs"; + if (parentDoc == lastSeenParentDoc) { + return true; + } else if (parentDoc == 0) { + lastEmittedValue = missingValue; + return true; + } + final int prevParentDoc = parentDocs.prevSetBit(parentDoc - 1); + final int firstChildDoc; + if (childDocs.docID() > prevParentDoc) { + firstChildDoc = childDocs.docID(); + } else { + firstChildDoc = childDocs.advance(prevParentDoc + 1); + } + + lastSeenParentDoc = parentDoc; + lastEmittedValue = pick(values, missingValue, childDocs, firstChildDoc, parentDoc, maxChildren); + return true; + } + + @Override + public int docID() { + return lastSeenParentDoc; + } + + @Override + public long longValue() { + return lastEmittedValue; + } + }; + } + + protected long pick( + SortedNumericDocValues values, + long missingValue, + DocIdSetIterator docItr, + int startDoc, + int endDoc, + int maxChildren + ) throws IOException { + throw new IllegalArgumentException("Unsupported sort mode: " + this); + } + + /** + * Copied from {@link Long#divideUnsigned(long, long)} and {@link Long#remainderUnsigned(long, long)} + */ + private static long divideUnsignedAndRoundUp(long dividend, long divisor) { + assert divisor > 0; + final long q = (dividend >>> 1) / divisor << 1; + final long r = dividend - q * divisor; + final long quotient = q + ((r | ~(r - divisor)) >>> (Long.SIZE - 1)); + final long rem = r - ((~(r - divisor) >> (Long.SIZE - 1)) & divisor); + return quotient + Math.round((double) rem / divisor); + } +} diff --git a/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongValuesComparatorSource.java b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongValuesComparatorSource.java index 9db5817450cd0..c7b9c9f445484 100644 --- a/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongValuesComparatorSource.java +++ b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongValuesComparatorSource.java @@ -41,6 +41,7 @@ public class UnsignedLongValuesComparatorSource extends IndexFieldData.XFieldComparatorSource { private final IndexNumericFieldData indexFieldData; + private final UnsignedLongMultiValueMode unsignedSortMode; public UnsignedLongValuesComparatorSource( IndexNumericFieldData indexFieldData, @@ -50,6 +51,7 @@ public UnsignedLongValuesComparatorSource( ) { super(missingValue, sortMode, nested); this.indexFieldData = indexFieldData; + this.unsignedSortMode = UnsignedLongMultiValueMode.toUnsignedSortMode(sortMode); } @Override @@ -66,12 +68,12 @@ private SortedNumericDocValues loadDocValues(LeafReaderContext context) { private NumericDocValues getNumericDocValues(LeafReaderContext context, BigInteger missingValue) throws IOException { final SortedNumericDocValues values = loadDocValues(context); if (nested == null) { - return FieldData.replaceMissing(sortMode.select(values), missingValue); + return FieldData.replaceMissing(unsignedSortMode.select(values), missingValue); } final BitSet rootDocs = nested.rootDocs(context); final DocIdSetIterator innerDocs = nested.innerDocs(context); final int maxChildren = nested.getNestedSort() != null ? nested.getNestedSort().getMaxChildren() : Integer.MAX_VALUE; - return sortMode.select(values, missingValue.longValue(), rootDocs, innerDocs, context.reader().maxDoc(), maxChildren); + return unsignedSortMode.select(values, missingValue.longValue(), rootDocs, innerDocs, context.reader().maxDoc(), maxChildren); } @Override diff --git a/server/src/test/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongMultiValueModeTests.java b/server/src/test/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongMultiValueModeTests.java new file mode 100644 index 0000000000000..afa562862cfe9 --- /dev/null +++ b/server/src/test/java/org/opensearch/index/fielddata/fieldcomparator/UnsignedLongMultiValueModeTests.java @@ -0,0 +1,271 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.common.Numbers; +import org.opensearch.index.fielddata.AbstractNumericDocValues; +import org.opensearch.index.fielddata.AbstractSortedNumericDocValues; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.RoundingMode; +import java.util.Arrays; +import java.util.function.Supplier; + +public class UnsignedLongMultiValueModeTests extends OpenSearchTestCase { + private static FixedBitSet randomRootDocs(int maxDoc) { + FixedBitSet set = new FixedBitSet(maxDoc); + for (int i = 0; i < maxDoc; ++i) { + if (randomBoolean()) { + set.set(i); + } + } + // the last doc must be a root doc + set.set(maxDoc - 1); + return set; + } + + private static FixedBitSet randomInnerDocs(FixedBitSet rootDocs) { + FixedBitSet innerDocs = new FixedBitSet(rootDocs.length()); + for (int i = 0; i < innerDocs.length(); ++i) { + if (!rootDocs.get(i) && randomBoolean()) { + innerDocs.set(i); + } + } + return innerDocs; + } + + public void testSingleValuedLongs() throws Exception { + final int numDocs = scaledRandomIntBetween(1, 100); + final long[] array = new long[numDocs]; + final FixedBitSet docsWithValue = randomBoolean() ? null : new FixedBitSet(numDocs); + for (int i = 0; i < array.length; ++i) { + if (randomBoolean()) { + array[i] = randomUnsignedLong().longValue(); + if (docsWithValue != null) { + docsWithValue.set(i); + } + } else if (docsWithValue != null && randomBoolean()) { + docsWithValue.set(i); + } + } + + final Supplier multiValues = () -> DocValues.singleton(new AbstractNumericDocValues() { + int docId = -1; + + @Override + public boolean advanceExact(int target) throws IOException { + this.docId = target; + return docsWithValue == null || docsWithValue.get(docId); + } + + @Override + public int docID() { + return docId; + } + + @Override + public long longValue() { + return array[docId]; + } + }); + verifySortedNumeric(multiValues, numDocs); + final FixedBitSet rootDocs = randomRootDocs(numDocs); + final FixedBitSet innerDocs = randomInnerDocs(rootDocs); + verifySortedNumeric(multiValues, numDocs, rootDocs, innerDocs, Integer.MAX_VALUE); + verifySortedNumeric(multiValues, numDocs, rootDocs, innerDocs, randomIntBetween(1, numDocs)); + } + + public void testMultiValuedLongs() throws Exception { + final int numDocs = scaledRandomIntBetween(1, 100); + final long[][] array = new long[numDocs][]; + for (int i = 0; i < numDocs; ++i) { + final long[] values = new long[randomInt(4)]; + for (int j = 0; j < values.length; ++j) { + values[j] = randomUnsignedLong().longValue(); + } + Arrays.sort(values); + array[i] = values; + } + final Supplier multiValues = () -> new AbstractSortedNumericDocValues() { + int doc; + int i; + + @Override + public long nextValue() { + return array[doc][i++]; + } + + @Override + public boolean advanceExact(int doc) { + this.doc = doc; + i = 0; + return array[doc].length > 0; + } + + @Override + public int docValueCount() { + return array[doc].length; + } + }; + verifySortedNumeric(multiValues, numDocs); + final FixedBitSet rootDocs = randomRootDocs(numDocs); + final FixedBitSet innerDocs = randomInnerDocs(rootDocs); + verifySortedNumeric(multiValues, numDocs, rootDocs, innerDocs, Integer.MAX_VALUE); + verifySortedNumeric(multiValues, numDocs, rootDocs, innerDocs, randomIntBetween(1, numDocs)); + } + + private void verifySortedNumeric(Supplier supplier, int maxDoc) throws IOException { + for (UnsignedLongMultiValueMode mode : UnsignedLongMultiValueMode.values()) { + SortedNumericDocValues values = supplier.get(); + final NumericDocValues selected = mode.select(values); + for (int i = 0; i < maxDoc; ++i) { + Long actual = null; + if (selected.advanceExact(i)) { + actual = selected.longValue(); + verifyLongValueCanCalledMoreThanOnce(selected, actual); + } + + BigInteger expected = null; + if (values.advanceExact(i)) { + int numValues = values.docValueCount(); + if (mode == UnsignedLongMultiValueMode.MAX) { + expected = Numbers.MIN_UNSIGNED_LONG_VALUE; + } else if (mode == UnsignedLongMultiValueMode.MIN) { + expected = Numbers.MAX_UNSIGNED_LONG_VALUE; + } else { + expected = BigInteger.ZERO; + } + for (int j = 0; j < numValues; ++j) { + if (mode == UnsignedLongMultiValueMode.SUM || mode == UnsignedLongMultiValueMode.AVG) { + expected = expected.add(Numbers.toUnsignedBigInteger(values.nextValue())); + } else if (mode == UnsignedLongMultiValueMode.MIN) { + expected = expected.min(Numbers.toUnsignedBigInteger(values.nextValue())); + } else if (mode == UnsignedLongMultiValueMode.MAX) { + expected = expected.max(Numbers.toUnsignedBigInteger(values.nextValue())); + } + } + if (mode == UnsignedLongMultiValueMode.AVG) { + expected = Numbers.toUnsignedBigInteger(expected.longValue()); + expected = numValues > 1 + ? new BigDecimal(expected).divide(new BigDecimal(numValues), RoundingMode.HALF_UP).toBigInteger() + : expected; + } else if (mode == UnsignedLongMultiValueMode.MEDIAN) { + final Long[] docValues = new Long[numValues]; + for (int j = 0; j < numValues; ++j) { + docValues[j] = values.nextValue(); + } + Arrays.sort(docValues, Long::compareUnsigned); + int value = numValues / 2; + if (numValues % 2 == 0) { + expected = Numbers.toUnsignedBigInteger(docValues[value - 1]) + .add(Numbers.toUnsignedBigInteger(docValues[value])); + expected = Numbers.toUnsignedBigInteger(expected.longValue()); + expected = new BigDecimal(expected).divide(new BigDecimal(2), RoundingMode.HALF_UP).toBigInteger(); + } else { + expected = Numbers.toUnsignedBigInteger(docValues[value]); + } + } + } + + final Long expectedLong = expected == null ? null : expected.longValue(); + assertEquals(mode.toString() + " docId=" + i, expectedLong, actual); + } + } + } + + private void verifyLongValueCanCalledMoreThanOnce(NumericDocValues values, long expected) throws IOException { + for (int j = 0, numCall = randomIntBetween(1, 10); j < numCall; j++) { + assertEquals(expected, values.longValue()); + } + } + + private void verifySortedNumeric( + Supplier supplier, + int maxDoc, + FixedBitSet rootDocs, + FixedBitSet innerDocs, + int maxChildren + ) throws IOException { + for (long missingValue : new long[] { 0, randomUnsignedLong().longValue() }) { + for (UnsignedLongMultiValueMode mode : new UnsignedLongMultiValueMode[] { + UnsignedLongMultiValueMode.MIN, + UnsignedLongMultiValueMode.MAX, + UnsignedLongMultiValueMode.SUM, + UnsignedLongMultiValueMode.AVG }) { + SortedNumericDocValues values = supplier.get(); + final NumericDocValues selected = mode.select( + values, + missingValue, + rootDocs, + new BitSetIterator(innerDocs, 0L), + maxDoc, + maxChildren + ); + int prevRoot = -1; + for (int root = rootDocs.nextSetBit(0); root != -1; root = root + 1 < maxDoc ? rootDocs.nextSetBit(root + 1) : -1) { + assertTrue(selected.advanceExact(root)); + final long actual = selected.longValue(); + verifyLongValueCanCalledMoreThanOnce(selected, actual); + + BigInteger expected = BigInteger.ZERO; + if (mode == UnsignedLongMultiValueMode.MAX) { + expected = Numbers.MIN_UNSIGNED_LONG_VALUE; + } else if (mode == UnsignedLongMultiValueMode.MIN) { + expected = Numbers.MAX_UNSIGNED_LONG_VALUE; + } + int numValues = 0; + int count = 0; + for (int child = innerDocs.nextSetBit(prevRoot + 1); child != -1 && child < root; child = innerDocs.nextSetBit( + child + 1 + )) { + if (values.advanceExact(child)) { + if (++count > maxChildren) { + break; + } + for (int j = 0; j < values.docValueCount(); ++j) { + if (mode == UnsignedLongMultiValueMode.SUM || mode == UnsignedLongMultiValueMode.AVG) { + expected = expected.add(Numbers.toUnsignedBigInteger(values.nextValue())); + } else if (mode == UnsignedLongMultiValueMode.MIN) { + expected = expected.min(Numbers.toUnsignedBigInteger(values.nextValue())); + } else if (mode == UnsignedLongMultiValueMode.MAX) { + expected = expected.max(Numbers.toUnsignedBigInteger(values.nextValue())); + } + ++numValues; + } + } + } + final long expectedLong; + if (numValues == 0) { + expectedLong = missingValue; + } else if (mode == UnsignedLongMultiValueMode.AVG) { + expected = Numbers.toUnsignedBigInteger(expected.longValue()); + expected = numValues > 1 + ? new BigDecimal(expected).divide(new BigDecimal(numValues), RoundingMode.HALF_UP).toBigInteger() + : expected; + expectedLong = expected.longValue(); + } else { + expectedLong = expected.longValue(); + } + + assertEquals(mode.toString() + " docId=" + root, expectedLong, actual); + + prevRoot = root; + } + } + } + } +}