From e0d92eef988cfa9a2f1031432b673bbd1bcc15d5 Mon Sep 17 00:00:00 2001
From: Kaival Parikh <46070017+kaivalnp@users.noreply.github.com>
Date: Sat, 4 Mar 2023 14:42:11 +0530
Subject: [PATCH] Concurrent rewrite for KnnVectorQuery (#12160)

- Reduce overhead of non-concurrent search by preserving original execution
- Improve readability by factoring into separate functions

---------

Co-authored-by: Kaival Parikh <kaivalp2000@gmail.com>
---
 lucene/CHANGES.txt                            |  2 +
 .../lucene/search/AbstractKnnVectorQuery.java | 72 ++++++++++++++++---
 .../search/BaseKnnVectorQueryTestCase.java    |  9 ++-
 3 files changed, 71 insertions(+), 12 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 799b4e8e4d28..4e86d5aeff71 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -70,6 +70,8 @@ Optimizations
 
 * GITHUB#11857, GITHUB#11859, GITHUB#11893, GITHUB#11909: Hunspell: improved suggestion performance (Peter Gromov)
 
+* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)
+
 Bug Fixes
 ---------------------
 
diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index 39c2d34ffd2c..04309561cbdd 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -21,7 +21,11 @@
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Comparator;
+import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.FutureTask;
 import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.index.FieldInfo;
 import org.apache.lucene.index.IndexReader;
@@ -29,6 +33,7 @@
 import org.apache.lucene.util.BitSet;
 import org.apache.lucene.util.BitSetIterator;
 import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.ThreadInterruptedException;
 
 /**
  * Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@@ -62,9 +67,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
   @Override
   public Query rewrite(IndexSearcher indexSearcher) throws IOException {
     IndexReader reader = indexSearcher.getIndexReader();
-    TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
 
-    Weight filterWeight = null;
+    final Weight filterWeight;
     if (filter != null) {
       BooleanQuery booleanQuery =
           new BooleanQuery.Builder()
@@ -73,17 +77,16 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
               .build();
       Query rewritten = indexSearcher.rewrite(booleanQuery);
       filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
+    } else {
+      filterWeight = null;
     }
 
-    for (LeafReaderContext ctx : reader.leaves()) {
-      TopDocs results = searchLeaf(ctx, filterWeight);
-      if (ctx.docBase > 0) {
-        for (ScoreDoc scoreDoc : results.scoreDocs) {
-          scoreDoc.doc += ctx.docBase;
-        }
-      }
-      perLeafResults[ctx.ord] = results;
-    }
+    Executor executor = indexSearcher.getExecutor();
+    TopDocs[] perLeafResults =
+        (executor == null)
+            ? sequentialSearch(reader.leaves(), filterWeight)
+            : parallelSearch(reader.leaves(), filterWeight, executor);
+
     // Merge sort the results
     TopDocs topK = TopDocs.merge(k, perLeafResults);
     if (topK.scoreDocs.length == 0) {
@@ -92,7 +95,54 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
     return createRewrittenQuery(reader, topK);
   }
 
+  private TopDocs[] sequentialSearch(
+      List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
+    try {
+      TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
+      for (LeafReaderContext ctx : leafReaderContexts) {
+        perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
+      }
+      return perLeafResults;
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private TopDocs[] parallelSearch(
+      List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
+    List<FutureTask<TopDocs>> tasks =
+        leafReaderContexts.stream()
+            .map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
+            .toList();
+
+    SliceExecutor sliceExecutor = new SliceExecutor(executor);
+    sliceExecutor.invokeAll(tasks);
+
+    return tasks.stream()
+        .map(
+            task -> {
+              try {
+                return task.get();
+              } catch (ExecutionException e) {
+                throw new RuntimeException(e.getCause());
+              } catch (InterruptedException e) {
+                throw new ThreadInterruptedException(e);
+              }
+            })
+        .toArray(TopDocs[]::new);
+  }
+
   private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
+    TopDocs results = getLeafResults(ctx, filterWeight);
+    if (ctx.docBase > 0) {
+      for (ScoreDoc scoreDoc : results.scoreDocs) {
+        scoreDoc.doc += ctx.docBase;
+      }
+    }
+    return results;
+  }
+
+  private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
     Bits liveDocs = ctx.reader().getLiveDocs();
     int maxDoc = ctx.reader().maxDoc();
 
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index dbb13d9f058f..f9356ae0bc16 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException {
       IndexSearcher searcher = newSearcher(reader);
       AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
       IllegalArgumentException e =
-          expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
+          expectThrows(
+              RuntimeException.class,
+              IllegalArgumentException.class,
+              () -> searcher.search(kvq, 10));
       assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
     }
   }
@@ -495,6 +498,7 @@ public void testRandomWithFilter() throws IOException {
           assertEquals(9, results.totalHits.value);
           assertEquals(results.totalHits.value, results.scoreDocs.length);
           expectThrows(
+              RuntimeException.class,
               UnsupportedOperationException.class,
               () ->
                   searcher.search(
@@ -509,6 +513,7 @@ public void testRandomWithFilter() throws IOException {
           assertEquals(5, results.totalHits.value);
           assertEquals(results.totalHits.value, results.scoreDocs.length);
           expectThrows(
+              RuntimeException.class,
               UnsupportedOperationException.class,
               () ->
                   searcher.search(
@@ -536,6 +541,7 @@ public void testRandomWithFilter() throws IOException {
           // Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
           Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
           expectThrows(
+              RuntimeException.class,
               UnsupportedOperationException.class,
               () ->
                   searcher.search(
@@ -708,6 +714,7 @@ public void testBitSetQuery() throws IOException {
 
         Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
         expectThrows(
+            RuntimeException.class,
             UnsupportedOperationException.class,
             () ->
                 searcher.search(