Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport: Concurrent rewrite for KnnVectorQuery (#12160) #12288

Merged
merged 2 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Optimizations

* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun)

* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)

Bug Fixes
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.FutureTask;
import java.util.stream.Collectors;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.ThreadInterruptedException;

/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
Expand Down Expand Up @@ -62,9 +68,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];

Weight filterWeight = null;
final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
Expand All @@ -73,17 +78,16 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}

for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
Executor executor = indexSearcher.getExecutor();
TopDocs[] perLeafResults =
(executor == null)
? sequentialSearch(reader.leaves(), filterWeight)
: parallelSearch(reader.leaves(), filterWeight, executor);

// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
Expand All @@ -92,7 +96,54 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return createRewrittenQuery(reader, topK);
}

private TopDocs[] sequentialSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
try {
TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
for (LeafReaderContext ctx : leafReaderContexts) {
perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
}
return perLeafResults;
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private TopDocs[] parallelSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
List<FutureTask<TopDocs>> tasks =
leafReaderContexts.stream()
.map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
.collect(Collectors.toList());

SliceExecutor sliceExecutor = new SliceExecutor(executor);
sliceExecutor.invokeAll(tasks);

return tasks.stream()
.map(
task -> {
try {
return task.get();
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
} catch (InterruptedException e) {
throw new ThreadInterruptedException(e);
}
})
.toArray(TopDocs[]::new);
}

private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}

private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
expectThrows(
RuntimeException.class,
IllegalArgumentException.class,
() -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
Expand Down Expand Up @@ -529,6 +532,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand All @@ -543,6 +547,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -570,6 +575,7 @@ public void testRandomWithFilter() throws IOException {
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -742,6 +748,7 @@ public void testBitSetQuery() throws IOException {

Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down