From 4d26cb2219cedeaad5b47e440a293e78a1d1d740 Mon Sep 17 00:00:00 2001 From: Adrien Grand Date: Fri, 11 Aug 2023 22:37:37 +0200 Subject: [PATCH] Optimize disjunction counts. (#12415) This introduces `LeafCollector#collect(DocIdStream)` to enable collectors to collect batches of doc IDs at once. `BooleanScorer` takes advantage of this by creating a `DocIdStream` whose `count()` method counts the number of bits that are set in the bit set of matches in the current window, instead of naively iterating over all matches. On wikimedium10m, this yields a ~20% speedup when counting hits for the `title OR 12` query (2.9M hits). Relates #12358 --- lucene/CHANGES.txt | 2 + .../apache/lucene/search/BooleanScorer.java | 81 ++++++++++++------- .../lucene/search/CheckedIntConsumer.java | 31 +++++++ .../org/apache/lucene/search/DocIdStream.java | 45 +++++++++++ .../apache/lucene/search/LeafCollector.java | 23 ++++++ .../apache/lucene/search/MultiCollector.java | 1 + .../lucene/search/TotalHitCountCollector.java | 5 ++ .../tests/search/AssertingLeafCollector.java | 39 +++++++++ .../lucene/tests/search/QueryUtils.java | 69 ++++++++++++++++ 9 files changed, 266 insertions(+), 30 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/CheckedIntConsumer.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/DocIdStream.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 13fd71a2b881..cb5c1173042e 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -91,6 +91,8 @@ Optimizations * GITHUB#12408: Lazy initialization improvements for Facets implementations when there are segments with no hits to count. (Greg Miller) +* GITHUB#12415: Optimized counts on disjunctive queries. (Adrien Grand) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java index 1648ecc5f8a2..6ab767f4ee81 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java @@ -109,6 +109,8 @@ public BulkScorerAndDoc get(int i) { } } + // One bucket per doc ID in the window, non-null if scores are needed or if frequencies need to be + // counted final Bucket[] buckets; // This is basically an inlined FixedBitSet... seems to help with bound checks final long[] matching = new long[SET_SIZE]; @@ -146,6 +148,52 @@ public void collect(int doc) throws IOException { final OrCollector orCollector = new OrCollector(); + final class DocIdStreamView extends DocIdStream { + + int base; + + @Override + public void forEach(CheckedIntConsumer consumer) throws IOException { + long[] matching = BooleanScorer.this.matching; + Bucket[] buckets = BooleanScorer.this.buckets; + int base = this.base; + for (int idx = 0; idx < matching.length; idx++) { + long bits = matching[idx]; + while (bits != 0L) { + int ntz = Long.numberOfTrailingZeros(bits); + if (buckets != null) { + final int indexInWindow = (idx << 6) | ntz; + final Bucket bucket = buckets[indexInWindow]; + if (bucket.freq >= minShouldMatch) { + score.score = (float) bucket.score; + consumer.accept(base | indexInWindow); + } + bucket.freq = 0; + bucket.score = 0; + } else { + consumer.accept(base | (idx << 6) | ntz); + } + bits ^= 1L << ntz; + } + } + } + + @Override + public int count() throws IOException { + if (minShouldMatch > 1) { + // We can't just count bits in that case + return super.count(); + } + int count = 0; + for (long l : matching) { + count += Long.bitCount(l); + } + return count; + } + } + + private final DocIdStreamView docIdStreamView = new DocIdStreamView(); + BooleanScorer( BooleanWeight weight, Collection scorers, @@ -186,35 +234,6 @@ public long cost() { return cost; } - private void scoreDocument(LeafCollector collector, int base, int i) throws IOException { - if (buckets != null) { - final Score score = this.score; - final Bucket bucket = buckets[i]; - if (bucket.freq >= minShouldMatch) { - score.score = (float) bucket.score; - final int doc = base | i; - collector.collect(doc); - } - bucket.freq = 0; - bucket.score = 0; - } else { - collector.collect(base | i); - } - } - - private void scoreMatches(LeafCollector collector, int base) throws IOException { - long[] matching = this.matching; - for (int idx = 0; idx < matching.length; idx++) { - long bits = matching[idx]; - while (bits != 0L) { - int ntz = Long.numberOfTrailingZeros(bits); - int doc = idx << 6 | ntz; - scoreDocument(collector, base, doc); - bits ^= 1L << ntz; - } - } - } - private void scoreWindowIntoBitSetAndReplay( LeafCollector collector, Bits acceptDocs, @@ -230,7 +249,9 @@ private void scoreWindowIntoBitSetAndReplay( scorer.score(orCollector, acceptDocs, min, max); } - scoreMatches(collector, base); + docIdStreamView.base = base; + collector.collect(docIdStreamView); + Arrays.fill(matching, 0L); } diff --git a/lucene/core/src/java/org/apache/lucene/search/CheckedIntConsumer.java b/lucene/core/src/java/org/apache/lucene/search/CheckedIntConsumer.java new file mode 100644 index 000000000000..0e2b67084cf8 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/CheckedIntConsumer.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.util.function.IntConsumer; + +/** Like {@link IntConsumer}, but may throw checked exceptions. */ +@FunctionalInterface +public interface CheckedIntConsumer { + + /** + * Process the given value. + * + * @see IntConsumer#accept(int) + */ + void accept(int value) throws T; +} diff --git a/lucene/core/src/java/org/apache/lucene/search/DocIdStream.java b/lucene/core/src/java/org/apache/lucene/search/DocIdStream.java new file mode 100644 index 000000000000..9cb4d4ed33bd --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/DocIdStream.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; + +/** + * A stream of doc IDs. Most methods on {@link DocIdStream}s are terminal, meaning that the {@link + * DocIdStream} may not be further used. + * + * @see LeafCollector#collect(DocIdStream) + * @lucene.experimental + */ +public abstract class DocIdStream { + + /** Sole constructor, for invocation by sub classes. */ + protected DocIdStream() {} + + /** + * Iterate over doc IDs contained in this stream in order, calling the given {@link + * CheckedIntConsumer} on them. This is a terminal operation. + */ + public abstract void forEach(CheckedIntConsumer consumer) throws IOException; + + /** Count the number of entries in this stream. This is a terminal operation. */ + public int count() throws IOException { + int[] count = new int[1]; + forEach(doc -> count[0]++); + return count[0]; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/LeafCollector.java b/lucene/core/src/java/org/apache/lucene/search/LeafCollector.java index 334afc798ca6..ec387ca0f0d3 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LeafCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/LeafCollector.java @@ -83,6 +83,29 @@ public interface LeafCollector { */ void collect(int doc) throws IOException; + /** + * Bulk-collect doc IDs. + * + *

Note: The provided {@link DocIdStream} may be reused across calls and should be consumed + * immediately. + * + *

Note: The provided {@link DocIdStream} typically only holds a small subset of query matches. + * This method may be called multiple times per segment. + * + *

Like {@link #collect(int)}, it is guaranteed that doc IDs get collected in order, ie. doc + * IDs are collected in order within a {@link DocIdStream}, and if called twice, all doc IDs from + * the second {@link DocIdStream} will be greater than all doc IDs from the first {@link + * DocIdStream}. + * + *

It is legal for callers to mix calls to {@link #collect(DocIdStream)} and {@link + * #collect(int)}. + * + *

The default implementation calls {@code stream.forEach(this::collect)}. + */ + default void collect(DocIdStream stream) throws IOException { + stream.forEach(this::collect); + } + /** * Optionally returns an iterator over competitive documents. * diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java index 72418594b6c2..c5372f3170a4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiCollector.java @@ -211,6 +211,7 @@ public void setMinCompetitiveScore(float minScore) throws IOException { } } + // NOTE: not propagating collect(DocIdStream) since DocIdStreams may only be consumed once. @Override public void collect(int doc) throws IOException { for (int i = 0; i < collectors.length; i++) { diff --git a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java index 30d0659f2cd8..0dcf7af5d01f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TotalHitCountCollector.java @@ -59,6 +59,11 @@ public void setScorer(Scorable scorer) throws IOException {} public void collect(int doc) throws IOException { totalHits++; } + + @Override + public void collect(DocIdStream stream) throws IOException { + totalHits += stream.count(); + } }; } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingLeafCollector.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingLeafCollector.java index a649b4a3869c..cdf170599a00 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingLeafCollector.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingLeafCollector.java @@ -17,7 +17,9 @@ package org.apache.lucene.tests.search; import java.io.IOException; +import org.apache.lucene.search.CheckedIntConsumer; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.DocIdStream; import org.apache.lucene.search.FilterLeafCollector; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Scorable; @@ -42,6 +44,11 @@ public void setScorer(Scorable scorer) throws IOException { super.setScorer(AssertingScorable.wrap(scorer)); } + @Override + public void collect(DocIdStream stream) throws IOException { + in.collect(new AssertingDocIdStream(stream)); + } + @Override public void collect(int doc) throws IOException { assert doc > lastCollected : "Out of order : " + lastCollected + " " + doc; @@ -91,4 +98,36 @@ public void finish() throws IOException { finishCalled = true; super.finish(); } + + private class AssertingDocIdStream extends DocIdStream { + + private final DocIdStream stream; + private boolean consumed; + + AssertingDocIdStream(DocIdStream stream) { + this.stream = stream; + } + + @Override + public void forEach(CheckedIntConsumer consumer) throws IOException { + assert consumed == false : "A terminal operation has already been called"; + stream.forEach( + doc -> { + assert doc > lastCollected : "Out of order : " + lastCollected + " " + doc; + assert doc >= min : "Out of range: " + doc + " < " + min; + assert doc < max : "Out of range: " + doc + " >= " + max; + consumer.accept(doc); + lastCollected = doc; + }); + consumed = true; + } + + @Override + public int count() throws IOException { + assert consumed == false : "A terminal operation has already been called"; + int count = stream.count(); + consumed = true; + return count; + } + } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java index fe5507eda718..0cd75f6b06c5 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java @@ -43,6 +43,7 @@ import org.apache.lucene.index.Terms; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.DocIdStream; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.LeafCollector; @@ -136,6 +137,7 @@ public static void check(Random random, Query q1, IndexSearcher s, boolean wrap) checkFirstSkipTo(q1, s); checkSkipTo(q1, s); checkBulkScorerSkipTo(random, q1, s); + checkCount(q1, s); if (wrap) { check(random, q1, wrapUnderlyingReader(random, s, -1), false); check(random, q1, wrapUnderlyingReader(random, s, 0), false); @@ -743,4 +745,71 @@ public void collect(int doc) throws IOException { } } } + + /** + * Check that counting hits through {@link DocIdStream#count()} yield the same result as counting + * naively. + */ + public static void checkCount(Query query, final IndexSearcher searcher) throws IOException { + query = searcher.rewrite(query); + Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1); + for (LeafReaderContext context : searcher.getIndexReader().leaves()) { + BulkScorer scorer = weight.bulkScorer(context); + if (scorer == null) { + continue; + } + int[] expectedCount = {0}; + boolean[] docIdStream = {false}; + scorer.score( + new LeafCollector() { + @Override + public void collect(DocIdStream stream) throws IOException { + // Don't use DocIdStream#count, we want to count the slow way here. + docIdStream[0] = true; + LeafCollector.super.collect(stream); + } + + @Override + public void collect(int doc) throws IOException { + expectedCount[0]++; + } + + @Override + public void setScorer(Scorable scorer) throws IOException {} + }, + context.reader().getLiveDocs(), + 0, + DocIdSetIterator.NO_MORE_DOCS); + if (docIdStream[0] == false) { + // Don't spend cycles running the query one more time, it doesn't use the DocIdStream + // optimization. + continue; + } + scorer = weight.bulkScorer(context); + if (scorer == null) { + assertEquals(0, expectedCount[0]); + continue; + } + int[] actualCount = {0}; + scorer.score( + new LeafCollector() { + @Override + public void collect(DocIdStream stream) throws IOException { + actualCount[0] += stream.count(); + } + + @Override + public void collect(int doc) throws IOException { + actualCount[0]++; + } + + @Override + public void setScorer(Scorable scorer) throws IOException {} + }, + context.reader().getLiveDocs(), + 0, + DocIdSetIterator.NO_MORE_DOCS); + assertEquals(expectedCount[0], actualCount[0]); + } + } }