Skip to content

Commit

Permalink
Optimize disjunction counts. (#12415)
Browse files Browse the repository at this point in the history
This introduces `LeafCollector#collect(DocIdStream)` to enable collectors to
collect batches of doc IDs at once. `BooleanScorer` takes advantage of this by
creating a `DocIdStream` whose `count()` method counts the number of bits that
are set in the bit set of matches in the current window, instead of naively
iterating over all matches.

On wikimedium10m, this yields a ~20% speedup when counting hits for the `title
OR 12` query (2.9M hits).

Relates #12358
  • Loading branch information
jpountz authored Aug 11, 2023
1 parent df8745e commit 4d26cb2
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 30 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ Optimizations
* GITHUB#12408: Lazy initialization improvements for Facets implementations when there are segments with no hits
to count. (Greg Miller)

* GITHUB#12415: Optimized counts on disjunctive queries. (Adrien Grand)

Bug Fixes
---------------------

Expand Down
81 changes: 51 additions & 30 deletions lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ public BulkScorerAndDoc get(int i) {
}
}

// One bucket per doc ID in the window, non-null if scores are needed or if frequencies need to be
// counted
final Bucket[] buckets;
// This is basically an inlined FixedBitSet... seems to help with bound checks
final long[] matching = new long[SET_SIZE];
Expand Down Expand Up @@ -146,6 +148,52 @@ public void collect(int doc) throws IOException {

final OrCollector orCollector = new OrCollector();

final class DocIdStreamView extends DocIdStream {

int base;

@Override
public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
long[] matching = BooleanScorer.this.matching;
Bucket[] buckets = BooleanScorer.this.buckets;
int base = this.base;
for (int idx = 0; idx < matching.length; idx++) {
long bits = matching[idx];
while (bits != 0L) {
int ntz = Long.numberOfTrailingZeros(bits);
if (buckets != null) {
final int indexInWindow = (idx << 6) | ntz;
final Bucket bucket = buckets[indexInWindow];
if (bucket.freq >= minShouldMatch) {
score.score = (float) bucket.score;
consumer.accept(base | indexInWindow);
}
bucket.freq = 0;
bucket.score = 0;
} else {
consumer.accept(base | (idx << 6) | ntz);
}
bits ^= 1L << ntz;
}
}
}

@Override
public int count() throws IOException {
if (minShouldMatch > 1) {
// We can't just count bits in that case
return super.count();
}
int count = 0;
for (long l : matching) {
count += Long.bitCount(l);
}
return count;
}
}

private final DocIdStreamView docIdStreamView = new DocIdStreamView();

BooleanScorer(
BooleanWeight weight,
Collection<BulkScorer> scorers,
Expand Down Expand Up @@ -186,35 +234,6 @@ public long cost() {
return cost;
}

private void scoreDocument(LeafCollector collector, int base, int i) throws IOException {
if (buckets != null) {
final Score score = this.score;
final Bucket bucket = buckets[i];
if (bucket.freq >= minShouldMatch) {
score.score = (float) bucket.score;
final int doc = base | i;
collector.collect(doc);
}
bucket.freq = 0;
bucket.score = 0;
} else {
collector.collect(base | i);
}
}

private void scoreMatches(LeafCollector collector, int base) throws IOException {
long[] matching = this.matching;
for (int idx = 0; idx < matching.length; idx++) {
long bits = matching[idx];
while (bits != 0L) {
int ntz = Long.numberOfTrailingZeros(bits);
int doc = idx << 6 | ntz;
scoreDocument(collector, base, doc);
bits ^= 1L << ntz;
}
}
}

private void scoreWindowIntoBitSetAndReplay(
LeafCollector collector,
Bits acceptDocs,
Expand All @@ -230,7 +249,9 @@ private void scoreWindowIntoBitSetAndReplay(
scorer.score(orCollector, acceptDocs, min, max);
}

scoreMatches(collector, base);
docIdStreamView.base = base;
collector.collect(docIdStreamView);

Arrays.fill(matching, 0L);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.util.function.IntConsumer;

/** Like {@link IntConsumer}, but may throw checked exceptions. */
@FunctionalInterface
public interface CheckedIntConsumer<T extends Exception> {

/**
* Process the given value.
*
* @see IntConsumer#accept(int)
*/
void accept(int value) throws T;
}
45 changes: 45 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/DocIdStream.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.io.IOException;

/**
* A stream of doc IDs. Most methods on {@link DocIdStream}s are terminal, meaning that the {@link
* DocIdStream} may not be further used.
*
* @see LeafCollector#collect(DocIdStream)
* @lucene.experimental
*/
public abstract class DocIdStream {

/** Sole constructor, for invocation by sub classes. */
protected DocIdStream() {}

/**
* Iterate over doc IDs contained in this stream in order, calling the given {@link
* CheckedIntConsumer} on them. This is a terminal operation.
*/
public abstract void forEach(CheckedIntConsumer<IOException> consumer) throws IOException;

/** Count the number of entries in this stream. This is a terminal operation. */
public int count() throws IOException {
int[] count = new int[1];
forEach(doc -> count[0]++);
return count[0];
}
}
23 changes: 23 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/LeafCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,29 @@ public interface LeafCollector {
*/
void collect(int doc) throws IOException;

/**
* Bulk-collect doc IDs.
*
* <p>Note: The provided {@link DocIdStream} may be reused across calls and should be consumed
* immediately.
*
* <p>Note: The provided {@link DocIdStream} typically only holds a small subset of query matches.
* This method may be called multiple times per segment.
*
* <p>Like {@link #collect(int)}, it is guaranteed that doc IDs get collected in order, ie. doc
* IDs are collected in order within a {@link DocIdStream}, and if called twice, all doc IDs from
* the second {@link DocIdStream} will be greater than all doc IDs from the first {@link
* DocIdStream}.
*
* <p>It is legal for callers to mix calls to {@link #collect(DocIdStream)} and {@link
* #collect(int)}.
*
* <p>The default implementation calls {@code stream.forEach(this::collect)}.
*/
default void collect(DocIdStream stream) throws IOException {
stream.forEach(this::collect);
}

/**
* Optionally returns an iterator over competitive documents.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ public void setMinCompetitiveScore(float minScore) throws IOException {
}
}

// NOTE: not propagating collect(DocIdStream) since DocIdStreams may only be consumed once.
@Override
public void collect(int doc) throws IOException {
for (int i = 0; i < collectors.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ public void setScorer(Scorable scorer) throws IOException {}
public void collect(int doc) throws IOException {
totalHits++;
}

@Override
public void collect(DocIdStream stream) throws IOException {
totalHits += stream.count();
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package org.apache.lucene.tests.search;

import java.io.IOException;
import org.apache.lucene.search.CheckedIntConsumer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DocIdStream;
import org.apache.lucene.search.FilterLeafCollector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
Expand All @@ -42,6 +44,11 @@ public void setScorer(Scorable scorer) throws IOException {
super.setScorer(AssertingScorable.wrap(scorer));
}

@Override
public void collect(DocIdStream stream) throws IOException {
in.collect(new AssertingDocIdStream(stream));
}

@Override
public void collect(int doc) throws IOException {
assert doc > lastCollected : "Out of order : " + lastCollected + " " + doc;
Expand Down Expand Up @@ -91,4 +98,36 @@ public void finish() throws IOException {
finishCalled = true;
super.finish();
}

private class AssertingDocIdStream extends DocIdStream {

private final DocIdStream stream;
private boolean consumed;

AssertingDocIdStream(DocIdStream stream) {
this.stream = stream;
}

@Override
public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
assert consumed == false : "A terminal operation has already been called";
stream.forEach(
doc -> {
assert doc > lastCollected : "Out of order : " + lastCollected + " " + doc;
assert doc >= min : "Out of range: " + doc + " < " + min;
assert doc < max : "Out of range: " + doc + " >= " + max;
consumer.accept(doc);
lastCollected = doc;
});
consumed = true;
}

@Override
public int count() throws IOException {
assert consumed == false : "A terminal operation has already been called";
int count = stream.count();
consumed = true;
return count;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DocIdStream;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.LeafCollector;
Expand Down Expand Up @@ -136,6 +137,7 @@ public static void check(Random random, Query q1, IndexSearcher s, boolean wrap)
checkFirstSkipTo(q1, s);
checkSkipTo(q1, s);
checkBulkScorerSkipTo(random, q1, s);
checkCount(q1, s);
if (wrap) {
check(random, q1, wrapUnderlyingReader(random, s, -1), false);
check(random, q1, wrapUnderlyingReader(random, s, 0), false);
Expand Down Expand Up @@ -743,4 +745,71 @@ public void collect(int doc) throws IOException {
}
}
}

/**
* Check that counting hits through {@link DocIdStream#count()} yield the same result as counting
* naively.
*/
public static void checkCount(Query query, final IndexSearcher searcher) throws IOException {
query = searcher.rewrite(query);
Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1);
for (LeafReaderContext context : searcher.getIndexReader().leaves()) {
BulkScorer scorer = weight.bulkScorer(context);
if (scorer == null) {
continue;
}
int[] expectedCount = {0};
boolean[] docIdStream = {false};
scorer.score(
new LeafCollector() {
@Override
public void collect(DocIdStream stream) throws IOException {
// Don't use DocIdStream#count, we want to count the slow way here.
docIdStream[0] = true;
LeafCollector.super.collect(stream);
}

@Override
public void collect(int doc) throws IOException {
expectedCount[0]++;
}

@Override
public void setScorer(Scorable scorer) throws IOException {}
},
context.reader().getLiveDocs(),
0,
DocIdSetIterator.NO_MORE_DOCS);
if (docIdStream[0] == false) {
// Don't spend cycles running the query one more time, it doesn't use the DocIdStream
// optimization.
continue;
}
scorer = weight.bulkScorer(context);
if (scorer == null) {
assertEquals(0, expectedCount[0]);
continue;
}
int[] actualCount = {0};
scorer.score(
new LeafCollector() {
@Override
public void collect(DocIdStream stream) throws IOException {
actualCount[0] += stream.count();
}

@Override
public void collect(int doc) throws IOException {
actualCount[0]++;
}

@Override
public void setScorer(Scorable scorer) throws IOException {}
},
context.reader().getLiveDocs(),
0,
DocIdSetIterator.NO_MORE_DOCS);
assertEquals(expectedCount[0], actualCount[0]);
}
}
}

0 comments on commit 4d26cb2

Please sign in to comment.