Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrent rewrite for KnnVectorQuery #12160

Merged
merged 5 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ Optimizations

* GITHUB#11857, GITHUB#11859, GITHUB#11893, GITHUB#11909: Hunspell: improved suggestion performance (Peter Gromov)

* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)

Bug Fixes
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.FutureTask;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.ThreadInterruptedException;

/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
Expand Down Expand Up @@ -62,9 +67,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];

Weight filterWeight = null;
final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
Expand All @@ -73,17 +77,16 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}

for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
Executor executor = indexSearcher.getExecutor();
TopDocs[] perLeafResults =
(executor == null)
? sequentialSearch(reader.leaves(), filterWeight)
: parallelSearch(reader.leaves(), filterWeight, executor);

// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
Expand All @@ -92,7 +95,54 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return createRewrittenQuery(reader, topK);
}

private TopDocs[] sequentialSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
try {
TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
for (LeafReaderContext ctx : leafReaderContexts) {
perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
}
return perLeafResults;
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private TopDocs[] parallelSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
List<FutureTask<TopDocs>> tasks =
leafReaderContexts.stream()
.map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
.toList();

SliceExecutor sliceExecutor = new SliceExecutor(executor);
sliceExecutor.invokeAll(tasks);

return tasks.stream()
.map(
task -> {
try {
return task.get();
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
} catch (InterruptedException e) {
throw new ThreadInterruptedException(e);
}
})
.toArray(TopDocs[]::new);
}

private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}

private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
expectThrows(
RuntimeException.class,
IllegalArgumentException.class,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since IAE extends RuntimeExeption, it should be good to just do expectThrows(RuntimeException.class, runnable)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to preserve the original functionality of the testcase: Checking for illegal arguments
If we only check for the outer class, it may be possible that some other exception was thrown inside (maybe RuntimeException(NullPointerException)), but the test still passed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, that makes sense, I had made a wrong assumption about what expectThrows does with two exception types!

() -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
Expand Down Expand Up @@ -495,6 +498,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand All @@ -509,6 +513,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -536,6 +541,7 @@ public void testRandomWithFilter() throws IOException {
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -708,6 +714,7 @@ public void testBitSetQuery() throws IOException {

Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down