Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrent rewrite for KnnVectorQuery #12160

Merged
merged 5 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.function.Supplier;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
Expand Down Expand Up @@ -62,9 +66,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];

Weight filterWeight = null;
final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
Expand All @@ -73,17 +76,41 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}

for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
TopDocs[] perLeafResults =
reader.leaves().stream()
.map(
ctx -> {
Supplier<TopDocs> supplier =
() -> {
try {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
} catch (Exception e) {
throw new CompletionException(e);
}
};

Executor executor = indexSearcher.getExecutor();
if (executor == null) {
return CompletableFuture.completedFuture(supplier.get());
} else {
return CompletableFuture.supplyAsync(supplier, executor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In IndexSearcher we're using SliceExecutor to make sure the main thread is also doing some work but not only wait for joining.
I think we can replicate the same logic here? (Since KNN search is likely to be slow so probably the main thread should do some work as well?)

Maybe we can just use the SliceExecutor from IndexSearcher so that it might also kind of solving the load balancing problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the input! This was helpful in reducing latency from thread switching further

}
})
.toList()
.stream()
.map(CompletableFuture::join)
.toArray(TopDocs[]::new);

// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletionException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntPoint;
Expand Down Expand Up @@ -210,7 +211,10 @@ public void testDimensionMismatch() throws IOException {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
expectThrows(
CompletionException.class,
IllegalArgumentException.class,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since IAE extends RuntimeExeption, it should be good to just do expectThrows(RuntimeException.class, runnable)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to preserve the original functionality of the testcase: Checking for illegal arguments
If we only check for the outer class, it may be possible that some other exception was thrown inside (maybe RuntimeException(NullPointerException)), but the test still passed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, that makes sense, I had made a wrong assumption about what expectThrows does with two exception types!

() -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
Expand Down Expand Up @@ -495,6 +499,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
CompletionException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand All @@ -509,6 +514,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
CompletionException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -536,6 +542,7 @@ public void testRandomWithFilter() throws IOException {
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
CompletionException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -708,6 +715,7 @@ public void testBitSetQuery() throws IOException {

Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
CompletionException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down