diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 2f3a572a1bf9..8875a3f09532 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -25,6 +25,8 @@ Optimizations * GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun) +* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh) + Bug Fixes --------------------- (No changes) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 9403a6413e24..d6b9e04b542f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -21,7 +21,12 @@ import java.io.IOException; import java.util.Arrays; import java.util.Comparator; +import java.util.List; import java.util.Objects; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.FutureTask; +import java.util.stream.Collectors; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; @@ -29,6 +34,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.ThreadInterruptedException; /** * Uses {@link KnnVectorsReader#search} to perform nearest neighbour search. @@ -62,9 +68,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) { @Override public Query rewrite(IndexSearcher indexSearcher) throws IOException { IndexReader reader = indexSearcher.getIndexReader(); - TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()]; - Weight filterWeight = null; + final Weight filterWeight; if (filter != null) { BooleanQuery booleanQuery = new BooleanQuery.Builder() @@ -73,17 +78,16 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { .build(); Query rewritten = indexSearcher.rewrite(booleanQuery); filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); + } else { + filterWeight = null; } - for (LeafReaderContext ctx : reader.leaves()) { - TopDocs results = searchLeaf(ctx, filterWeight); - if (ctx.docBase > 0) { - for (ScoreDoc scoreDoc : results.scoreDocs) { - scoreDoc.doc += ctx.docBase; - } - } - perLeafResults[ctx.ord] = results; - } + Executor executor = indexSearcher.getExecutor(); + TopDocs[] perLeafResults = + (executor == null) + ? sequentialSearch(reader.leaves(), filterWeight) + : parallelSearch(reader.leaves(), filterWeight, executor); + // Merge sort the results TopDocs topK = TopDocs.merge(k, perLeafResults); if (topK.scoreDocs.length == 0) { @@ -92,7 +96,54 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return createRewrittenQuery(reader, topK); } + private TopDocs[] sequentialSearch( + List leafReaderContexts, Weight filterWeight) { + try { + TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()]; + for (LeafReaderContext ctx : leafReaderContexts) { + perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight); + } + return perLeafResults; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private TopDocs[] parallelSearch( + List leafReaderContexts, Weight filterWeight, Executor executor) { + List> tasks = + leafReaderContexts.stream() + .map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight))) + .collect(Collectors.toList()); + + SliceExecutor sliceExecutor = new SliceExecutor(executor); + sliceExecutor.invokeAll(tasks); + + return tasks.stream() + .map( + task -> { + try { + return task.get(); + } catch (ExecutionException e) { + throw new RuntimeException(e.getCause()); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + }) + .toArray(TopDocs[]::new); + } + private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException { + TopDocs results = getLeafResults(ctx, filterWeight); + if (ctx.docBase > 0) { + for (ScoreDoc scoreDoc : results.scoreDocs) { + scoreDoc.doc += ctx.docBase; + } + } + return results; + } + + private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException { Bits liveDocs = ctx.reader().getLiveDocs(); int maxDoc = ctx.reader().maxDoc(); diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index af91b19444b6..24de1a463c73 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException { IndexSearcher searcher = newSearcher(reader); AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10); IllegalArgumentException e = - expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10)); + expectThrows( + RuntimeException.class, + IllegalArgumentException.class, + () -> searcher.search(kvq, 10)); assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage()); } } @@ -529,6 +532,7 @@ public void testRandomWithFilter() throws IOException { assertEquals(9, results.totalHits.value); assertEquals(results.totalHits.value, results.scoreDocs.length); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -543,6 +547,7 @@ public void testRandomWithFilter() throws IOException { assertEquals(5, results.totalHits.value); assertEquals(results.totalHits.value, results.scoreDocs.length); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -570,6 +575,7 @@ public void testRandomWithFilter() throws IOException { // Test a filter that exhausts visitedLimit in upper levels, and switches to exact search Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -742,6 +748,7 @@ public void testBitSetQuery() throws IOException { Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs)); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search(