From 8795cef3d13854d6219849ef7479bd19920896aa Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 4 Sep 2024 22:08:11 +0000 Subject: [PATCH] Introduce ApproximateRangeQuery and ApproximateQuery (#13788) (#15586) This introduces a basic "approximation" framework that improves query performance by modifying the query in a way that should be functionally equivalent. To start, we can reduce the bounds of a range query in order to satisfy the `track_total_hits` value (which defaults to 10,000). --------- Signed-off-by: Harsha Vamsi Kalluri Signed-off-by: Michael Froh Co-authored-by: Michael Froh (cherry picked from commit 2e9db40a50735eacc95a4fc8926e8bb7042a696a) (cherry picked from commit 3ddb199a77b73364cce725a8dcf594ab572b3d2a) Signed-off-by: github-actions[bot] --- CHANGELOG.md | 1 + .../test/search/370_approximate_range.yml | 72 +++ .../opensearch/common/util/FeatureFlags.java | 10 + .../index/mapper/DateFieldMapper.java | 32 +- .../bucket/filterrewrite/Helper.java | 2 + .../ApproximateIndexOrDocValuesQuery.java | 62 +++ .../ApproximatePointRangeQuery.java | 515 ++++++++++++++++++ .../search/approximate/ApproximateQuery.java | 21 + .../approximate/ApproximateScoreQuery.java | 90 +++ .../search/approximate/package-info.java | 12 + .../search/internal/ContextIndexSearcher.java | 4 + .../index/mapper/DateFieldTypeTests.java | 84 ++- ...angeFieldQueryStringQueryBuilderTests.java | 39 +- .../index/mapper/RangeFieldTypeTests.java | 11 +- .../query/MatchPhraseQueryBuilderTests.java | 2 + .../query/QueryStringQueryBuilderTests.java | 34 +- .../index/query/RangeQueryBuilderTests.java | 147 +++-- ...ApproximateIndexOrDocValuesQueryTests.java | 113 ++++ .../ApproximatePointRangeQueryTests.java | 346 ++++++++++++ .../ApproximateScoreQueryTests.java | 83 +++ 20 files changed, 1622 insertions(+), 58 deletions(-) create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/test/search/370_approximate_range.yml create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQuery.java create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateQuery.java create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java create mode 100644 server/src/main/java/org/opensearch/search/approximate/package-info.java create mode 100644 server/src/test/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQueryTests.java create mode 100644 server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java create mode 100644 server/src/test/java/org/opensearch/search/approximate/ApproximateScoreQueryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 18ab4c09c56a3..3101e5553bfa4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add runAs to Subject interface and introduce IdentityAwarePlugin extension point ([#14630](https://github.com/opensearch-project/OpenSearch/pull/14630)) - [Workload Management] Add rejection logic for co-ordinator and shard level requests ([#15428](https://github.com/opensearch-project/OpenSearch/pull/15428))) - Adding translog durability validation in index templates ([#15494](https://github.com/opensearch-project/OpenSearch/pull/15494)) +- [Range Queries] Add new approximateable query framework to short-circuit range queries ([#13788](https://github.com/opensearch-project/OpenSearch/pull/13788)) - [Workload Management] Add query group level failure tracking ([#15227](https://github.com/opensearch-project/OpenSearch/pull/15527)) - [Reader Writer Separation] Add searchOnly replica routing configuration ([#15410](https://github.com/opensearch-project/OpenSearch/pull/15410)) - Add index creation using the context field ([#15290](https://github.com/opensearch-project/OpenSearch/pull/15290)) diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/370_approximate_range.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/370_approximate_range.yml new file mode 100644 index 0000000000000..ba896dfcad506 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/370_approximate_range.yml @@ -0,0 +1,72 @@ +--- +"search with approximate range": + - do: + indices.create: + index: test + body: + mappings: + properties: + date: + type: date + index: true + doc_values: true + + - do: + bulk: + index: test + refresh: true + body: + - '{"index": {"_index": "test", "_id": "1" }}' + - '{ "date": "2018-10-29T12:12:12.987Z" }' + - '{ "index": { "_index": "test", "_id": "2" }}' + - '{ "date": "2020-10-29T12:12:12.987Z" }' + - '{ "index": { "_index": "test", "_id": "3" } }' + - '{ "date": "2024-10-29T12:12:12.987Z" }' + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + query: + range: { + date: { + gte: "2018-10-29T12:12:12.987Z" + }, + } + + - match: { hits.total: 3 } + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + sort: [{ date: asc }] + query: + range: { + date: { + gte: "2018-10-29T12:12:12.987Z" + }, + } + + + - match: { hits.total: 3 } + - match: { hits.hits.0._id: "1" } + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + sort: [{ date: desc }] + query: + range: { + date: { + gte: "2018-10-29T12:12:12.987Z", + lte: "2020-10-29T12:12:12.987Z" + }, + } + + - match: { hits.total: 2 } + - match: { hits.hits.0._id: "2" } diff --git a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java index f391547b2055e..0ef2e773a690b 100644 --- a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java +++ b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java @@ -125,6 +125,16 @@ public class FeatureFlags { public static final String STAR_TREE_INDEX = "opensearch.experimental.feature.composite_index.star_tree.enabled"; public static final Setting STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, false, Property.NodeScope); + /** + * Gates the functionality of ApproximatePointRangeQuery where we approximate query results. + */ + public static final String APPROXIMATE_POINT_RANGE_QUERY = "opensearch.experimental.feature.approximate_point_range_query.enabled"; + public static final Setting APPROXIMATE_POINT_RANGE_QUERY_SETTING = Setting.boolSetting( + APPROXIMATE_POINT_RANGE_QUERY, + false, + Property.NodeScope + ); + private static final List> ALL_FEATURE_FLAG_SETTINGS = List.of( REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING, EXTENSIONS_SETTING, diff --git a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java index ce21c739e8c98..5a1c2baa12086 100644 --- a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java @@ -62,6 +62,8 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; import org.opensearch.search.lookup.SearchLookup; import java.io.IOException; @@ -81,6 +83,7 @@ import java.util.function.Supplier; import static org.opensearch.common.time.DateUtils.toLong; +import static org.apache.lucene.document.LongPoint.pack; /** * A {@link FieldMapper} for dates. @@ -109,6 +112,21 @@ public static DateFormatter getDefaultDateTimeFormatter() { : LEGACY_DEFAULT_DATE_TIME_FORMATTER; } + public static Query getDefaultQuery(Query pointRangeQuery, Query dvQuery, String name, long l, long u) { + return FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY_SETTING) + ? new ApproximateIndexOrDocValuesQuery( + pointRangeQuery, + new ApproximatePointRangeQuery(name, pack(new long[] { l }).bytes, pack(new long[] { u }).bytes, new long[] { l }.length) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + dvQuery + ) + : new IndexOrDocValuesQuery(pointRangeQuery, dvQuery); + } + /** * Resolution of the date time * @@ -468,24 +486,22 @@ public Query rangeQuery( } DateMathParser parser = forcedDateParser == null ? dateMathParser : forcedDateParser; return dateRangeQuery(lowerTerm, upperTerm, includeLower, includeUpper, timeZone, parser, context, resolution, (l, u) -> { + Query pointRangeQuery = isSearchable() ? LongPoint.newRangeQuery(name(), l, u) : null; + Query dvQuery = hasDocValues() ? SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u) : null; if (isSearchable() && hasDocValues()) { - Query query = LongPoint.newRangeQuery(name(), l, u); - Query dvQuery = SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u); - query = new IndexOrDocValuesQuery(query, dvQuery); - + Query query = getDefaultQuery(pointRangeQuery, dvQuery, name(), l, u); if (context.indexSortedOnField(name())) { query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query); } return query; } if (hasDocValues()) { - Query query = SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u); if (context.indexSortedOnField(name())) { - query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query); + dvQuery = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, dvQuery); } - return query; + return dvQuery; } - return LongPoint.newRangeQuery(name(), l, u); + return pointRangeQuery; }); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Helper.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Helper.java index 7493754d8efa2..17da7e5712be8 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Helper.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Helper.java @@ -23,6 +23,7 @@ import org.opensearch.common.lucene.search.function.FunctionScoreQuery; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.query.DateRangeIncludingNowQuery; +import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery; import org.opensearch.search.internal.SearchContext; import java.io.IOException; @@ -54,6 +55,7 @@ private Helper() {} queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getSubQuery()); queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery) q).getQuery()); queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery()); + queryWrappers.put(ApproximateIndexOrDocValuesQuery.class, q -> ((ApproximateIndexOrDocValuesQuery) q).getOriginalQuery()); } /** diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQuery.java new file mode 100644 index 0000000000000..b99e0a0cbf808 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQuery.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.search.IndexOrDocValuesQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; + +/** + * A wrapper around {@link IndexOrDocValuesQuery} that can be used to run approximate queries. + * It delegates to either {@link ApproximateQuery} or {@link IndexOrDocValuesQuery} based on whether the query can be approximated or not. + * @see ApproximateQuery + */ +public final class ApproximateIndexOrDocValuesQuery extends ApproximateScoreQuery { + + private final ApproximateQuery approximateIndexQuery; + private final IndexOrDocValuesQuery indexOrDocValuesQuery; + + public ApproximateIndexOrDocValuesQuery(Query indexQuery, ApproximateQuery approximateIndexQuery, Query dvQuery) { + super(new IndexOrDocValuesQuery(indexQuery, dvQuery), approximateIndexQuery); + this.approximateIndexQuery = approximateIndexQuery; + this.indexOrDocValuesQuery = new IndexOrDocValuesQuery(indexQuery, dvQuery); + } + + @Override + public String toString(String field) { + return "ApproximateIndexOrDocValuesQuery(indexQuery=" + + indexOrDocValuesQuery.getIndexQuery().toString(field) + + ", approximateIndexQuery=" + + approximateIndexQuery.toString(field) + + ", dvQuery=" + + indexOrDocValuesQuery.getRandomAccessQuery().toString(field) + + ")"; + } + + @Override + public void visit(QueryVisitor visitor) { + indexOrDocValuesQuery.visit(visitor); + } + + @Override + public boolean equals(Object obj) { + if (sameClassAs(obj) == false) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int h = classHash(); + h = 31 * h + indexOrDocValuesQuery.getIndexQuery().hashCode(); + h = 31 * h + indexOrDocValuesQuery.getRandomAccessQuery().hashCode(); + return h; + } +} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java new file mode 100644 index 0000000000000..cee8bc43d7ffd --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -0,0 +1,515 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.PointRangeQuery; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.DocIdSetBuilder; +import org.apache.lucene.util.IntsRef; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * An approximate-able version of {@link PointRangeQuery}. It creates an instance of {@link PointRangeQuery} but short-circuits the intersect logic + * after {@code size} is hit + */ +public abstract class ApproximatePointRangeQuery extends ApproximateQuery { + private int size; + + private SortOrder sortOrder; + + public final PointRangeQuery pointRangeQuery; + + protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) { + this(field, lowerPoint, upperPoint, numDims, 10_000, null); + } + + protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size) { + this(field, lowerPoint, upperPoint, numDims, size, null); + } + + protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size, SortOrder sortOrder) { + this.size = size; + this.sortOrder = sortOrder; + this.pointRangeQuery = new PointRangeQuery(field, lowerPoint, upperPoint, numDims) { + @Override + protected String toString(int dimension, byte[] value) { + return super.toString(field); + } + }; + } + + public int getSize() { + return this.size; + } + + public void setSize(int size) { + this.size = size; + } + + public SortOrder getSortOrder() { + return this.sortOrder; + } + + public void setSortOrder(SortOrder sortOrder) { + this.sortOrder = sortOrder; + } + + @Override + public void visit(QueryVisitor visitor) { + pointRangeQuery.visit(visitor); + } + + @Override + public final ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + Weight pointRangeQueryWeight = pointRangeQuery.createWeight(searcher, scoreMode, boost); + + return new ConstantScoreWeight(this, boost) { + + private final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(pointRangeQuery.getBytesPerDim()); + + // we pull this from PointRangeQuery since it is final + private boolean matches(byte[] packedValue) { + for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) { + int offset = dim * pointRangeQuery.getBytesPerDim(); + if (comparator.compare(packedValue, offset, pointRangeQuery.getLowerPoint(), offset) < 0) { + // Doc's value is too low, in this dimension + return false; + } + if (comparator.compare(packedValue, offset, pointRangeQuery.getUpperPoint(), offset) > 0) { + // Doc's value is too high, in this dimension + return false; + } + } + return true; + } + + // we pull this from PointRangeQuery since it is final + private PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { + + boolean crosses = false; + + for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) { + int offset = dim * pointRangeQuery.getBytesPerDim(); + + if (comparator.compare(minPackedValue, offset, pointRangeQuery.getUpperPoint(), offset) > 0 + || comparator.compare(maxPackedValue, offset, pointRangeQuery.getLowerPoint(), offset) < 0) { + return PointValues.Relation.CELL_OUTSIDE_QUERY; + } + + crosses |= comparator.compare(minPackedValue, offset, pointRangeQuery.getLowerPoint(), offset) < 0 + || comparator.compare(maxPackedValue, offset, pointRangeQuery.getUpperPoint(), offset) > 0; + } + + if (crosses) { + return PointValues.Relation.CELL_CROSSES_QUERY; + } else { + return PointValues.Relation.CELL_INSIDE_QUERY; + } + } + + public PointValues.IntersectVisitor getIntersectVisitor(DocIdSetBuilder result, long[] docCount) { + return new PointValues.IntersectVisitor() { + + DocIdSetBuilder.BulkAdder adder; + + @Override + public void grow(int count) { + adder = result.grow(count); + } + + @Override + public void visit(int docID) { + // it is possible that size < 1024 and docCount < size but we will continue to count through all the 1024 docs + // and collect less, but it won't hurt performance + if (docCount[0] < size) { + adder.add(docID); + docCount[0]++; + } + } + + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } + + @Override + public void visit(IntsRef ref) { + for (int i = 0; i < ref.length; i++) { + adder.add(ref.ints[ref.offset + i]); + } + } + + @Override + public void visit(int docID, byte[] packedValue) { + if (matches(packedValue)) { + visit(docID); + } + } + + @Override + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + if (matches(packedValue)) { + adder.add(iterator); + } + } + + @Override + public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + return relate(minPackedValue, maxPackedValue); + } + }; + } + + // we pull this from PointRangeQuery since it is final + private boolean checkValidPointValues(PointValues values) throws IOException { + if (values == null) { + // No docs in this segment/field indexed any points + return false; + } + + if (values.getNumIndexDimensions() != pointRangeQuery.getNumDims()) { + throw new IllegalArgumentException( + "field=\"" + + pointRangeQuery.getField() + + "\" was indexed with numIndexDimensions=" + + values.getNumIndexDimensions() + + " but this query has numDims=" + + pointRangeQuery.getNumDims() + ); + } + if (pointRangeQuery.getBytesPerDim() != values.getBytesPerDimension()) { + throw new IllegalArgumentException( + "field=\"" + + pointRangeQuery.getField() + + "\" was indexed with bytesPerDim=" + + values.getBytesPerDimension() + + " but this query has bytesPerDim=" + + pointRangeQuery.getBytesPerDim() + ); + } + return true; + } + + private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount) + throws IOException { + intersectLeft(visitor, pointTree, docCount); + assert pointTree.moveToParent() == false; + } + + private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount) + throws IOException { + intersectRight(visitor, pointTree, docCount); + assert pointTree.moveToParent() == false; + } + + // custom intersect visitor to walk the left of the tree + public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) + throws IOException { + PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + if (docCount[0] > size) { + return; + } + switch (r) { + case CELL_OUTSIDE_QUERY: + // This cell is fully outside the query shape: stop recursing + break; + case CELL_INSIDE_QUERY: + // If the cell is fully inside, we keep moving to child until we reach a point where we can no longer move or when + // we have sufficient doc count. We first move down and then move to the left child + if (pointTree.moveToChild() && docCount[0] < size) { + do { + intersectLeft(visitor, pointTree, docCount); + } while (pointTree.moveToSibling() && docCount[0] < size); + pointTree.moveToParent(); + } else { + // we're at the leaf node, if we're under the size, visit all the docIds in this node. + if (docCount[0] < size) { + pointTree.visitDocIDs(visitor); + } + } + break; + case CELL_CROSSES_QUERY: + // The cell crosses the shape boundary, or the cell fully contains the query, so we fall + // through and do full filtering: + if (pointTree.moveToChild() && docCount[0] < size) { + do { + intersectLeft(visitor, pointTree, docCount); + } while (pointTree.moveToSibling() && docCount[0] < size); + pointTree.moveToParent(); + } else { + // TODO: we can assert that the first value here in fact matches what the pointTree + // claimed? + // Leaf node; scan and filter all points in this block: + if (docCount[0] < size) { + pointTree.visitDocValues(visitor); + } + } + break; + default: + throw new IllegalArgumentException("Unreachable code"); + } + } + + // custom intersect visitor to walk the right of tree + public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) + throws IOException { + PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + if (docCount[0] > size) { + return; + } + switch (r) { + case CELL_OUTSIDE_QUERY: + // This cell is fully outside the query shape: stop recursing + break; + + case CELL_INSIDE_QUERY: + // If the cell is fully inside, we keep moving right as long as the point tree size is over our size requirement + if (pointTree.size() > size && docCount[0] < size && moveRight(pointTree)) { + intersectRight(visitor, pointTree, docCount); + pointTree.moveToParent(); + } + // if point tree size is no longer over, we have to go back one level where it still was over and the intersect left + else if (pointTree.size() <= size && docCount[0] < size) { + pointTree.moveToParent(); + intersectLeft(visitor, pointTree, docCount); + } + // if we've reached leaf, it means out size is under the size of the leaf, we can just collect all docIDs + else { + // Leaf node; scan and filter all points in this block: + if (docCount[0] < size) { + pointTree.visitDocIDs(visitor); + } + } + break; + case CELL_CROSSES_QUERY: + // If the cell is fully inside, we keep moving right as long as the point tree size is over our size requirement + if (pointTree.size() > size && docCount[0] < size && moveRight(pointTree)) { + intersectRight(visitor, pointTree, docCount); + pointTree.moveToParent(); + } + // if point tree size is no longer over, we have to go back one level where it still was over and the intersect left + else if (pointTree.size() <= size && docCount[0] < size) { + pointTree.moveToParent(); + intersectLeft(visitor, pointTree, docCount); + } + // if we've reached leaf, it means out size is under the size of the leaf, we can just collect all doc values + else { + // Leaf node; scan and filter all points in this block: + if (docCount[0] < size) { + pointTree.visitDocValues(visitor); + } + } + break; + default: + throw new IllegalArgumentException("Unreachable code"); + } + } + + public boolean moveRight(PointValues.PointTree pointTree) throws IOException { + return pointTree.moveToChild() && pointTree.moveToSibling(); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + LeafReader reader = context.reader(); + long[] docCount = { 0 }; + + PointValues values = reader.getPointValues(pointRangeQuery.getField()); + if (checkValidPointValues(values) == false) { + return null; + } + final Weight weight = this; + if (size > values.size()) { + return pointRangeQueryWeight.scorerSupplier(context); + } else { + if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { + return new ScorerSupplier() { + + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField()); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); + long cost = -1; + + @Override + public Scorer get(long leadCost) throws IOException { + intersectLeft(values.getPointTree(), visitor, docCount); + DocIdSetIterator iterator = result.build().iterator(); + return new ConstantScoreScorer(weight, score(), scoreMode, iterator); + } + + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } + return cost; + } + }; + } else { + // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results + // than expected + final int deletedDocs = reader.numDeletedDocs(); + size += deletedDocs; + return new ScorerSupplier() { + + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, pointRangeQuery.getField()); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); + long cost = -1; + + @Override + public Scorer get(long leadCost) throws IOException { + intersectRight(values.getPointTree(), visitor, docCount); + DocIdSetIterator iterator = result.build().iterator(); + return new ConstantScoreScorer(weight, score(), scoreMode, iterator); + } + + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } + return cost; + } + }; + } + } + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + ScorerSupplier scorerSupplier = scorerSupplier(context); + if (scorerSupplier == null) { + return null; + } + return scorerSupplier.get(Long.MAX_VALUE); + } + + @Override + public int count(LeafReaderContext context) throws IOException { + return pointRangeQueryWeight.count(context); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } + + @Override + public boolean canApproximate(SearchContext context) { + if (context == null) { + return false; + } + if (context.aggregations() != null) { + return false; + } + if (!(context.query() instanceof ApproximateIndexOrDocValuesQuery)) { + return false; + } + this.setSize(Math.max(context.from() + context.size(), context.trackTotalHitsUpTo())); + if (context.request() != null && context.request().source() != null) { + FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source()); + if (primarySortField != null + && primarySortField.missing() == null + && primarySortField.getFieldName().equals(((RangeQueryBuilder) context.request().source().query()).fieldName())) { + if (primarySortField.order() == SortOrder.DESC) { + this.setSortOrder(SortOrder.DESC); + } + } + } + return true; + } + + @Override + public final int hashCode() { + return pointRangeQuery.hashCode(); + } + + @Override + public final boolean equals(Object o) { + return sameClassAs(o) && equalsTo(getClass().cast(o)); + } + + private boolean equalsTo(ApproximatePointRangeQuery other) { + return Objects.equals(pointRangeQuery.getField(), other.pointRangeQuery.getField()) + && pointRangeQuery.getNumDims() == other.pointRangeQuery.getNumDims() + && pointRangeQuery.getBytesPerDim() == other.pointRangeQuery.getBytesPerDim() + && Arrays.equals(pointRangeQuery.getLowerPoint(), other.pointRangeQuery.getLowerPoint()) + && Arrays.equals(pointRangeQuery.getUpperPoint(), other.pointRangeQuery.getUpperPoint()); + } + + @Override + public final String toString(String field) { + final StringBuilder sb = new StringBuilder(); + if (pointRangeQuery.getField().equals(field) == false) { + sb.append(pointRangeQuery.getField()); + sb.append(':'); + } + + // print ourselves as "range per dimension" + for (int i = 0; i < pointRangeQuery.getNumDims(); i++) { + if (i > 0) { + sb.append(','); + } + + int startOffset = pointRangeQuery.getBytesPerDim() * i; + + sb.append('['); + sb.append( + toString( + i, + ArrayUtil.copyOfSubArray(pointRangeQuery.getLowerPoint(), startOffset, startOffset + pointRangeQuery.getBytesPerDim()) + ) + ); + sb.append(" TO "); + sb.append( + toString( + i, + ArrayUtil.copyOfSubArray(pointRangeQuery.getUpperPoint(), startOffset, startOffset + pointRangeQuery.getBytesPerDim()) + ) + ); + sb.append(']'); + } + + return sb.toString(); + } + + /** + * Returns a string of a single value in a human-readable format for debugging. This is used by + * {@link #toString()}. + * + * @param dimension dimension of the particular value + * @param value single value, never null + * @return human readable value for debugging + */ + protected abstract String toString(int dimension, byte[] value); +} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateQuery.java new file mode 100644 index 0000000000000..0e6faf396b671 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateQuery.java @@ -0,0 +1,21 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.search.Query; +import org.opensearch.search.internal.SearchContext; + +/** + * Abstract class that can be inherited by queries that can be approximated. Queries should implement {@link #canApproximate(SearchContext)} to specify conditions on when they can be approximated +*/ +public abstract class ApproximateQuery extends Query { + + protected abstract boolean canApproximate(SearchContext context); + +} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java new file mode 100644 index 0000000000000..d1dd32b239f28 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.Weight; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; + +/** + * Entry-point for the approximation framework. + * This class is heavily inspired by {@link org.apache.lucene.search.IndexOrDocValuesQuery}. It acts as a wrapper that consumer two queries, a regular query and an approximate version of the same. By default, it executes the regular query and returns {@link Weight#scorer} for the original query. At run-time, depending on certain constraints, we can re-write the {@code Weight} to use the approximate weight instead. + */ +public class ApproximateScoreQuery extends Query { + + private final Query originalQuery; + private final ApproximateQuery approximationQuery; + + protected Query resolvedQuery; + + public ApproximateScoreQuery(Query originalQuery, ApproximateQuery approximationQuery) { + this.originalQuery = originalQuery; + this.approximationQuery = approximationQuery; + } + + public Query getOriginalQuery() { + return originalQuery; + } + + public ApproximateQuery getApproximationQuery() { + return approximationQuery; + } + + @Override + public final Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (resolvedQuery == null) { + throw new IllegalStateException("Cannot rewrite resolved query without setContext being called"); + } + return resolvedQuery.rewrite(indexSearcher); + } + + public void setContext(SearchContext context) { + if (resolvedQuery != null) { + throw new IllegalStateException("Query already resolved, duplicate call to setContext"); + } + resolvedQuery = approximationQuery.canApproximate(context) ? approximationQuery : originalQuery; + }; + + @Override + public String toString(String s) { + return "ApproximateScoreQuery(originalQuery=" + + originalQuery.toString() + + ", approximationQuery=" + + approximationQuery.toString() + + ")"; + } + + @Override + public void visit(QueryVisitor queryVisitor) { + QueryVisitor v = queryVisitor.getSubVisitor(BooleanClause.Occur.MUST, this); + originalQuery.visit(v); + approximationQuery.visit(v); + } + + @Override + public boolean equals(Object o) { + if (!sameClassAs(o)) { + return false; + } + return true; + } + + @Override + public int hashCode() { + int h = classHash(); + h = 31 * h + originalQuery.hashCode(); + h = 31 * h + approximationQuery.hashCode(); + return h; + } +} diff --git a/server/src/main/java/org/opensearch/search/approximate/package-info.java b/server/src/main/java/org/opensearch/search/approximate/package-info.java new file mode 100644 index 0000000000000..1a09183c7d9fa --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Approximation query framework to approximate commonly used queries + */ +package org.opensearch.search.approximate; diff --git a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java index fa00ace378df1..f118e4106db83 100644 --- a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java @@ -69,6 +69,7 @@ import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchService; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.search.dfs.AggregatedDfs; import org.opensearch.search.profile.ContextualProfileBreakdown; import org.opensearch.search.profile.Timer; @@ -218,6 +219,9 @@ public Weight createWeight(Query query, ScoreMode scoreMode, float boost) throws profiler.pollLastElement(); } return new ProfileWeight(query, weight, profile); + } else if (query instanceof ApproximateScoreQuery) { + ((ApproximateScoreQuery) query).setContext(searchContext); + return super.createWeight(query, scoreMode, boost); } else { return super.createWeight(query, scoreMode, boost); } diff --git a/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java index 7ed2ca2d150d6..df800067143e4 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DateFieldTypeTests.java @@ -41,7 +41,6 @@ import org.apache.lucene.index.MultiReader; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.IndexOrDocValuesQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSortSortedNumericDocValuesRangeQuery; import org.apache.lucene.search.Query; @@ -53,6 +52,7 @@ import org.opensearch.common.time.DateFormatters; import org.opensearch.common.time.DateMathParser; import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.io.IOUtils; import org.opensearch.index.IndexSettings; import org.opensearch.index.fielddata.IndexNumericFieldData; @@ -65,6 +65,8 @@ import org.opensearch.index.query.DateRangeIncludingNowQuery; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; import org.joda.time.DateTimeZone; import java.io.IOException; @@ -72,6 +74,10 @@ import java.time.ZoneOffset; import java.util.Collections; +import static org.hamcrest.CoreMatchers.is; +import static org.apache.lucene.document.LongPoint.pack; +import static org.junit.Assume.assumeThat; + public class DateFieldTypeTests extends FieldTypeTestCase { private static final long nowInMillis = 0; @@ -207,10 +213,26 @@ public void testTermQuery() { MappedFieldType ft = new DateFieldType("field"); String date = "2015-10-12T14:10:55"; long instant = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date)).toInstant().toEpochMilli(); - Query expected = new IndexOrDocValuesQuery( + Query expected = new ApproximateIndexOrDocValuesQuery( LongPoint.newRangeQuery("field", instant, instant + 999), + new ApproximatePointRangeQuery( + "field", + pack(new long[] { instant }).bytes, + pack(new long[] { instant + 999 }).bytes, + new long[] { instant }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, SortedNumericDocValuesField.newSlowRangeQuery("field", instant, instant + 999) ); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); assertEquals(expected, ft.termQuery(date, context)); MappedFieldType unsearchable = new DateFieldType( @@ -257,10 +279,26 @@ public void testRangeQuery() throws IOException { String date2 = "2016-04-28T11:33:52"; long instant1 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date1)).toInstant().toEpochMilli(); long instant2 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date2)).toInstant().toEpochMilli() + 999; - Query expected = new IndexOrDocValuesQuery( + Query expected = new ApproximateIndexOrDocValuesQuery( LongPoint.newRangeQuery("field", instant1, instant2), + new ApproximatePointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, SortedNumericDocValuesField.newSlowRangeQuery("field", instant1, instant2) ); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); assertEquals( expected, ft.rangeQuery(date1, date2, true, true, null, null, null, context).rewrite(new IndexSearcher(new MultiReader())) @@ -269,11 +307,27 @@ public void testRangeQuery() throws IOException { instant1 = nowInMillis; instant2 = instant1 + 100; expected = new DateRangeIncludingNowQuery( - new IndexOrDocValuesQuery( + new ApproximateIndexOrDocValuesQuery( LongPoint.newRangeQuery("field", instant1, instant2), + new ApproximatePointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, SortedNumericDocValuesField.newSlowRangeQuery("field", instant1, instant2) ) ); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); assertEquals(expected, ft.rangeQuery("now", instant2, true, true, null, null, null, context)); MappedFieldType unsearchable = new DateFieldType( @@ -330,13 +384,31 @@ public void testRangeQueryWithIndexSort() { long instant1 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date1)).toInstant().toEpochMilli(); long instant2 = DateFormatters.from(DateFieldMapper.getDefaultDateTimeFormatter().parse(date2)).toInstant().toEpochMilli() + 999; - Query pointQuery = LongPoint.newRangeQuery("field", instant1, instant2); Query dvQuery = SortedNumericDocValuesField.newSlowRangeQuery("field", instant1, instant2); Query expected = new IndexSortSortedNumericDocValuesRangeQuery( "field", instant1, instant2, - new IndexOrDocValuesQuery(pointQuery, dvQuery) + new ApproximateIndexOrDocValuesQuery( + LongPoint.newRangeQuery("field", instant1, instant2), + new ApproximatePointRangeQuery( + "field", + pack(new long[] { instant1 }).bytes, + pack(new long[] { instant2 }).bytes, + new long[] { instant1 }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + dvQuery + ) + ); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) ); assertEquals(expected, ft.rangeQuery(date1, date2, true, true, null, null, null, context)); } diff --git a/server/src/test/java/org/opensearch/index/mapper/RangeFieldQueryStringQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/mapper/RangeFieldQueryStringQueryBuilderTests.java index 9dea7e13ac45e..7a8ac829bdd97 100644 --- a/server/src/test/java/org/opensearch/index/mapper/RangeFieldQueryStringQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/RangeFieldQueryStringQueryBuilderTests.java @@ -47,15 +47,21 @@ import org.opensearch.common.compress.CompressedXContent; import org.opensearch.common.network.InetAddresses; import org.opensearch.common.time.DateMathParser; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryStringQueryBuilder; +import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; import org.opensearch.test.AbstractQueryTestCase; import java.io.IOException; import java.net.InetAddress; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.either; import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.apache.lucene.document.LongPoint.pack; +import static org.junit.Assume.assumeThat; public class RangeFieldQueryStringQueryBuilderTests extends AbstractQueryTestCase { @@ -173,18 +179,39 @@ public void testDateRangeQuery() throws Exception { DateFieldMapper.DateFieldType dateType = (DateFieldMapper.DateFieldType) context.fieldMapper(DATE_FIELD_NAME); parser = dateType.dateMathParser; Query queryOnDateField = new QueryStringQueryBuilder(DATE_FIELD_NAME + ":[2010-01-01 TO 2018-01-01]").toQuery(createShardContext()); - Query controlQuery = LongPoint.newRangeQuery( - DATE_FIELD_NAME, - new long[] { parser.parse(lowerBoundExact, () -> 0).toEpochMilli() }, - new long[] { parser.parse(upperBoundExact, () -> 0).toEpochMilli() } - ); Query controlDv = SortedNumericDocValuesField.newSlowRangeQuery( DATE_FIELD_NAME, parser.parse(lowerBoundExact, () -> 0).toEpochMilli(), parser.parse(upperBoundExact, () -> 0).toEpochMilli() ); - assertEquals(new IndexOrDocValuesQuery(controlQuery, controlDv), queryOnDateField); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); + assertEquals( + new ApproximateIndexOrDocValuesQuery( + LongPoint.newRangeQuery( + DATE_FIELD_NAME, + parser.parse(lowerBoundExact, () -> 0).toEpochMilli(), + parser.parse(upperBoundExact, () -> 0).toEpochMilli() + ), + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { parser.parse(lowerBoundExact, () -> 0).toEpochMilli() }).bytes, + pack(new long[] { parser.parse(upperBoundExact, () -> 0).toEpochMilli() }).bytes, + new long[] { parser.parse(lowerBoundExact, () -> 0).toEpochMilli() }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + controlDv + ), + queryOnDateField + ); } public void testIPRangeQuery() throws Exception { diff --git a/server/src/test/java/org/opensearch/index/mapper/RangeFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/RangeFieldTypeTests.java index 49bf227e5073c..b157c43e45451 100644 --- a/server/src/test/java/org/opensearch/index/mapper/RangeFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/RangeFieldTypeTests.java @@ -57,6 +57,7 @@ import org.opensearch.index.mapper.RangeFieldMapper.RangeFieldType; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; +import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery; import org.opensearch.test.IndexSettingsModule; import org.joda.time.DateTime; import org.junit.Before; @@ -285,7 +286,15 @@ public void testDateRangeQueryUsingMappingFormatLegacy() { // compare lower and upper bounds with what we would get on a `date` field DateFieldType dateFieldType = new DateFieldType("field", DateFieldMapper.Resolution.MILLISECONDS, formatter); final Query queryOnDateField = dateFieldType.rangeQuery(from, to, true, true, relation, null, fieldType.dateMathParser(), context); - assertEquals("field:[1465975790000 TO 1466062190999]", ((IndexOrDocValuesQuery) queryOnDateField).getIndexQuery().toString()); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); + assertEquals( + "field:[1465975790000 TO 1466062190999]", + ((IndexOrDocValuesQuery) ((ApproximateIndexOrDocValuesQuery) queryOnDateField).getOriginalQuery()).getIndexQuery().toString() + ); } public void testDateRangeQueryUsingMappingFormat() { diff --git a/server/src/test/java/org/opensearch/index/query/MatchPhraseQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/MatchPhraseQueryBuilderTests.java index f8d5d2ce3d062..ddf58073a5206 100644 --- a/server/src/test/java/org/opensearch/index/query/MatchPhraseQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/MatchPhraseQueryBuilderTests.java @@ -42,6 +42,7 @@ import org.apache.lucene.search.TermQuery; import org.opensearch.core.common.ParsingException; import org.opensearch.index.search.MatchQuery.ZeroTermsQuery; +import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery; import org.opensearch.test.AbstractQueryTestCase; import java.io.IOException; @@ -130,6 +131,7 @@ protected void doAssertLuceneQuery(MatchPhraseQueryBuilder queryBuilder, Query q .or(instanceOf(PointRangeQuery.class)) .or(instanceOf(IndexOrDocValuesQuery.class)) .or(instanceOf(MatchNoDocsQuery.class)) + .or(instanceOf(ApproximateIndexOrDocValuesQuery.class)) ); } diff --git a/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java index af4a34aa98116..5b030df20e889 100644 --- a/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java @@ -47,7 +47,6 @@ import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.DisjunctionMaxQuery; import org.apache.lucene.search.FuzzyQuery; -import org.apache.lucene.search.IndexOrDocValuesQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.MultiTermQuery; @@ -71,11 +70,14 @@ import org.opensearch.common.compress.CompressedXContent; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.Fuzziness; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.mapper.FieldNamesFieldMapper; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.search.QueryStringQueryParser; +import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; import org.opensearch.test.AbstractQueryTestCase; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; @@ -98,6 +100,9 @@ import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.apache.lucene.document.LongPoint.pack; +import static org.junit.Assume.assumeThat; public class QueryStringQueryBuilderTests extends AbstractQueryTestCase { @@ -853,7 +858,12 @@ public void testToQueryDateWithTimeZone() throws Exception { QueryStringQueryBuilder qsq = queryStringQuery(DATE_FIELD_NAME + ":1970-01-01"); QueryShardContext context = createShardContext(); Query query = qsq.toQuery(context); - assertThat(query, instanceOf(IndexOrDocValuesQuery.class)); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); + assertThat(query, instanceOf(ApproximateIndexOrDocValuesQuery.class)); long lower = 0; // 1970-01-01T00:00:00.999 UTC long upper = 86399999; // 1970-01-01T23:59:59.999 UTC assertEquals(calculateExpectedDateQuery(lower, upper), query); @@ -862,10 +872,22 @@ public void testToQueryDateWithTimeZone() throws Exception { assertEquals(calculateExpectedDateQuery(lower + msPerHour, upper + msPerHour), qsq.timeZone("-01:00").toQuery(context)); } - private IndexOrDocValuesQuery calculateExpectedDateQuery(long lower, long upper) { - Query query = LongPoint.newRangeQuery(DATE_FIELD_NAME, lower, upper); - Query dv = SortedNumericDocValuesField.newSlowRangeQuery(DATE_FIELD_NAME, lower, upper); - return new IndexOrDocValuesQuery(query, dv); + private ApproximateIndexOrDocValuesQuery calculateExpectedDateQuery(long lower, long upper) { + return new ApproximateIndexOrDocValuesQuery( + LongPoint.newRangeQuery(DATE_FIELD_NAME, lower, upper), + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + SortedNumericDocValuesField.newSlowRangeQuery(DATE_FIELD_NAME, lower, upper) + ); } public void testFuzzyNumeric() throws Exception { diff --git a/server/src/test/java/org/opensearch/index/query/RangeQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/RangeQueryBuilderTests.java index 64b3eea029bd1..601a66d229131 100644 --- a/server/src/test/java/org/opensearch/index/query/RangeQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/RangeQueryBuilderTests.java @@ -34,6 +34,7 @@ import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.LongPoint; +import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.index.Term; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.DocValuesFieldExistsQuery; @@ -47,12 +48,16 @@ import org.opensearch.OpenSearchParseException; import org.opensearch.common.geo.ShapeRelation; import org.opensearch.common.lucene.BytesRefs; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.ParsingException; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.FieldNamesFieldMapper; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MappedFieldType.Relation; import org.opensearch.index.mapper.MapperService; +import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery; +import org.opensearch.search.approximate.ApproximatePointRangeQuery; +import org.opensearch.search.approximate.ApproximateQuery; import org.opensearch.test.AbstractQueryTestCase; import org.joda.time.DateTime; import org.joda.time.chrono.ISOChronology; @@ -65,9 +70,12 @@ import java.util.Map; import static org.opensearch.index.query.QueryBuilders.rangeQuery; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.sameInstance; +import static org.apache.lucene.document.LongPoint.pack; +import static org.junit.Assume.assumeThat; public class RangeQueryBuilderTests extends AbstractQueryTestCase { @Override @@ -183,9 +191,16 @@ protected void doAssertLuceneQuery(RangeQueryBuilder queryBuilder, Query query, assertThat(termRangeQuery.includesLower(), equalTo(queryBuilder.includeLower())); assertThat(termRangeQuery.includesUpper(), equalTo(queryBuilder.includeUpper())); } else if (expectedFieldName.equals(DATE_FIELD_NAME)) { - assertThat(query, instanceOf(IndexOrDocValuesQuery.class)); - query = ((IndexOrDocValuesQuery) query).getIndexQuery(); - assertThat(query, instanceOf(PointRangeQuery.class)); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); + assertThat(query, instanceOf(ApproximateIndexOrDocValuesQuery.class)); + Query approximationQuery = ((ApproximateIndexOrDocValuesQuery) query).getApproximationQuery(); + assertThat(approximationQuery, instanceOf(ApproximateQuery.class)); + Query originalQuery = ((ApproximateIndexOrDocValuesQuery) query).getOriginalQuery(); + assertThat(originalQuery, instanceOf(IndexOrDocValuesQuery.class)); MapperService mapperService = context.getMapperService(); MappedFieldType mappedFieldType = mapperService.fieldType(expectedFieldName); final Long fromInMillis; @@ -234,7 +249,24 @@ protected void doAssertLuceneQuery(RangeQueryBuilder queryBuilder, Query query, maxLong--; } } - assertEquals(LongPoint.newRangeQuery(DATE_FIELD_NAME, minLong, maxLong), query); + assertEquals( + new ApproximateIndexOrDocValuesQuery( + LongPoint.newRangeQuery(DATE_FIELD_NAME, minLong, maxLong), + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { minLong }).bytes, + pack(new long[] { maxLong }).bytes, + new long[] { minLong }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + SortedNumericDocValuesField.newSlowRangeQuery(DATE_FIELD_NAME, minLong, maxLong) + ), + query + ); } else if (expectedFieldName.equals(INT_FIELD_NAME)) { assertThat(query, instanceOf(IndexOrDocValuesQuery.class)); query = ((IndexOrDocValuesQuery) query).getIndexQuery(); @@ -299,15 +331,33 @@ public void testDateRangeQueryFormat() throws IOException { + " }\n" + "}"; Query parsedQuery = parseQuery(query).toQuery(createShardContext()); - assertThat(parsedQuery, instanceOf(IndexOrDocValuesQuery.class)); - parsedQuery = ((IndexOrDocValuesQuery) parsedQuery).getIndexQuery(); - assertThat(parsedQuery, instanceOf(PointRangeQuery.class)); - + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); + assertThat(parsedQuery, instanceOf(ApproximateIndexOrDocValuesQuery.class)); + Query approximationQuery = ((ApproximateIndexOrDocValuesQuery) parsedQuery).getApproximationQuery(); + assertThat(approximationQuery, instanceOf(ApproximateQuery.class)); + Query originalQuery = ((ApproximateIndexOrDocValuesQuery) parsedQuery).getOriginalQuery(); + assertThat(originalQuery, instanceOf(IndexOrDocValuesQuery.class)); + long lower = DateTime.parse("2012-01-01T00:00:00.000+00").getMillis(); + long upper = DateTime.parse("2030-01-01T00:00:00.000+00").getMillis() - 1; assertEquals( - LongPoint.newRangeQuery( - DATE_FIELD_NAME, - DateTime.parse("2012-01-01T00:00:00.000+00").getMillis(), - DateTime.parse("2030-01-01T00:00:00.000+00").getMillis() - 1 + new ApproximateIndexOrDocValuesQuery( + LongPoint.newRangeQuery(DATE_FIELD_NAME, lower, upper), + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + SortedNumericDocValuesField.newSlowRangeQuery(DATE_FIELD_NAME, lower, upper) ), parsedQuery ); @@ -339,15 +389,33 @@ public void testDateRangeBoundaries() throws IOException { + " }\n" + "}\n"; Query parsedQuery = parseQuery(query).toQuery(createShardContext()); - assertThat(parsedQuery, instanceOf(IndexOrDocValuesQuery.class)); - parsedQuery = ((IndexOrDocValuesQuery) parsedQuery).getIndexQuery(); - assertThat(parsedQuery, instanceOf(PointRangeQuery.class)); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); + assertThat(parsedQuery, instanceOf(ApproximateIndexOrDocValuesQuery.class)); + + long lower = DateTime.parse("2014-11-01T00:00:00.000+00").getMillis(); + long upper = DateTime.parse("2014-12-08T23:59:59.999+00").getMillis(); assertEquals( - LongPoint.newRangeQuery( - DATE_FIELD_NAME, - DateTime.parse("2014-11-01T00:00:00.000+00").getMillis(), - DateTime.parse("2014-12-08T23:59:59.999+00").getMillis() - ), + new ApproximateIndexOrDocValuesQuery( + LongPoint.newRangeQuery(DATE_FIELD_NAME, lower, upper), + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + SortedNumericDocValuesField.newSlowRangeQuery(DATE_FIELD_NAME, lower, upper) + ) + + , parsedQuery ); @@ -362,15 +430,27 @@ public void testDateRangeBoundaries() throws IOException { + " }\n" + "}"; parsedQuery = parseQuery(query).toQuery(createShardContext()); - assertThat(parsedQuery, instanceOf(IndexOrDocValuesQuery.class)); - parsedQuery = ((IndexOrDocValuesQuery) parsedQuery).getIndexQuery(); - assertThat(parsedQuery, instanceOf(PointRangeQuery.class)); + assertThat(parsedQuery, instanceOf(ApproximateIndexOrDocValuesQuery.class)); + lower = DateTime.parse("2014-11-30T23:59:59.999+00").getMillis() + 1; + upper = DateTime.parse("2014-12-08T00:00:00.000+00").getMillis() - 1; assertEquals( - LongPoint.newRangeQuery( - DATE_FIELD_NAME, - DateTime.parse("2014-11-30T23:59:59.999+00").getMillis() + 1, - DateTime.parse("2014-12-08T00:00:00.000+00").getMillis() - 1 - ), + new ApproximateIndexOrDocValuesQuery( + LongPoint.newRangeQuery(DATE_FIELD_NAME, lower, upper), + new ApproximatePointRangeQuery( + DATE_FIELD_NAME, + pack(new long[] { lower }).bytes, + pack(new long[] { upper }).bytes, + new long[] { lower }.length + ) { + @Override + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }, + SortedNumericDocValuesField.newSlowRangeQuery(DATE_FIELD_NAME, lower, upper) + ) + + , parsedQuery ); } @@ -391,9 +471,14 @@ public void testDateRangeQueryTimezone() throws IOException { Query parsedQuery = parseQuery(query).toQuery(context); assertThat(parsedQuery, instanceOf(DateRangeIncludingNowQuery.class)); parsedQuery = ((DateRangeIncludingNowQuery) parsedQuery).getQuery(); - assertThat(parsedQuery, instanceOf(IndexOrDocValuesQuery.class)); - parsedQuery = ((IndexOrDocValuesQuery) parsedQuery).getIndexQuery(); - assertThat(parsedQuery, instanceOf(PointRangeQuery.class)); + assumeThat( + "Using Approximate Range Query as default", + FeatureFlags.isEnabled(FeatureFlags.APPROXIMATE_POINT_RANGE_QUERY), + is(true) + ); + assertThat(parsedQuery, instanceOf(ApproximateIndexOrDocValuesQuery.class)); + parsedQuery = ((ApproximateIndexOrDocValuesQuery) parsedQuery).getApproximationQuery(); + assertThat(parsedQuery, instanceOf(ApproximateQuery.class)); // TODO what else can we assert query = "{\n" diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQueryTests.java new file mode 100644 index 0000000000000..47f87c6abf629 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQueryTests.java @@ -0,0 +1,113 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; + +import static org.apache.lucene.document.LongPoint.pack; + +public class ApproximateIndexOrDocValuesQueryTests extends OpenSearchTestCase { + private Directory dir; + private IndexWriter w; + private DirectoryReader reader; + private IndexSearcher searcher; + + @Before + public void initSearcher() throws IOException { + dir = newDirectory(); + w = new IndexWriter(dir, newIndexWriterConfig()); + } + + @After + public void closeAllTheReaders() throws IOException { + reader.close(); + w.close(); + dir.close(); + } + + public void testApproximateIndexOrDocValuesQueryWeight() throws IOException { + + long l = Long.MIN_VALUE; + long u = Long.MAX_VALUE; + Query indexQuery = LongPoint.newRangeQuery("test-index", l, u); + + ApproximateQuery approximateIndexQuery = new ApproximatePointRangeQuery( + "test-index", + pack(new long[] { l }).bytes, + pack(new long[] { u }).bytes, + new long[] { l }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + + Query dvQuery = SortedNumericDocValuesField.newSlowRangeQuery("test-index", l, u); + + ApproximateIndexOrDocValuesQuery approximateIndexOrDocValuesQuery = new ApproximateIndexOrDocValuesQuery( + indexQuery, + approximateIndexQuery, + dvQuery + ); + + reader = DirectoryReader.open(w); + searcher = newSearcher(reader); + + approximateIndexOrDocValuesQuery.resolvedQuery = indexQuery; + + Weight weight = approximateIndexOrDocValuesQuery.rewrite(searcher).createWeight(searcher, ScoreMode.COMPLETE, 1f); + + assertTrue(weight instanceof ConstantScoreWeight); + + ApproximateQuery approximateIndexQueryCanApproximate = new ApproximatePointRangeQuery( + "test-index", + pack(new long[] { l }).bytes, + pack(new long[] { u }).bytes, + new long[] { l }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + + public boolean canApproximate(SearchContext context) { + return true; + } + }; + + ApproximateIndexOrDocValuesQuery approximateIndexOrDocValuesQueryCanApproximate = new ApproximateIndexOrDocValuesQuery( + indexQuery, + approximateIndexQueryCanApproximate, + dvQuery + ); + + approximateIndexOrDocValuesQueryCanApproximate.resolvedQuery = approximateIndexQueryCanApproximate; + + Weight approximateIndexOrDocValuesQueryCanApproximateWeight = approximateIndexOrDocValuesQueryCanApproximate.rewrite(searcher) + .createWeight(searcher, ScoreMode.COMPLETE, 1f); + + // we get ConstantScoreWeight since we're expecting to call ApproximatePointRangeQuery + assertTrue(approximateIndexOrDocValuesQueryCanApproximateWeight instanceof ConstantScoreWeight); + + } +} diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java new file mode 100644 index 0000000000000..dd683d28f00f7 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java @@ -0,0 +1,346 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.apache.lucene.document.LongPoint.pack; +import static org.mockito.Mockito.mock; + +public class ApproximatePointRangeQueryTests extends OpenSearchTestCase { + + protected static final String DATE_FIELD_NAME = "mapped_date"; + + public void testApproximateRangeEqualsActualRange() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + for (int i = 0; i < 100; i++) { + int numPoints = RandomNumbers.randomIntBetween(random(), 1, 10); + Document doc = new Document(); + for (int j = 0; j < numPoints; j++) { + for (int v = 0; v < dims; v++) { + scratch[v] = RandomNumbers.randomLongBetween(random(), 0, 100); + } + doc.add(new LongPoint("point", scratch)); + } + iw.addDocument(doc); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = RandomNumbers.randomLongBetween(random(), -100, 200); + long upper = lower + RandomNumbers.randomLongBetween(random(), 0, 100); + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = LongPoint.newRangeQuery("point", lower, upper); + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + assertEquals(topDocs.totalHits, topDocs1.totalHits); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeWithDefaultSize() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + if (i % 15 == 0) iw.flush(); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 1000; + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + assertEquals(topDocs.totalHits, new TotalHits(1000, TotalHits.Relation.EQUAL_TO)); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeWithSizeUnderDefault() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + if (i % 15 == 0) iw.flush(); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 45; + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 10) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + assertEquals(topDocs.totalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO)); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeWithSizeOverDefault() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 15000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 12000; + Query approximateQuery = new ApproximatePointRangeQuery( + "point", + pack(lower).bytes, + pack(upper).bytes, + dims, + 11_000 + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 11000); + assertEquals(topDocs.totalHits, new TotalHits(11001, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeShortCircuit() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + if (i % 10 == 0) iw.flush(); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 100; + Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 10) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = LongPoint.newRangeQuery("point", lower, upper); + ; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + + // since we short-circuit from the approx range at the end of size these will not be equal + assertNotEquals(topDocs.totalHits, topDocs1.totalHits); + assertEquals(topDocs.totalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO)); + assertEquals(topDocs1.totalHits, new TotalHits(101, TotalHits.Relation.EQUAL_TO)); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testApproximateRangeShortCircuitAscSort() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + + long[] scratch = new long[dims]; + int numPoints = 1000; + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + for (int v = 0; v < dims; v++) { + scratch[v] = i; + } + doc.add(new LongPoint("point", scratch)); + iw.addDocument(doc); + } + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + long lower = 0; + long upper = 20; + Query approximateQuery = new ApproximatePointRangeQuery( + "point", + pack(lower).bytes, + pack(upper).bytes, + dims, + 10, + SortOrder.ASC + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + Query query = LongPoint.newRangeQuery("point", lower, upper); + ; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(approximateQuery, 10); + TopDocs topDocs1 = searcher.search(query, 10); + + // since we short-circuit from the approx range at the end of size these will not be equal + assertNotEquals(topDocs.totalHits, topDocs1.totalHits); + assertEquals(topDocs.totalHits, new TotalHits(10, TotalHits.Relation.EQUAL_TO)); + assertEquals(topDocs1.totalHits, new TotalHits(21, TotalHits.Relation.EQUAL_TO)); + assertEquals(topDocs.scoreDocs[0].doc, 0); + assertEquals(topDocs.scoreDocs[1].doc, 1); + assertEquals(topDocs.scoreDocs[2].doc, 2); + assertEquals(topDocs.scoreDocs[3].doc, 3); + assertEquals(topDocs.scoreDocs[4].doc, 4); + assertEquals(topDocs.scoreDocs[5].doc, 5); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } + + public void testSize() { + ApproximatePointRangeQuery query = new ApproximatePointRangeQuery("point", pack(0).bytes, pack(20).bytes, 1) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + assertEquals(query.getSize(), 10_000); + + query.setSize(100); + assertEquals(query.getSize(), 100); + + } + + public void testSortOrder() { + ApproximatePointRangeQuery query = new ApproximatePointRangeQuery("point", pack(0).bytes, pack(20).bytes, 1) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + assertNull(query.getSortOrder()); + + query.setSortOrder(SortOrder.ASC); + assertEquals(query.getSortOrder(), SortOrder.ASC); + } + + public void testCanApproximate() { + ApproximatePointRangeQuery query = new ApproximatePointRangeQuery("point", pack(0).bytes, pack(20).bytes, 1) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + + assertFalse(query.canApproximate(null)); + + ApproximatePointRangeQuery queryCanApproximate = new ApproximatePointRangeQuery("point", pack(0).bytes, pack(20).bytes, 1) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + + public boolean canApproximate(SearchContext context) { + return true; + } + }; + SearchContext searchContext = mock(SearchContext.class); + assertTrue(queryCanApproximate.canApproximate(searchContext)); + } +} diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateScoreQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateScoreQueryTests.java new file mode 100644 index 0000000000000..aa45ea6744227 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateScoreQueryTests.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.PointRangeQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.apache.lucene.document.LongPoint.pack; + +public class ApproximateScoreQueryTests extends OpenSearchTestCase { + + public void testApproximationScoreSupplier() throws IOException { + long l = Long.MIN_VALUE; + long u = Long.MAX_VALUE; + Query originalQuery = new PointRangeQuery( + "test-index", + pack(new long[] { l }).bytes, + pack(new long[] { u }).bytes, + new long[] { l }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + + ApproximateQuery approximateQuery = new ApproximatePointRangeQuery( + "test-index", + pack(new long[] { l }).bytes, + pack(new long[] { u }).bytes, + new long[] { l }.length + ) { + protected String toString(int dimension, byte[] value) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } + }; + + ApproximateScoreQuery query = new ApproximateScoreQuery(originalQuery, approximateQuery); + query.resolvedQuery = approximateQuery; + + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + Document document = new Document(); + document.add(new LongPoint("testPoint", Long.MIN_VALUE)); + iw.addDocument(document); + iw.flush(); + try (IndexReader reader = iw.getReader()) { + try { + IndexSearcher searcher = new IndexSearcher(reader); + searcher.search(query, 10); + Weight weight = query.rewrite(searcher).createWeight(searcher, ScoreMode.TOP_SCORES, 1.0F); + Scorer scorer = weight.scorer(reader.leaves().get(0)); + assertEquals( + scorer, + originalQuery.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0F).scorer(searcher.getLeafContexts().get(0)) + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + } + } + } +}