Skip to content

Commit

Permalink
Concurrent rewrite for KnnVectorQuery (apache#12160)
Browse files Browse the repository at this point in the history
- Reduce overhead of non-concurrent search by preserving original execution
- Improve readability by factoring into separate functions

---------

Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
  • Loading branch information
2 people authored and benwtrent committed May 12, 2023
1 parent ba91ef9 commit a2f70d3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Optimizations

* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun)

* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)

Bug Fixes
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.FutureTask;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.ThreadInterruptedException;

/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
Expand Down Expand Up @@ -62,9 +67,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];

Weight filterWeight = null;
final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
Expand All @@ -73,17 +77,16 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}

for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
Executor executor = indexSearcher.getExecutor();
TopDocs[] perLeafResults =
(executor == null)
? sequentialSearch(reader.leaves(), filterWeight)
: parallelSearch(reader.leaves(), filterWeight, executor);

// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
Expand All @@ -92,7 +95,54 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return createRewrittenQuery(reader, topK);
}

private TopDocs[] sequentialSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
try {
TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
for (LeafReaderContext ctx : leafReaderContexts) {
perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
}
return perLeafResults;
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private TopDocs[] parallelSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
List<FutureTask<TopDocs>> tasks =
leafReaderContexts.stream()
.map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
.toList();

SliceExecutor sliceExecutor = new SliceExecutor(executor);
sliceExecutor.invokeAll(tasks);

return tasks.stream()
.map(
task -> {
try {
return task.get();
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
} catch (InterruptedException e) {
throw new ThreadInterruptedException(e);
}
})
.toArray(TopDocs[]::new);
}

private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}

private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
expectThrows(
RuntimeException.class,
IllegalArgumentException.class,
() -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
Expand Down Expand Up @@ -529,6 +532,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand All @@ -543,6 +547,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -570,6 +575,7 @@ public void testRandomWithFilter() throws IOException {
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -742,6 +748,7 @@ public void testBitSetQuery() throws IOException {

Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down

0 comments on commit a2f70d3

Please sign in to comment.