Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrent rewrite for KnnVectorQuery #12160

Merged
merged 5 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.FutureTask;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
Expand Down Expand Up @@ -62,9 +67,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];

Weight filterWeight = null;
final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
Expand All @@ -73,17 +77,48 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}

for (LeafReaderContext ctx : reader.leaves()) {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
perLeafResults[ctx.ord] = results;
}
List<FutureTask<TopDocs>> tasks =
reader.leaves().stream()
.map(
ctx ->
new FutureTask<>(
() -> {
try {
TopDocs results = searchLeaf(ctx, filterWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}))
.toList();

Executor executor = Objects.requireNonNullElse(indexSearcher.getExecutor(), Runnable::run);
SliceExecutor sliceExecutor = new SliceExecutor(executor);
sliceExecutor.invokeAll(tasks);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this change, but it seems to be we have a simple optimization opportunity here.

We shouldn't bother with any parallelism if indexSearcher.getExecutor() == null || reader.leaves().size() <= 1. Its a simple if branch that allows us to remove all the overhead associated with parallel rewrites when no parallelism can be achieved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is right. Thinking out loud, I would assume that users who leverage IndexSearcher concurrency generally have two thread pools, one that they pass to the IndexSearcher constructor that they expect to do the heavy work, and another one, which is the one where IndexSearcher#search is called, that mostly handles coordination and lightweight work such as merging top hits coming from different shards but generally spends most of its time waiting for work to complete in the other threadpool. Your optimization suggestion boils dow to running some heavy work (a vector search) in the coordinating threadpool when there is a single segment. If heavy work may happen in either threadpool, this makes sizing these threadpools complicated, as either you allocate num_cores threads to the threadpool that does heavy work but then you may end up with more than num_cores threads doing heavy work because some heavy work also happens in the coordinating threadpool., or you allocate less than num_cores threads but then you might not use all your hardware?

That said, your suggestion aligns with how IndexSearcher currently works, so maybe we should apply it for now and discuss in a follow-up issue whether we should also delegate to the executor when there is a single segment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, your suggestion aligns with how IndexSearcher currently works, so maybe we should apply it for now and discuss in a follow-up issue whether we should also delegate to the executor when there is a single segment.

I am fine with that.

I think also that we could only check indexSearcher.getExecutor() == null instead of making a decision for the caller regarding the number of leaves.

So, I would say for now only check if indexSearcher.getExecutor() == null and if it is, do it the old way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaivalnp could you (instead of using Runnable::run) just do the regular loop as it was previously if indexSearcher.getExecutor() == null?

If getExecutor() is not null, we should assume the caller wants it used. @jpountz is correct there.

Copy link
Contributor Author

@kaivalnp kaivalnp Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't bother with any parallelism if indexSearcher.getExecutor() == null || reader.leaves().size() <= 1. Its a simple if branch that allows us to remove all the overhead associated with parallel rewrites when no parallelism can be achieved.

I would totally agree with you here! We shouldn't add overhead for non-concurrent executions
If I understand correctly, you are suggesting to add an if block with the condition:

I have tried to implement the changes here. I ran some benchmarks for these (with the executor as null):

enwiki (topK = 100, segment count = 10, executor = null)

recall Sequential SliceExecutor ReduceOverhead nDoc fanout maxConn beamWidth
0.995 0.95 0.96 0.95 10000 0 16 32
0.998 1.26 1.30 1.29 10000 50 16 32
0.998 1.05 1.07 1.07 10000 0 16 64
0.999 1.41 1.43 1.43 10000 50 16 64
0.995 0.98 0.99 0.98 10000 0 32 32
0.998 1.31 1.33 1.34 10000 50 32 32
0.998 0.99 1.01 1.01 10000 0 32 64
0.999 1.33 1.36 1.36 10000 50 32 64
0.987 1.70 1.70 1.71 100000 0 16 32
0.992 2.30 2.30 2.31 100000 50 16 32
0.993 1.92 1.89 1.94 100000 0 16 64
0.996 2.63 2.65 2.64 100000 50 16 64
0.987 1.73 1.70 1.74 100000 0 32 32
0.992 2.34 2.30 2.37 100000 50 32 32
0.994 1.96 1.92 1.98 100000 0 32 64
0.997 2.66 2.61 2.69 100000 50 32 64
0.971 2.72 2.70 2.74 1000000 0 16 32
0.982 3.77 3.79 3.78 1000000 50 16 32
0.985 3.13 3.19 3.19 1000000 0 16 64
0.991 4.34 4.37 4.36 1000000 50 16 64
0.973 2.86 2.94 2.94 1000000 0 32 32
0.983 3.94 3.98 3.97 1000000 50 32 32
0.986 3.38 3.37 3.38 1000000 0 32 64
0.992 4.63 4.66 4.67 1000000 50 32 64

enwiki (topK = 100, segment count = 5, executor = null)

recall Sequential SliceExecutor ReduceOverhead nDoc fanout maxConn beamWidth
0.991 0.59 0.61 0.59 10000 0 16 32
0.996 0.82 0.83 0.81 10000 50 16 32
0.997 0.61 0.62 0.60 10000 0 16 64
0.999 0.88 0.88 0.86 10000 50 16 64
0.991 0.59 0.59 0.58 10000 0 32 32
0.995 0.80 0.81 0.80 10000 50 32 32
0.997 0.64 0.64 0.62 10000 0 32 64
0.999 0.87 0.88 0.89 10000 50 32 64
0.978 1.09 1.08 1.08 100000 0 16 32
0.987 1.29 1.32 1.34 100000 50 16 32
0.989 1.10 1.09 1.10 100000 0 16 64
0.994 1.48 1.49 1.46 100000 50 16 64
0.977 0.98 0.99 0.98 100000 0 32 32
0.987 1.33 1.35 1.34 100000 50 32 32
0.989 1.13 1.14 1.13 100000 0 32 64
0.994 1.55 1.55 1.53 100000 50 32 64
0.957 1.48 1.52 1.49 1000000 0 16 32
0.972 2.03 2.08 2.04 1000000 50 16 32
0.976 1.70 1.73 1.71 1000000 0 16 64
0.985 2.42 2.45 2.47 1000000 50 16 64
0.959 1.67 1.65 1.66 1000000 0 32 32
0.974 2.13 2.15 2.16 1000000 50 32 32
0.978 1.89 1.84 1.89 1000000 0 32 64
0.987 2.52 2.53 2.55 1000000 50 32 64

enwiki (topK = 100, segment count = 1, executor = null)

recall Sequential SliceExecutor ReduceOverhead nDoc fanout maxConn beamWidth
0.941 0.22 0.21 0.24 10000 0 16 32
0.970 0.24 0.24 0.25 10000 50 16 32
0.965 0.20 0.19 0.20 10000 0 16 64
0.984 0.28 0.27 0.28 10000 50 16 64
0.941 0.18 0.17 0.18 10000 0 32 32
0.970 0.24 0.23 0.23 10000 50 32 32
0.966 0.20 0.20 0.20 10000 0 32 64
0.985 0.28 0.27 0.26 10000 50 32 64
0.909 0.27 0.27 0.27 100000 0 16 32
0.945 0.38 0.36 0.37 100000 50 16 32
0.944 0.32 0.30 0.30 100000 0 16 64
0.969 0.43 0.41 0.42 100000 50 16 64
0.914 0.28 0.28 0.29 100000 0 32 32
0.948 0.39 0.38 0.38 100000 50 32 32
0.949 0.30 0.30 0.32 100000 0 32 64
0.972 0.44 0.41 0.40 100000 50 32 64
0.870 0.35 0.34 0.35 1000000 0 16 32
0.911 0.49 0.48 0.47 1000000 50 16 32
0.913 0.40 0.40 0.41 1000000 0 16 64
0.945 0.55 0.55 0.56 1000000 50 16 64
0.881 0.38 0.39 0.38 1000000 0 32 32
0.919 0.52 0.52 0.52 1000000 50 32 32
0.923 0.45 0.45 0.46 1000000 0 32 64
0.954 0.62 0.62 0.61 1000000 50 32 64

There are a few places where it gives some speedup, but this seems to be too low (Note that there is also some logic duplication here and here, which we would want to avoid, maybe by wrapping it in a callable. I tried that out locally and it was performing similar to worse)

In the absence of an executor, we are setting it to Runnable::run, which performs the same tasks sequentially. My guess would be that its overhead is much lower compared to the search tasks, and IMO the readability earlier outweighs the separate if block

Please let me know what you feel / if you had something else in mind?

Edit: Sorry, links in this comment are now broken because they pointed to specific lines at the time of writing. Now that the underlying branch is updated, links point to unrelated places

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there is also some logic duplication

Duplication of logic right next to each other is fine (IMO). I would keep it simple and duplicate those 4 lines.

I would also change the if statement to only be if(executor == null).

I think the minor if statement is worth it. It creates fewer objects and is a simpler function. It might be more readable if you broke the results gathering into individual private methods.

TopDocs[] gatherPerLeafResults(List<LeafReaderContext>,Weight)

TopDocs[] gatherPerLeafResults(List<LeafReaderContext>,Weight,Executor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the logic duplication, it just updated the doc ids (by adding ctx.docBase to get index-level doc ids): and I put it in a separate function

I think the minor if statement is worth it. It creates fewer objects and is a simpler function. It might be more readable if you broke the results gathering into individual private methods.

Here are the sample changes, please let me know if these look good: and I'll commit it in this PR

Note that I had to wrap the sequential execution in a try - catch, and wrap exceptions in a RuntimeException for consistency with exceptions thrown during parallel execution (also to pass test cases)

Copy link
Member

@benwtrent benwtrent Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the logic duplication

I wouldn't worry about that. That makes things even more difficult to reason about. I would much rather have a method that takes in the filter weight and leaf contexts and one that takes the same parameters but with an added Executor.

One called where indexSearcher.getExecutor() == null and the other when the executor is provided.
Two methods like this:

  private TopDocs[] gatherLeafResults(
      List<LeafReaderContext> leafReaderContexts, Weight filterWeight) throws IOException {
    TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
    for (LeafReaderContext ctx : leafReaderContexts) {
      TopDocs results = searchLeaf(ctx, filterWeight);
      if (ctx.docBase > 0) {
        for (ScoreDoc scoreDoc : results.scoreDocs) {
          scoreDoc.doc += ctx.docBase;
        }
      }
      perLeafResults[ctx.ord] = results;
    }
    return perLeafResults;
  }

  private TopDocs[] gatherLeafResults(
      List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
    List<FutureTask<TopDocs>> tasks =
        leafReaderContexts.stream()
            .map(
                ctx ->
                    new FutureTask<>(
                        () -> {
                          TopDocs results = searchLeaf(ctx, filterWeight);
                          if (ctx.docBase > 0) {
                            for (ScoreDoc scoreDoc : results.scoreDocs) {
                              scoreDoc.doc += ctx.docBase;
                            }
                          }
                          return results;
                        }))
            .toList();

    SliceExecutor sliceExecutor = new SliceExecutor(executor);
    sliceExecutor.invokeAll(tasks);

    return tasks.stream()
        .map(
            task -> {
              try {
                return task.get();
              } catch (ExecutionException e) {
                throw new RuntimeException(e.getCause());
              } catch (InterruptedException e) {
                throw new ThreadInterruptedException(e);
              }
            })
        .toArray(TopDocs[]::new);
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the input!
This was really helpful in reducing overhead for non-concurrent search, and improving readability!


TopDocs[] perLeafResults =
tasks.stream()
.map(
task -> {
try {
return task.get();
} catch (ExecutionException e) {
throw new RuntimeException(e.getCause());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not confident it's safe to swallow the root exception and only report the cause? Would it work to throw new RuntimeException(e)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think any exception thrown during the thread's execution is (always?) wrapped in an ExecutionException

I mainly used getCause for two reasons:

  • We would always have to throw a RuntimeException(ExecutionException(actual Throwable)), and the ExecutionException might be redundant there
  • LuceneTestCase (currently) allows checking at most two wrapped exceptions, and the one above would have three

However, I don't have any strong opinions on this and can write a function to check for three nested exceptions as well. Please let me know what you feel, and I'll update it accordingly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank for pointing to that javadoc, it makes sense to me. Maybe we should consider doing the same in IndexSearcher where we also catch ExecutionException (in a follow-up).

} catch (InterruptedException e) {
throw new RuntimeException(e);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you throw ThreadInterruptedException instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! This is more suitable

}
})
.toArray(TopDocs[]::new);

// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
expectThrows(
RuntimeException.class,
IllegalArgumentException.class,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since IAE extends RuntimeExeption, it should be good to just do expectThrows(RuntimeException.class, runnable)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to preserve the original functionality of the testcase: Checking for illegal arguments
If we only check for the outer class, it may be possible that some other exception was thrown inside (maybe RuntimeException(NullPointerException)), but the test still passed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, that makes sense, I had made a wrong assumption about what expectThrows does with two exception types!

() -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
Expand Down Expand Up @@ -495,6 +498,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand All @@ -509,6 +513,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -536,6 +541,7 @@ public void testRandomWithFilter() throws IOException {
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down Expand Up @@ -708,6 +714,7 @@ public void testBitSetQuery() throws IOException {

Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
Expand Down