diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 89d645fa9bf9..77cc97287bcf 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -36,6 +36,8 @@ Improvements * GITHUB#12305: Minor cleanup and improvements to DaciukMihovAutomatonBuilder. (Greg Miller) +* GITHUB#12325: Parallelize AbstractKnnVectorQuery rewrite across slices rather than segments. (Luca Cavanna) + Optimizations --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index d6b9e04b542f..01c12e349e5b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -19,14 +19,13 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.List; import java.util.Objects; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; import java.util.concurrent.FutureTask; -import java.util.stream.Collectors; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; @@ -82,11 +81,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { filterWeight = null; } - Executor executor = indexSearcher.getExecutor(); + SliceExecutor sliceExecutor = indexSearcher.getSliceExecutor(); + // in case of parallel execution, the leaf results are not ordered by leaf context's ordinal TopDocs[] perLeafResults = - (executor == null) + (sliceExecutor == null) ? sequentialSearch(reader.leaves(), filterWeight) - : parallelSearch(reader.leaves(), filterWeight, executor); + : parallelSearch(indexSearcher.getSlices(), filterWeight, sliceExecutor); // Merge sort the results TopDocs topK = TopDocs.merge(k, perLeafResults); @@ -110,27 +110,40 @@ private TopDocs[] sequentialSearch( } private TopDocs[] parallelSearch( - List leafReaderContexts, Weight filterWeight, Executor executor) { - List> tasks = - leafReaderContexts.stream() - .map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight))) - .collect(Collectors.toList()); + IndexSearcher.LeafSlice[] slices, Weight filterWeight, SliceExecutor sliceExecutor) { + + List> tasks = new ArrayList<>(slices.length); + int segmentsCount = 0; + for (IndexSearcher.LeafSlice slice : slices) { + segmentsCount += slice.leaves.length; + tasks.add( + new FutureTask<>( + () -> { + TopDocs[] results = new TopDocs[slice.leaves.length]; + int i = 0; + for (LeafReaderContext context : slice.leaves) { + results[i++] = searchLeaf(context, filterWeight); + } + return results; + })); + } - SliceExecutor sliceExecutor = new SliceExecutor(executor); sliceExecutor.invokeAll(tasks); - return tasks.stream() - .map( - task -> { - try { - return task.get(); - } catch (ExecutionException e) { - throw new RuntimeException(e.getCause()); - } catch (InterruptedException e) { - throw new ThreadInterruptedException(e); - } - }) - .toArray(TopDocs[]::new); + TopDocs[] topDocs = new TopDocs[segmentsCount]; + int i = 0; + for (FutureTask task : tasks) { + try { + for (TopDocs docs : task.get()) { + topDocs[i++] = docs; + } + } catch (ExecutionException e) { + throw new RuntimeException(e.getCause()); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + } + return topDocs; } private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index ce93bd9cef83..f167b7161123 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -998,6 +998,10 @@ public Executor getExecutor() { return executor; } + SliceExecutor getSliceExecutor() { + return sliceExecutor; + } + /** * Thrown when an attempt is made to add more than {@link #getMaxClauseCount()} clauses. This * typically happens if a PrefixQuery, FuzzyQuery, WildcardQuery, or TermRangeQuery is expanded to