From 34f0453283a45acac366539d08d47a5e8939204a Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 15 Jan 2025 09:08:29 -0500 Subject: [PATCH] Add two new "Seeded" Knn queries for seeded vector search (#14084) ### Description In some vector search cases, users may already know some documents that are likely related to a query. Let's support seeding HNSW's scoring stage with these documents, rather than using HNSW's hierarchical stage. An example use case is hybrid search, where both a traditional and vector search are performed. The top results from the traditional search are likely reasonable seeds for the vector search. Even when not performing hybrid search, traditional matching can often be faster than traversing the hierarchy, which can be used to speed up the vector search process (up to 2x faster for the same effectiveness), as was demonstrated in [this article](https://arxiv.org/abs/2307.16779) (full disclosure: seanmacavaney is an author of the article). The main changes are: - A new "seeded" focused knn collector and collector manager - Two new basic knn queries that expose using these specialized collectors for seeded entrypoint - `HnswGraphSearcher`, which bypasses the `findBestEntryPoint` step if seeds are provided. //cc @seanmacavaney Co-authored-by: Sean MacAvaney Co-authored-by: Sean MacAvaney Co-authored-by: Christine Poerschke --- lucene/CHANGES.txt | 6 +- .../lucene/search/KnnByteVectorQuery.java | 2 +- .../apache/lucene/search/KnnCollector.java | 54 +++++ .../lucene/search/KnnFloatVectorQuery.java | 2 +- .../search/SeededKnnByteVectorQuery.java | 97 +++++++++ .../search/SeededKnnFloatVectorQuery.java | 97 +++++++++ .../TimeLimitingKnnCollectorManager.java | 42 +--- .../lucene/search/knn/EntryPointProvider.java | 28 +++ .../lucene/search/knn/SeededKnnCollector.java | 48 ++++ .../search/knn/SeededKnnCollectorManager.java | 177 +++++++++++++++ .../lucene/util/hnsw/HnswGraphSearcher.java | 31 ++- .../hnsw/OrdinalTranslatedKnnCollector.java | 42 +--- .../lucene/document/TestManyKnnDocs.java | 136 +++++++++++- .../lucene/search/TestKnnByteVectorQuery.java | 4 +- .../search/TestKnnFloatVectorQuery.java | 2 +- .../search/TestSeededKnnByteVectorQuery.java | 205 ++++++++++++++++++ .../search/TestSeededKnnFloatVectorQuery.java | 191 ++++++++++++++++ 17 files changed, 1075 insertions(+), 89 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index c0a73e1bbedb..5084c25f3560 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -46,7 +46,11 @@ API Changes New Features --------------------- -(No changes) + +* GITHUB#14084, GITHUB#13635, GITHUB#13634: Adds new `SeededKnnByteVectorQuery` and `SeededKnnFloatVectorQuery` + queries. These queries allow for the vector search entry points to be initialized via a `seed` query. This follows + the research provided via https://arxiv.org/abs/2307.16779. (Sean MacAvaney, Ben Trent). + Improvements --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 35144055830c..05157ab65cb5 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -46,7 +46,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final byte[] target; + protected final byte[] target; /** * Find the k nearest documents to the target vector according to the vectors in the diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index 43bac9fbc309..f694d8f7085c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -85,4 +85,58 @@ public interface KnnCollector { * @return The collected top documents */ TopDocs topDocs(); + + /** + * KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend + * the object with new behaviors. + * + * @lucene.experimental + */ + abstract class Decorator implements KnnCollector { + private final KnnCollector collector; + + public Decorator(KnnCollector collector) { + this.collector = collector; + } + + @Override + public boolean earlyTerminated() { + return collector.earlyTerminated(); + } + + @Override + public void incVisitedCount(int count) { + collector.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return collector.visitedCount(); + } + + @Override + public long visitLimit() { + return collector.visitLimit(); + } + + @Override + public int k() { + return collector.k(); + } + + @Override + public boolean collect(int docId, float similarity) { + return collector.collect(docId, similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return collector.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + return collector.topDocs(); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index d2aaf4296eda..c7d6fdb3608d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final float[] target; + protected final float[] target; /** * Find the k nearest documents to the target vector according to the vectors in the diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java new file mode 100644 index 000000000000..980b6869c34f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.SeededKnnCollectorManager; + +/** + * This is a version of knn byte vector query that provides a query seed to initiate the vector + * search. NOTE: The underlying format is free to ignore the provided seed + * + *

See "Lexically-Accelerated Dense + * Retrieval" (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir). + * In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and + * Development in Information Retrieval Pages 152 - 162 + * + * @lucene.experimental + */ +public class SeededKnnByteVectorQuery extends KnnByteVectorQuery { + final Query seed; + final Weight seedWeight; + + /** + * Construct a new SeededKnnByteVectorQuery instance + * + * @param field knn byte vector field to query + * @param target the query vector + * @param k number of neighbors to return + * @param filter a filter on the neighbors to return + * @param seed a query seed to initiate the vector format search + */ + public SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter); + this.seed = Objects.requireNonNull(seed); + this.seedWeight = null; + } + + SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter); + this.seed = null; + this.seedWeight = Objects.requireNonNull(seedWeight); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + SeededKnnByteVectorQuery rewritten = + new SeededKnnByteVectorQuery(field, target, k, filter, seedWeight); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + if (seedWeight == null) { + throw new UnsupportedOperationException("must be rewritten before constructing manager"); + } + return new SeededKnnCollectorManager( + super.getKnnCollectorManager(k, searcher), + seedWeight, + k, + leaf -> { + ByteVectorValues vv = leaf.getByteVectorValues(field); + if (vv == null) { + ByteVectorValues.checkField(leaf.getContext().reader(), field); + } + return vv; + }); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java new file mode 100644 index 000000000000..02a33bdcdef7 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.SeededKnnCollectorManager; + +/** + * This is a version of knn float vector query that provides a query seed to initiate the vector + * search. NOTE: The underlying format is free to ignore the provided seed. + * + *

See "Lexically-Accelerated Dense + * Retrieval" (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir). + * In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and + * Development in Information Retrieval Pages 152 - 162 + * + * @lucene.experimental + */ +public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery { + final Query seed; + final Weight seedWeight; + + /** + * Construct a new SeededKnnFloatVectorQuery instance + * + * @param field knn float vector field to query + * @param target the query vector + * @param k number of neighbors to return + * @param filter a filter on the neighbors to return + * @param seed a query seed to initiate the vector format search + */ + public SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter); + this.seed = Objects.requireNonNull(seed); + this.seedWeight = null; + } + + SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter); + this.seed = null; + this.seedWeight = Objects.requireNonNull(seedWeight); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + SeededKnnFloatVectorQuery rewritten = + new SeededKnnFloatVectorQuery(field, target, k, filter, seedWeight); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + if (seedWeight == null) { + throw new UnsupportedOperationException("must be rewritten before constructing manager"); + } + return new SeededKnnCollectorManager( + super.getKnnCollectorManager(k, searcher), + seedWeight, + k, + leaf -> { + FloatVectorValues vv = leaf.getFloatVectorValues(field); + if (vv == null) { + FloatVectorValues.checkField(leaf.getContext().reader(), field); + } + return vv; + }); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java index 2a1f312fbc58..2dc2f035b90f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java @@ -45,51 +45,19 @@ public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) th return new TimeLimitingKnnCollector(collector); } - class TimeLimitingKnnCollector implements KnnCollector { - private final KnnCollector collector; - - TimeLimitingKnnCollector(KnnCollector collector) { - this.collector = collector; + class TimeLimitingKnnCollector extends KnnCollector.Decorator { + public TimeLimitingKnnCollector(KnnCollector collector) { + super(collector); } @Override public boolean earlyTerminated() { - return queryTimeout.shouldExit() || collector.earlyTerminated(); - } - - @Override - public void incVisitedCount(int count) { - collector.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return collector.visitedCount(); - } - - @Override - public long visitLimit() { - return collector.visitLimit(); - } - - @Override - public int k() { - return collector.k(); - } - - @Override - public boolean collect(int docId, float similarity) { - return collector.collect(docId, similarity); - } - - @Override - public float minCompetitiveSimilarity() { - return collector.minCompetitiveSimilarity(); + return queryTimeout.shouldExit() || super.earlyTerminated(); } @Override public TopDocs topDocs() { - TopDocs docs = collector.topDocs(); + TopDocs docs = super.topDocs(); // Mark results as partial if timeout is met TotalHits.Relation relation = diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java new file mode 100644 index 000000000000..9e7b44b571df --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.DocIdSetIterator; + +/** Provides entry points for the kNN search */ +public interface EntryPointProvider { + /** Iterator of valid entry points for the kNN search */ + DocIdSetIterator entryPoints(); + + /** Number of valid entry points for the kNN search */ + int numberOfEntryPoints(); +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java new file mode 100644 index 000000000000..c3c4f62901ee --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; + +/** + * A {@link KnnCollector} that provides seeded knn collection. See usage in {@link + * SeededKnnCollectorManager}. + * + * @lucene.experimental + */ +class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider { + private final DocIdSetIterator entryPoints; + private final int numberOfEntryPoints; + + SeededKnnCollector( + KnnCollector collector, DocIdSetIterator entryPoints, int numberOfEntryPoints) { + super(collector); + this.entryPoints = entryPoints; + this.numberOfEntryPoints = numberOfEntryPoints; + } + + @Override + public DocIdSetIterator entryPoints() { + return entryPoints; + } + + @Override + public int numberOfEntryPoints() { + return numberOfEntryPoints; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java new file mode 100644 index 000000000000..7631db6e3022 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.search.TopScoreDocCollectorManager; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.IOFunction; + +/** + * A {@link KnnCollectorManager} that provides seeded knn collection. See usage in {@link + * org.apache.lucene.search.SeededKnnFloatVectorQuery} and {@link + * org.apache.lucene.search.SeededKnnByteVectorQuery}. + */ +public class SeededKnnCollectorManager implements KnnCollectorManager { + private final KnnCollectorManager delegate; + private final Weight seedWeight; + private final int k; + private final IOFunction vectorValuesSupplier; + + public SeededKnnCollectorManager( + KnnCollectorManager delegate, + Weight seedWeight, + int k, + IOFunction vectorValuesSupplier) { + this.delegate = delegate; + this.seedWeight = seedWeight; + this.k = k; + this.vectorValuesSupplier = vectorValuesSupplier; + } + + @Override + public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws IOException { + // Execute the seed query + TopScoreDocCollector seedCollector = + new TopScoreDocCollectorManager(k, null, Integer.MAX_VALUE).newCollector(); + final LeafReader leafReader = ctx.reader(); + final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); + if (leafCollector != null) { + try { + BulkScorer scorer = seedWeight.bulkScorer(ctx); + if (scorer != null) { + scorer.score( + leafCollector, + leafReader.getLiveDocs(), + 0 /* min */, + DocIdSetIterator.NO_MORE_DOCS /* max */); + } + } catch ( + @SuppressWarnings("unused") + CollectionTerminatedException e) { + } + leafCollector.finish(); + } + + TopDocs seedTopDocs = seedCollector.topDocs(); + KnnVectorValues vectorValues = vectorValuesSupplier.apply(leafReader); + final KnnCollector delegateCollector = delegate.newCollector(visitedLimit, ctx); + if (seedTopDocs.totalHits.value() == 0 || vectorValues == null) { + return delegateCollector; + } + KnnVectorValues.DocIndexIterator indexIterator = vectorValues.iterator(); + DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs)); + return new SeededKnnCollector(delegateCollector, seedDocs, seedTopDocs.scoreDocs.length); + } + + private static class MappedDISI extends DocIdSetIterator { + KnnVectorValues.DocIndexIterator indexedDISI; + DocIdSetIterator sourceDISI; + + private MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) { + this.indexedDISI = indexedDISI; + this.sourceDISI = sourceDISI; + } + + /** + * Advances the source iterator to the first document number that is greater than or equal to + * the provided target and returns the corresponding index. + */ + @Override + public int advance(int target) throws IOException { + int newTarget = sourceDISI.advance(target); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + + @Override + public long cost() { + return sourceDISI.cost(); + } + + @Override + public int docID() { + if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return indexedDISI.index(); + } + + /** Advances to the next document in the source iterator and returns the corresponding index. */ + @Override + public int nextDoc() throws IOException { + int newTarget = sourceDISI.nextDoc(); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + } + + private static class TopDocsDISI extends DocIdSetIterator { + private final int[] sortedDocIds; + private int idx = -1; + + private TopDocsDISI(TopDocs topDocs) { + sortedDocIds = new int[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + sortedDocIds[i] = topDocs.scoreDocs[i].doc; + } + Arrays.sort(sortedDocIds); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return sortedDocIds.length; + } + + @Override + public int docID() { + if (idx == -1) { + return -1; + } else if (idx >= sortedDocIds.length) { + return DocIdSetIterator.NO_MORE_DOCS; + } else { + return sortedDocIds[idx]; + } + } + + @Override + public int nextDoc() { + idx += 1; + return docID(); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 46d6c93d52c3..e8f0d316fd81 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,8 +20,10 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.search.knn.EntryPointProvider; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -52,7 +54,9 @@ public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) { } /** - * Searches HNSW graph for the nearest neighbors of a query vector. + * Searches the HNSW graph for the nearest neighbors of a query vector. If entry points are + * directly provided via the knnCollector, then the search will be initialized at those points. + * Otherwise, the search will discover the best entry point per the normal HNSW search algorithm. * * @param scorer the scorer to compare the query with the nodes * @param knnCollector a collector of top knn results to be returned @@ -67,7 +71,30 @@ public static void search( HnswGraphSearcher graphSearcher = new HnswGraphSearcher( new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); - search(scorer, knnCollector, graph, graphSearcher, acceptOrds); + final int[] entryPoints; + if (knnCollector instanceof EntryPointProvider epp) { + if (epp.numberOfEntryPoints() <= 0) { + throw new IllegalArgumentException("The number of entry points must be > 0"); + } + DocIdSetIterator eps = epp.entryPoints(); + entryPoints = new int[epp.numberOfEntryPoints()]; + int idx = 0; + while (idx < entryPoints.length) { + int entryPointOrdInt = eps.nextDoc(); + if (entryPointOrdInt == NO_MORE_DOCS) { + throw new IllegalArgumentException( + "The number of entry points provided is less than the number of entry points requested"); + } + assert entryPointOrdInt < getGraphSize(graph); + entryPoints[idx++] = entryPointOrdInt; + } + // This is an invalid case, but we should check it + assert entryPoints.length > 0; + // We use provided entry point ordinals to search the complete graph (level 0) + graphSearcher.searchLevel(knnCollector, scorer, 0, entryPoints, graph, acceptOrds); + } else { + search(scorer, knnCollector, graph, graphSearcher, acceptOrds); + } } /** diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java index ed1a5ffb59fa..5225fe700ab9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java @@ -24,54 +24,24 @@ /** * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId */ -public final class OrdinalTranslatedKnnCollector implements KnnCollector { +public final class OrdinalTranslatedKnnCollector extends KnnCollector.Decorator { - private final KnnCollector in; private final IntToIntFunction vectorOrdinalToDocId; - public OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) { - this.in = in; + public OrdinalTranslatedKnnCollector( + KnnCollector collector, IntToIntFunction vectorOrdinalToDocId) { + super(collector); this.vectorOrdinalToDocId = vectorOrdinalToDocId; } - @Override - public boolean earlyTerminated() { - return in.earlyTerminated(); - } - - @Override - public void incVisitedCount(int count) { - in.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return in.visitedCount(); - } - - @Override - public long visitLimit() { - return in.visitLimit(); - } - - @Override - public int k() { - return in.k(); - } - @Override public boolean collect(int vectorId, float similarity) { - return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity); - } - - @Override - public float minCompetitiveSimilarity() { - return in.minCompetitiveSimilarity(); + return super.collect(vectorOrdinalToDocId.apply(vectorId), similarity); } @Override public TopDocs topDocs() { - TopDocs td = in.topDocs(); + TopDocs td = super.topDocs(); return new TopDocs( new TotalHits( visitedCount(), diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index 2023ee73391d..1e485515a62b 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import java.nio.file.Path; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; @@ -24,19 +25,27 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.SeededKnnFloatVectorQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.tests.codecs.vector.ConfigurableMCodec; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase.Monster; +import org.junit.BeforeClass; @TimeoutSuite(millis = 86_400_000) // 24 hour timeout @Monster("takes ~10 minutes and needs extra heap, disk space, file handles") public class TestManyKnnDocs extends LuceneTestCase { // gradlew -p lucene/core test --tests TestManyKnnDocs -Ptests.heapsize=16g -Dtests.monster=true - public void testLargeSegment() throws Exception { + private static Path testDir; + + @BeforeClass + public static void init_index() throws Exception { IndexWriterConfig iwc = new IndexWriterConfig(); iwc.setCodec( new ConfigurableMCodec( @@ -46,27 +55,138 @@ public void testLargeSegment() throws Exception { mp.setMaxMergeAtOnce(256); // avoid intermediate merges (waste of time with HNSW?) mp.setSegmentsPerTier(256); // only merge once at the end when we ask iwc.setMergePolicy(mp); - String fieldName = "field"; VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - try (Directory dir = FSDirectory.open(createTempDir("ManyKnnVectorDocs")); + try (Directory dir = FSDirectory.open(testDir = createTempDir("ManyKnnVectorDocs")); IndexWriter iw = new IndexWriter(dir, iwc)) { int numVectors = 2088992; - float[] vector = new float[1]; - Document doc = new Document(); - doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction)); for (int i = 0; i < numVectors; i++) { + float[] vector = new float[1]; + Document doc = new Document(); vector[0] = (i % 256); + doc.add(new KnnFloatVectorField("field", vector, similarityFunction)); + doc.add(new KeywordField("int", "" + i, org.apache.lucene.document.Field.Store.YES)); + doc.add(new StoredField("intValue", i)); iw.addDocument(doc); } // merge to single segment and then verify iw.forceMerge(1); iw.commit(); + } + } + + public void testLargeSegmentKnn() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); - TopDocs docs = searcher.search(new KnnFloatVectorQuery("field", new float[] {120}, 10), 5); - assertEquals(5, docs.scoreDocs.length); + for (int i = 0; i < 256; i++) { + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search(new KnnFloatVectorQuery("field", vector, 10, filterQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededExact() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + (i + 256)); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededNearby() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + i); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededDistant() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + (i + 128)); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededNone() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = new MatchNoDocsQuery(); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } } } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index b45d6e8fb641..21219e0e1d99 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -61,7 +61,7 @@ Field getKnnVectorField(String name, float[] vector) { return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); } - private static byte[] floatToBytes(float[] query) { + static byte[] floatToBytes(float[] query) { byte[] bytes = new byte[query.length]; for (int i = 0; i < query.length; i++) { assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0 @@ -109,7 +109,7 @@ public void testVectorEncodingMismatch() throws IOException { } } - private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { + static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { super(field, target, k, filter); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 5dcb6f97df93..ece2b385654e 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -259,7 +259,7 @@ public void testDocAndScoreQueryBasics() throws IOException { } } - private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { + static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { super(field, target, k, filter); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java new file mode 100644 index 000000000000..d0fb8c95e035 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; + +public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { + + private static final Query MATCH_NONE = new MatchNoDocsQuery(); + + @Override + AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return new SeededKnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, MATCH_NONE); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, MATCH_NONE); + } + + @Override + float[] randomVector(int dim) { + byte[] b = TestVectorUtil.randomVectorBytes(dim); + float[] v = new float[b.length]; + int vi = 0; + for (int i = 0; i < v.length; i++) { + v[vi++] = b[i]; + } + return v; + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); + } + + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; + } + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + SeededKnnByteVectorQuery query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, filter, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), numDocsWithVector); + + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new MatchNoDocsQuery(); + query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private static class ThrowingKnnVectorQuery extends SeededKnnByteVectorQuery { + + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); + } + + private ThrowingKnnVectorQuery( + String field, byte[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter, seedWeight); + } + + @Override + // This is test only and we need to overwrite the inner rewrite to throw + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + return new ThrowingKnnVectorQuery(field, target, k, filter, seedWeight) + .rewrite(indexSearcher); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java new file mode 100644 index 000000000000..d5630037ef74 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; + +public class TestSeededKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { + private static final Query MATCH_NONE = new MatchNoDocsQuery(); + + @Override + KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return new SeededKnnFloatVectorQuery(field, query, k, queryFilter, MATCH_NONE); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return new ThrowingKnnVectorQuery(field, vec, k, query, MATCH_NONE); + } + + @Override + float[] randomVector(int dim) { + return TestVectorUtil.randomVector(dim); + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnFloatVectorField(name, vector, similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnFloatVectorField(name, vector); + } + + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; + } + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + AbstractKnnVectorQuery query = + new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, filter, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), numDocsWithVector); + + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new MatchNoDocsQuery(); + query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private static class ThrowingKnnVectorQuery extends SeededKnnFloatVectorQuery { + + private ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); + } + + private ThrowingKnnVectorQuery( + String field, float[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter, seedWeight); + } + + @Override + // This is test only and we need to overwrite the inner rewrite to throw + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + return new ThrowingKnnVectorQuery(field, target, k, filter, seedWeight) + .rewrite(indexSearcher); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +}