From e8425fc975352d120dd4e4710c308b835dd50316 Mon Sep 17 00:00:00 2001 From: Chaitanya Gohel <104654647+gashutos@users.noreply.github.com> Date: Tue, 28 Mar 2023 04:13:49 +0530 Subject: [PATCH] Enable numeric sort optimization support for all numeric types (#6424) * Adding numeric optimization support for all numeric types Signed-off-by: gashutos * modifying CHANGELOG.md Signed-off-by: gashutos * Handling multi-cluster scenario where SortField serialization was failing Signed-off-by: gashutos * Fixing javadoc errors Signed-off-by: gashutos * Fixing nested sort integ tests Signed-off-by: gashutos * Stremlining behaviour of custom comparator tests too Signed-off-by: gashutos * Adding more integ tests for IntValuesComparatorSource & fixing few ITs Signed-off-by: gashutos * Fixing few more integ tests Signed-off-by: gashutos * Streamlining applySortWidening method with CreateSort and avoid modifying cteated objects of sort Signed-off-by: gashutos * Correcting licence header Signed-off-by: gashutos --------- Signed-off-by: gashutos Co-authored-by: Daniel (dB.) Doubrovkine --- CHANGELOG.md | 1 + .../org/opensearch/index/IndexSortIT.java | 4 +- .../search/searchafter/SearchAfterIT.java | 14 +- .../action/search/SearchPhaseController.java | 51 +++- .../org/opensearch/common/lucene/Lucene.java | 5 +- .../fielddata/IndexNumericFieldData.java | 118 ++------- .../IntValuesComparatorSource.java | 126 ++++++++++ .../opensearch/search/sort/BucketedSort.java | 88 +++++++ .../sort/SortedWiderNumericSortField.java | 87 +++++++ .../search/SearchPhaseControllerTests.java | 226 ++++++++++++++++-- .../opensearch/common/lucene/LuceneTests.java | 12 +- .../search/nested/NestedSortingTests.java | 60 ++--- .../search/sort/BucketedSortForIntsTests.java | 77 ++++++ .../search/sort/FieldSortBuilderTests.java | 6 +- 14 files changed, 703 insertions(+), 172 deletions(-) create mode 100644 server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/IntValuesComparatorSource.java create mode 100644 server/src/main/java/org/opensearch/search/sort/SortedWiderNumericSortField.java create mode 100644 server/src/test/java/org/opensearch/search/sort/BucketedSortForIntsTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cccdc7e6b5f3..6ee485b50020f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Segment Replication] Apply backpressure when replicas fall behind ([#6563](https://github.com/opensearch-project/OpenSearch/pull/6563)) - [Remote Store] Integrate remote segment store in peer recovery flow ([#6664](https://github.com/opensearch-project/OpenSearch/pull/6664)) - [Segment Replication] Add new cluster setting to set replication strategy by default for all indices in cluster. ([#6791](https://github.com/opensearch-project/OpenSearch/pull/6791)) +- Enable sort optimization for all NumericTypes ([#6464](https://github.com/opensearch-project/OpenSearch/pull/6464) ### Dependencies - Bump `org.apache.logging.log4j:log4j-core` from 2.18.0 to 2.20.0 ([#6490](https://github.com/opensearch-project/OpenSearch/pull/6490)) diff --git a/server/src/internalClusterTest/java/org/opensearch/index/IndexSortIT.java b/server/src/internalClusterTest/java/org/opensearch/index/IndexSortIT.java index e06a2f91ae91a..d547ded8152dd 100644 --- a/server/src/internalClusterTest/java/org/opensearch/index/IndexSortIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/index/IndexSortIT.java @@ -81,8 +81,8 @@ private static XContentBuilder createTestMapping() { public void testIndexSort() { SortField dateSort = new SortedNumericSortField("date", SortField.Type.LONG, false); dateSort.setMissingValue(Long.MAX_VALUE); - SortField numericSort = new SortedNumericSortField("numeric_dv", SortField.Type.LONG, false); - numericSort.setMissingValue(Long.MAX_VALUE); + SortField numericSort = new SortedNumericSortField("numeric_dv", SortField.Type.INT, false); + numericSort.setMissingValue(Integer.MAX_VALUE); SortField keywordSort = new SortedSetSortField("keyword_dv", false); keywordSort.setMissingValue(SortField.STRING_LAST); Sort indexSort = new Sort(dateSort, numericSort, keywordSort); diff --git a/server/src/internalClusterTest/java/org/opensearch/search/searchafter/SearchAfterIT.java b/server/src/internalClusterTest/java/org/opensearch/search/searchafter/SearchAfterIT.java index 2c98154115bb9..2a662c9dda088 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/searchafter/SearchAfterIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/searchafter/SearchAfterIT.java @@ -382,24 +382,22 @@ private void createIndexMappingsFromObjectType(String indexName, List ty ensureGreen(); } - // Convert Integer, Short, Byte and Boolean to Long in order to match the conversion done + // Convert Integer, Short, Byte and Boolean to Int in order to match the conversion done // by the internal hits when populating the sort values. private List convertSortValues(List sortValues) { List converted = new ArrayList<>(); for (int i = 0; i < sortValues.size(); i++) { Object from = sortValues.get(i); - if (from instanceof Integer) { - converted.add(((Integer) from).longValue()); - } else if (from instanceof Short) { - converted.add(((Short) from).longValue()); + if (from instanceof Short) { + converted.add(((Short) from).intValue()); } else if (from instanceof Byte) { - converted.add(((Byte) from).longValue()); + converted.add(((Byte) from).intValue()); } else if (from instanceof Boolean) { boolean b = (boolean) from; if (b) { - converted.add(1L); + converted.add(1); } else { - converted.add(0L); + converted.add(0); } } else { converted.add(from); diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index d32e7753cd153..3f4f7c2b92512 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -41,6 +41,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.SortedNumericSortField; import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopFieldDocs; @@ -68,6 +69,7 @@ import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.sort.SortedWiderNumericSortField; import org.opensearch.search.suggest.Suggest; import org.opensearch.search.suggest.Suggest.Suggestion; import org.opensearch.search.suggest.completion.CompletionSuggestion; @@ -235,14 +237,12 @@ static TopDocs mergeTopDocs(Collection results, int topN, int from) { if (numShards == 1 && from == 0) { // only one shard and no pagination we can just return the topDocs as we got them. return topDocs; } else if (topDocs instanceof CollapseTopFieldDocs) { - CollapseTopFieldDocs firstTopDocs = (CollapseTopFieldDocs) topDocs; - final Sort sort = new Sort(firstTopDocs.fields); final CollapseTopFieldDocs[] shardTopDocs = results.toArray(new CollapseTopFieldDocs[numShards]); + final Sort sort = createSort(shardTopDocs); mergedTopDocs = CollapseTopFieldDocs.merge(sort, from, topN, shardTopDocs, false); } else if (topDocs instanceof TopFieldDocs) { - TopFieldDocs firstTopDocs = (TopFieldDocs) topDocs; - final Sort sort = new Sort(firstTopDocs.fields); final TopFieldDocs[] shardTopDocs = results.toArray(new TopFieldDocs[numShards]); + final Sort sort = createSort(shardTopDocs); mergedTopDocs = TopDocs.merge(sort, from, topN, shardTopDocs); } else { final TopDocs[] shardTopDocs = results.toArray(new TopDocs[numShards]); @@ -600,6 +600,49 @@ private static void validateMergeSortValueFormats(Collection newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) { + assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName()); + + final int iMissingValue = (Integer) missingObject(missingValue, reversed); + // NOTE: it's important to pass null as a missing value in the constructor so that + // the comparator doesn't check docsWithField since we replace missing values in select() + return new IntComparator(numHits, null, null, reversed, false) { + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + return new IntLeafComparator(context) { + @Override + protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException { + return IntValuesComparatorSource.this.getNumericDocValues(context, iMissingValue); + } + }; + } + }; + } + + @Override + public BucketedSort newBucketedSort( + BigArrays bigArrays, + SortOrder sortOrder, + DocValueFormat format, + int bucketSize, + BucketedSort.ExtraData extra + ) { + return new BucketedSort.ForInts(bigArrays, sortOrder, format, bucketSize, extra) { + private final int iMissingValue = (Integer) missingObject(missingValue, sortOrder == SortOrder.DESC); + + @Override + public Leaf forLeaf(LeafReaderContext ctx) throws IOException { + return new Leaf(ctx) { + private final NumericDocValues docValues = getNumericDocValues(ctx, iMissingValue); + private int docValue; + + @Override + protected boolean advanceExact(int doc) throws IOException { + if (docValues.advanceExact(doc)) { + docValue = (int) docValues.longValue(); + return true; + } + return false; + } + + @Override + protected int docValue() { + return docValue; + } + }; + } + }; + } +} diff --git a/server/src/main/java/org/opensearch/search/sort/BucketedSort.java b/server/src/main/java/org/opensearch/search/sort/BucketedSort.java index 9266469db2b05..a075b6567fe2d 100644 --- a/server/src/main/java/org/opensearch/search/sort/BucketedSort.java +++ b/server/src/main/java/org/opensearch/search/sort/BucketedSort.java @@ -42,6 +42,7 @@ import org.opensearch.common.util.BitArray; import org.opensearch.common.util.DoubleArray; import org.opensearch.common.util.FloatArray; +import org.opensearch.common.util.IntArray; import org.opensearch.common.util.LongArray; import org.opensearch.search.DocValueFormat; @@ -756,4 +757,91 @@ protected final boolean docBetterThan(long index) { } } } + + /** + * Superclass for implementations of {@linkplain BucketedSort} for {@code int} keys. + */ + public abstract static class ForInts extends BucketedSort { + private IntArray values = bigArrays.newIntArray(1, false); + + public ForInts(BigArrays bigArrays, SortOrder sortOrder, DocValueFormat format, int bucketSize, ExtraData extra) { + super(bigArrays, sortOrder, format, bucketSize, extra); + initGatherOffsets(); + } + + @Override + public final boolean needsScores() { + return false; + } + + @Override + protected final BigArray values() { + return values; + } + + @Override + protected final void growValues(long minSize) { + values = bigArrays.grow(values, minSize); + } + + @Override + protected final int getNextGatherOffset(long rootIndex) { + return values.get(rootIndex); + } + + @Override + protected final void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + @Override + protected final SortValue getValue(long index) { + return SortValue.from(values.get(index)); + } + + @Override + protected final boolean betterThan(long lhs, long rhs) { + return getOrder().reverseMul() * Integer.compare(values.get(lhs), values.get(rhs)) < 0; + } + + @Override + protected final void swap(long lhs, long rhs) { + int tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + } + + /** + * Leaf for bucketed sort + * + * @opensearch.internal + */ + protected abstract class Leaf extends BucketedSort.Leaf { + protected Leaf(LeafReaderContext ctx) { + super(ctx); + } + + /** + * Return the value for of this sort for the document to which + * we just {@link #advanceExact(int) moved}. This should be fast + * because it is called twice per competitive hit when in heap + * mode, once for {@link #docBetterThan(long)} and once + * for {@link #setIndexToDocValue(long)}. + */ + protected abstract int docValue(); + + @Override + public final void setScorer(Scorable scorer) {} + + @Override + protected final void setIndexToDocValue(long index) { + values.set(index, docValue()); + } + + @Override + protected final boolean docBetterThan(long index) { + return getOrder().reverseMul() * Integer.compare(docValue(), values.get(index)) < 0; + } + } + } } diff --git a/server/src/main/java/org/opensearch/search/sort/SortedWiderNumericSortField.java b/server/src/main/java/org/opensearch/search/sort/SortedWiderNumericSortField.java new file mode 100644 index 0000000000000..2caacf33fcdcf --- /dev/null +++ b/server/src/main/java/org/opensearch/search/sort/SortedWiderNumericSortField.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.sort; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.SortedNumericSelector; +import org.apache.lucene.search.SortedNumericSortField; +import org.apache.lucene.search.comparators.NumericComparator; + +import java.io.IOException; + +/** + * Sorted numeric field for wider sort types, + * to help sorting two different numeric types. + * + * @opensearch.internal + */ +public class SortedWiderNumericSortField extends SortedNumericSortField { + /** + * Creates a sort, possibly in reverse, specifying how the sort value from the document's set is + * selected. + * + * @param field Name of field to sort by. Must not be null. + * @param type Type of values + * @param reverse True if natural order should be reversed. + * @param selector custom selector type for choosing the sort value from the set. + */ + public SortedWiderNumericSortField(String field, Type type, boolean reverse, SortedNumericSelector.Type selector) { + super(field, type, reverse, selector); + } + + /** + * Creates and return a comparator, which always converts Numeric to double + * and compare to support multi type comparison between numeric values + * @param numHits number of top hits the queue will store + * @param enableSkipping true if the comparator can skip documents via {@link + * LeafFieldComparator#competitiveIterator()} + * @return NumericComparator + */ + @Override + public FieldComparator getComparator(int numHits, boolean enableSkipping) { + return new NumericComparator(getField(), (Number) getMissingValue(), getReverse(), enableSkipping, Double.BYTES) { + @Override + public int compare(int slot1, int slot2) { + throw new UnsupportedOperationException(); + } + + @Override + public Number value(int slot) { + throw new UnsupportedOperationException(); + } + + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int compareValues(Number first, Number second) { + if (first == null) { + if (second == null) { + return 0; + } else { + return -1; + } + } else if (second == null) { + return 1; + } else { + return Double.compare(first.doubleValue(), second.doubleValue()); + } + } + }; + } +} diff --git a/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java index 2a6d6ee7e45bb..abcf1efe56122 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java @@ -37,6 +37,7 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.SortedNumericSortField; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; @@ -173,27 +174,70 @@ public void testSortDocs() { int nShards = randomIntBetween(1, 20); int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2); AtomicArray results = generateQueryResults(nShards, suggestions, queryResultSize, false); - Optional first = results.asList().stream().findFirst(); - int from = 0, size = 0; - if (first.isPresent()) { - from = first.get().queryResult().from(); - size = first.get().queryResult().size(); + performSortDocs(results, queryResultSize); + } + + /** + * Test to verify merge shard results with SortField.Type.Int, document type Integer + */ + public void testSortIntFieldDocsMerge() { + List suggestions = new ArrayList<>(); + for (int i = 0; i < randomIntBetween(1, 5); i++) { + suggestions.add(new CompletionSuggestion(randomAlphaOfLength(randomIntBetween(1, 5)), randomIntBetween(1, 20), false)); } - int accumulatedLength = Math.min(queryResultSize, getTotalQueryHits(results)); - List reducedCompletionSuggestions = reducedSuggest(results); - for (Suggest.Suggestion suggestion : reducedCompletionSuggestions) { - int suggestionSize = suggestion.getEntries().get(0).getOptions().size(); - accumulatedLength += suggestionSize; + int nShards = randomIntBetween(1, 20); + int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2); + AtomicArray results = generateQueryResultsWithIntSortedField(nShards, suggestions, queryResultSize, false); + performSortDocs(results, queryResultSize); + } + + /** + * Test to verify merge shard results with different SortField.Type. + * Few shards with Int and few shards with Long + */ + public void testSortIntLongFieldDocsMerge() { + List suggestions = new ArrayList<>(); + for (int i = 0; i < randomIntBetween(1, 5); i++) { + suggestions.add(new CompletionSuggestion(randomAlphaOfLength(randomIntBetween(1, 5)), randomIntBetween(1, 20), false)); } - List topDocsList = new ArrayList<>(); - for (SearchPhaseResult result : results.asList()) { - QuerySearchResult queryResult = result.queryResult(); - TopDocs topDocs = queryResult.consumeTopDocs().topDocs; - SearchPhaseController.setShardIndex(topDocs, result.getShardIndex()); - topDocsList.add(topDocs); + int nShards = randomIntBetween(1, 20); + int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2); + AtomicArray results = generateQueryResultsWithIntLongSortedField(nShards, suggestions, queryResultSize, false); + performSortDocs(results, queryResultSize); + } + + /** + * Test to verify merge shard results with SortField.Type.Float, document type Float + */ + public void testSortFloatFieldDocsMerge() { + List suggestions = new ArrayList<>(); + for (int i = 0; i < randomIntBetween(1, 5); i++) { + suggestions.add(new CompletionSuggestion(randomAlphaOfLength(randomIntBetween(1, 5)), randomIntBetween(1, 20), false)); } - ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(true, topDocsList, from, size, reducedCompletionSuggestions).scoreDocs; - assertThat(sortedDocs.length, equalTo(accumulatedLength)); + int nShards = randomIntBetween(1, 20); + int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2); + AtomicArray results = generateQueryResultsWithFloatSortedField(nShards, suggestions, queryResultSize, false); + performSortDocs(results, queryResultSize); + } + + /** + * Test to verify merge shard results with different SortField.Type. + * Few shards with Float and few shards with Double + */ + public void testSortIntFloatDoubleFieldDocsMerge() { + List suggestions = new ArrayList<>(); + for (int i = 0; i < randomIntBetween(1, 5); i++) { + suggestions.add(new CompletionSuggestion(randomAlphaOfLength(randomIntBetween(1, 5)), randomIntBetween(1, 20), false)); + } + int nShards = randomIntBetween(1, 20); + int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2); + AtomicArray results = generateQueryResultsWithFloatDoubleSortedField( + nShards, + suggestions, + queryResultSize, + false + ); + performSortDocs(results, queryResultSize); } public void testSortDocsIsIdempotent() throws Exception { @@ -241,6 +285,30 @@ public void testSortDocsIsIdempotent() throws Exception { } } + private static void performSortDocs(AtomicArray results, int queryResultSize) { + Optional first = results.asList().stream().findFirst(); + int from = 0, size = 0; + if (first.isPresent()) { + from = first.get().queryResult().from(); + size = first.get().queryResult().size(); + } + int accumulatedLength = Math.min(queryResultSize, getTotalQueryHits(results)); + List reducedCompletionSuggestions = reducedSuggest(results); + for (Suggest.Suggestion suggestion : reducedCompletionSuggestions) { + int suggestionSize = suggestion.getEntries().get(0).getOptions().size(); + accumulatedLength += suggestionSize; + } + List topDocsList = new ArrayList<>(); + for (SearchPhaseResult result : results.asList()) { + QuerySearchResult queryResult = result.queryResult(); + TopDocs topDocs = queryResult.consumeTopDocs().topDocs; + SearchPhaseController.setShardIndex(topDocs, result.getShardIndex()); + topDocsList.add(topDocs); + } + ScoreDoc[] sortedDocs = SearchPhaseController.sortDocs(true, topDocsList, from, size, reducedCompletionSuggestions).scoreDocs; + assertThat(sortedDocs.length, equalTo(accumulatedLength)); + } + private AtomicArray generateSeededQueryResults( long seed, int nShards, @@ -389,6 +457,128 @@ private static AtomicArray generateQueryResults( return queryResults; } + private static AtomicArray generateQueryResultsWithIntSortedField( + int nShards, + List suggestions, + int searchHitsSize, + boolean useConstantScore + ) { + AtomicArray results = generateQueryResults(nShards, suggestions, searchHitsSize, false); + for (int i = 0; i < results.length(); i++) { + int nDocs = randomIntBetween(0, searchHitsSize); + float maxScore = 0; + final TopDocs topDocs = getIntTopFieldDocs(nDocs, useConstantScore); + results.get(i).queryResult().topDocs(new TopDocsAndMaxScore(topDocs, maxScore), new DocValueFormat[1]); + } + return results; + } + + private static AtomicArray generateQueryResultsWithFloatSortedField( + int nShards, + List suggestions, + int searchHitsSize, + boolean useConstantScore + ) { + AtomicArray results = generateQueryResults(nShards, suggestions, searchHitsSize, false); + for (int i = 0; i < results.length(); i++) { + int nDocs = randomIntBetween(0, searchHitsSize); + float maxScore = 0; + final TopDocs topDocs = getFloatTopFieldDocs(nDocs, useConstantScore); + results.get(i).queryResult().topDocs(new TopDocsAndMaxScore(topDocs, maxScore), new DocValueFormat[1]); + } + return results; + } + + private static AtomicArray generateQueryResultsWithIntLongSortedField( + int nShards, + List suggestions, + int searchHitsSize, + boolean useConstantScore + ) { + AtomicArray results = generateQueryResults(nShards, suggestions, searchHitsSize, false); + for (int i = 0; i < results.length(); i++) { + int nDocs = randomIntBetween(0, searchHitsSize); + float maxScore = 0; + final TopDocs topDocs; + if (i % 2 == 0) { + topDocs = getLongTopFieldDocs(nDocs, useConstantScore); + } else { + topDocs = getIntTopFieldDocs(nDocs, useConstantScore); + } + results.get(i).queryResult().topDocs(new TopDocsAndMaxScore(topDocs, maxScore), new DocValueFormat[1]); + } + return results; + } + + private static AtomicArray generateQueryResultsWithFloatDoubleSortedField( + int nShards, + List suggestions, + int searchHitsSize, + boolean useConstantScore + ) { + AtomicArray results = generateQueryResults(nShards, suggestions, searchHitsSize, false); + for (int i = 0; i < results.length(); i++) { + int nDocs = randomIntBetween(0, searchHitsSize); + float maxScore = 0; + final TopDocs topDocs; + if (i % 2 == 0) { + topDocs = getFloatTopFieldDocs(nDocs, useConstantScore); + } else { + topDocs = getDoubleTopFieldDocs(nDocs, useConstantScore); + } + results.get(i).queryResult().topDocs(new TopDocsAndMaxScore(topDocs, maxScore), new DocValueFormat[1]); + } + return results; + } + + private static TopFieldDocs getLongTopFieldDocs(int nDocs, boolean useConstantScore) { + FieldDoc[] fieldDocs = new FieldDoc[nDocs]; + SortField[] sortFields = { new SortedNumericSortField("field", SortField.Type.LONG, true) }; + float maxScore = 0; + for (int i = 0; i < nDocs; i++) { + float score = useConstantScore ? 1.0F : Math.abs(randomFloat()); + fieldDocs[i] = new FieldDoc(i, score, new Long[] { randomLong() }); + maxScore = Math.max(score, maxScore); + } + return new TopFieldDocs(new TotalHits(fieldDocs.length, TotalHits.Relation.EQUAL_TO), fieldDocs, sortFields); + } + + private static TopFieldDocs getFloatTopFieldDocs(int nDocs, boolean useConstantScore) { + FieldDoc[] fieldDocs = new FieldDoc[nDocs]; + SortField[] sortFields = { new SortedNumericSortField("field", SortField.Type.FLOAT, true) }; + float maxScore = 0; + for (int i = 0; i < nDocs; i++) { + float score = useConstantScore ? 1.0F : Math.abs(randomFloat()); + fieldDocs[i] = new FieldDoc(i, score, new Float[] { randomFloat() }); + maxScore = Math.max(score, maxScore); + } + return new TopFieldDocs(new TotalHits(fieldDocs.length, TotalHits.Relation.EQUAL_TO), fieldDocs, sortFields); + } + + private static TopFieldDocs getDoubleTopFieldDocs(int nDocs, boolean useConstantScore) { + FieldDoc[] fieldDocs = new FieldDoc[nDocs]; + SortField[] sortFields = { new SortedNumericSortField("field", SortField.Type.DOUBLE, true) }; + float maxScore = 0; + for (int i = 0; i < nDocs; i++) { + float score = useConstantScore ? 1.0F : Math.abs(randomFloat()); + fieldDocs[i] = new FieldDoc(i, score, new Double[] { randomDouble() }); + maxScore = Math.max(score, maxScore); + } + return new TopFieldDocs(new TotalHits(fieldDocs.length, TotalHits.Relation.EQUAL_TO), fieldDocs, sortFields); + } + + private static TopFieldDocs getIntTopFieldDocs(int nDocs, boolean useConstantScore) { + FieldDoc[] fieldDocs = new FieldDoc[nDocs]; + SortField[] sortFields = { new SortedNumericSortField("field", SortField.Type.INT, true) }; + float maxScore = 0; + for (int i = 0; i < nDocs; i++) { + float score = useConstantScore ? 1.0F : Math.abs(randomFloat()); + fieldDocs[i] = new FieldDoc(i, score, new Integer[] { randomInt() }); + maxScore = Math.max(score, maxScore); + } + return new TopFieldDocs(new TotalHits(fieldDocs.length, TotalHits.Relation.EQUAL_TO), fieldDocs, sortFields); + } + private static int getTotalQueryHits(AtomicArray results) { int resultCount = 0; for (SearchPhaseResult shardResult : results.asList()) { diff --git a/server/src/test/java/org/opensearch/common/lucene/LuceneTests.java b/server/src/test/java/org/opensearch/common/lucene/LuceneTests.java index 97c192ecd9660..e7756cbd96734 100644 --- a/server/src/test/java/org/opensearch/common/lucene/LuceneTests.java +++ b/server/src/test/java/org/opensearch/common/lucene/LuceneTests.java @@ -87,6 +87,7 @@ import org.opensearch.index.fielddata.fieldcomparator.BytesRefFieldComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource; +import org.opensearch.index.fielddata.fieldcomparator.IntValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.LongValuesComparatorSource; import org.opensearch.search.MultiValueMode; import org.opensearch.test.OpenSearchTestCase; @@ -753,7 +754,7 @@ private static Tuple randomSortFieldCustomComparatorSource IndexFieldData.XFieldComparatorSource comparatorSource; boolean reverse = randomBoolean(); Object missingValue = null; - switch (randomIntBetween(0, 3)) { + switch (randomIntBetween(0, 4)) { case 0: comparatorSource = new LongValuesComparatorSource( null, @@ -787,6 +788,15 @@ private static Tuple randomSortFieldCustomComparatorSource ); missingValue = comparatorSource.missingValue(reverse); break; + case 4: + comparatorSource = new IntValuesComparatorSource( + null, + randomBoolean() ? randomInt() : null, + randomFrom(MultiValueMode.values()), + null + ); + missingValue = comparatorSource.missingValue(reverse); + break; default: throw new UnsupportedOperationException(); } diff --git a/server/src/test/java/org/opensearch/index/search/nested/NestedSortingTests.java b/server/src/test/java/org/opensearch/index/search/nested/NestedSortingTests.java index 79c0044a0ceb0..7f130680f83e5 100644 --- a/server/src/test/java/org/opensearch/index/search/nested/NestedSortingTests.java +++ b/server/src/test/java/org/opensearch/index/search/nested/NestedSortingTests.java @@ -645,15 +645,15 @@ public void testMultiLevelNestedSorting() throws IOException { TopFieldDocs topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(5L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76)); assertThat(searcher.doc(topFields.scoreDocs[1].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(87L)); + assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(87)); assertThat(searcher.doc(topFields.scoreDocs[2].doc).get("_id"), equalTo("1")); - assertThat(((FieldDoc) topFields.scoreDocs[2]).fields[0], equalTo(234L)); + assertThat(((FieldDoc) topFields.scoreDocs[2]).fields[0], equalTo(234)); assertThat(searcher.doc(topFields.scoreDocs[3].doc).get("_id"), equalTo("3")); - assertThat(((FieldDoc) topFields.scoreDocs[3]).fields[0], equalTo(976L)); + assertThat(((FieldDoc) topFields.scoreDocs[3]).fields[0], equalTo(976)); assertThat(searcher.doc(topFields.scoreDocs[4].doc).get("_id"), equalTo("5")); - assertThat(((FieldDoc) topFields.scoreDocs[4]).fields[0], equalTo(Long.MAX_VALUE)); + assertThat(((FieldDoc) topFields.scoreDocs[4]).fields[0], equalTo(Integer.MAX_VALUE)); // Specific genre { @@ -661,25 +661,25 @@ public void testMultiLevelNestedSorting() throws IOException { topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76)); queryBuilder = new TermQueryBuilder("genre", "science fiction"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("1")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(234L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(234)); queryBuilder = new TermQueryBuilder("genre", "horror"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("3")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(976L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(976)); queryBuilder = new TermQueryBuilder("genre", "cooking"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87)); } // reverse sort order @@ -689,15 +689,15 @@ public void testMultiLevelNestedSorting() throws IOException { topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(5L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("3")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(976L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(976)); assertThat(searcher.doc(topFields.scoreDocs[1].doc).get("_id"), equalTo("1")); - assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(849L)); + assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(849)); assertThat(searcher.doc(topFields.scoreDocs[2].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[2]).fields[0], equalTo(180L)); + assertThat(((FieldDoc) topFields.scoreDocs[2]).fields[0], equalTo(180)); assertThat(searcher.doc(topFields.scoreDocs[3].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[3]).fields[0], equalTo(76L)); + assertThat(((FieldDoc) topFields.scoreDocs[3]).fields[0], equalTo(76)); assertThat(searcher.doc(topFields.scoreDocs[4].doc).get("_id"), equalTo("5")); - assertThat(((FieldDoc) topFields.scoreDocs[4]).fields[0], equalTo(Long.MIN_VALUE)); + assertThat(((FieldDoc) topFields.scoreDocs[4]).fields[0], equalTo(Integer.MIN_VALUE)); } // Specific genre and reverse sort order @@ -706,25 +706,25 @@ public void testMultiLevelNestedSorting() throws IOException { topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76)); queryBuilder = new TermQueryBuilder("genre", "science fiction"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("1")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(849L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(849)); queryBuilder = new TermQueryBuilder("genre", "horror"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("3")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(976L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(976)); queryBuilder = new TermQueryBuilder("genre", "cooking"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(180L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(180)); } // Nested filter + query @@ -737,17 +737,17 @@ public void testMultiLevelNestedSorting() throws IOException { topFields = search(new NestedQueryBuilder("chapters", queryBuilder, ScoreMode.None), sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(2L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76)); assertThat(searcher.doc(topFields.scoreDocs[1].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(87L)); + assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(87)); sortBuilder.order(SortOrder.DESC); topFields = search(new NestedQueryBuilder("chapters", queryBuilder, ScoreMode.None), sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(2L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87)); assertThat(searcher.doc(topFields.scoreDocs[1].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(76L)); + assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(76)); } // Multiple Nested filters + query @@ -765,17 +765,17 @@ public void testMultiLevelNestedSorting() throws IOException { topFields = search(new NestedQueryBuilder("chapters", queryBuilder, ScoreMode.None), sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(2L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87)); assertThat(searcher.doc(topFields.scoreDocs[1].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(Long.MAX_VALUE)); + assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(Integer.MAX_VALUE)); sortBuilder.order(SortOrder.DESC); topFields = search(new NestedQueryBuilder("chapters", queryBuilder, ScoreMode.None), sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(2L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87)); assertThat(searcher.doc(topFields.scoreDocs[1].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(Long.MIN_VALUE)); + assertThat(((FieldDoc) topFields.scoreDocs[1]).fields[0], equalTo(Integer.MIN_VALUE)); } // Nested filter + Specific genre @@ -790,25 +790,25 @@ public void testMultiLevelNestedSorting() throws IOException { topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("2")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(76)); queryBuilder = new TermQueryBuilder("genre", "science fiction"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("1")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(Long.MAX_VALUE)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(Integer.MAX_VALUE)); queryBuilder = new TermQueryBuilder("genre", "horror"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("3")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(Long.MAX_VALUE)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(Integer.MAX_VALUE)); queryBuilder = new TermQueryBuilder("genre", "cooking"); topFields = search(queryBuilder, sortBuilder, queryShardContext, searcher); assertThat(topFields.totalHits.value, equalTo(1L)); assertThat(searcher.doc(topFields.scoreDocs[0].doc).get("_id"), equalTo("4")); - assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87L)); + assertThat(((FieldDoc) topFields.scoreDocs[0]).fields[0], equalTo(87)); } searcher.getIndexReader().close(); diff --git a/server/src/test/java/org/opensearch/search/sort/BucketedSortForIntsTests.java b/server/src/test/java/org/opensearch/search/sort/BucketedSortForIntsTests.java new file mode 100644 index 0000000000000..c3e0475685001 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/sort/BucketedSortForIntsTests.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.sort; + +import org.apache.lucene.index.LeafReaderContext; +import org.opensearch.search.DocValueFormat; + +public class BucketedSortForIntsTests extends BucketedSortTestCase { + @Override + public BucketedSort.ForInts build( + SortOrder sortOrder, + DocValueFormat format, + int bucketSize, + BucketedSort.ExtraData extra, + double[] values + ) { + return new BucketedSort.ForInts(bigArrays(), sortOrder, format, bucketSize, extra) { + @Override + public Leaf forLeaf(LeafReaderContext ctx) { + return new Leaf(ctx) { + int index = -1; + + @Override + protected boolean advanceExact(int doc) { + index = doc; + return doc < values.length; + } + + @Override + protected int docValue() { + return (int) values[index]; + } + }; + } + }; + } + + @Override + protected SortValue expectedSortValue(double v) { + return SortValue.from((long) v); + } + + @Override + protected double randomValue() { + return randomIntBetween(Integer.MIN_VALUE, Integer.MAX_VALUE); + } +} diff --git a/server/src/test/java/org/opensearch/search/sort/FieldSortBuilderTests.java b/server/src/test/java/org/opensearch/search/sort/FieldSortBuilderTests.java index 1d422f740c555..5bf60c50a1ab2 100644 --- a/server/src/test/java/org/opensearch/search/sort/FieldSortBuilderTests.java +++ b/server/src/test/java/org/opensearch/search/sort/FieldSortBuilderTests.java @@ -504,7 +504,7 @@ public void testGetMaxNumericSortValue() throws IOException { case INTEGER: int v2 = randomInt(); - values[i] = (long) v2; + values[i] = (int) v2; doc.add(new IntPoint(fieldName, v2)); break; @@ -528,13 +528,13 @@ public void testGetMaxNumericSortValue() throws IOException { case BYTE: byte v6 = randomByte(); - values[i] = (long) v6; + values[i] = (int) v6; doc.add(new IntPoint(fieldName, v6)); break; case SHORT: short v7 = randomShort(); - values[i] = (long) v7; + values[i] = (int) v7; doc.add(new IntPoint(fieldName, v7)); break;