Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ConjunctionDISI.createConjunction #12328

Merged
merged 4 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ Optimizations

* GITHUB#11857, GITHUB#11859, GITHUB#11893, GITHUB#11909: Hunspell: improved suggestion performance (Peter Gromov)

* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)
javanna marked this conversation as resolved.
Show resolved Hide resolved

* GITHUB#12328: Optimize ConjunctionDISI.createConjunction (Armin Braun)

Bug Fixes
---------------------

Expand Down
47 changes: 20 additions & 27 deletions lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitSet;
Expand Down Expand Up @@ -99,24 +98,26 @@ static DocIdSetIterator createConjunction(
allIterators.size() > 0
? allIterators.get(0).docID()
: twoPhaseIterators.get(0).approximation.docID();
boolean iteratorsOnTheSameDoc = allIterators.stream().allMatch(it -> it.docID() == curDoc);
iteratorsOnTheSameDoc =
iteratorsOnTheSameDoc
&& twoPhaseIterators.stream().allMatch(it -> it.approximation().docID() == curDoc);
if (iteratorsOnTheSameDoc == false) {
throw new IllegalArgumentException(
"Sub-iterators of ConjunctionDISI are not on the same document!");
long minCost = Long.MAX_VALUE;
for (DocIdSetIterator allIterator : allIterators) {
if (allIterator.docID() != curDoc) {
throwSubIteratorsNotOnSameDocument();
}
minCost = Math.min(allIterator.cost(), minCost);
}
for (TwoPhaseIterator it : twoPhaseIterators) {
if (it.approximation().docID() != curDoc) {
throwSubIteratorsNotOnSameDocument();
}
}

long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong();
List<BitSetIterator> bitSetIterators = new ArrayList<>();
List<DocIdSetIterator> iterators = new ArrayList<>();
for (DocIdSetIterator iterator : allIterators) {
if (iterator.cost() > minCost && iterator instanceof BitSetIterator) {
if (iterator instanceof BitSetIterator bitSetIterator && bitSetIterator.cost() > minCost) {
// we put all bitset iterators into bitSetIterators
// except if they have the minimum cost, since we need
// them to lead the iteration in that case
bitSetIterators.add((BitSetIterator) iterator);
bitSetIterators.add(bitSetIterator);
} else {
iterators.add(iterator);
}
Expand All @@ -142,6 +143,11 @@ static DocIdSetIterator createConjunction(
return disi;
}

private static void throwSubIteratorsNotOnSameDocument() {
throw new IllegalArgumentException(
"Sub-iterators of ConjunctionDISI are not on the same document!");
}

final DocIdSetIterator lead1, lead2;
final DocIdSetIterator[] others;

Expand All @@ -150,14 +156,7 @@ private ConjunctionDISI(List<? extends DocIdSetIterator> iterators) {

// Sort the array the first time to allow the least frequent DocsEnum to
// lead the matching.
CollectionUtil.timSort(
iterators,
new Comparator<DocIdSetIterator>() {
@Override
public int compare(DocIdSetIterator o1, DocIdSetIterator o2) {
return Long.compare(o1.cost(), o2.cost());
}
});
CollectionUtil.timSort(iterators, (o1, o2) -> Long.compare(o1.cost(), o2.cost()));
lead1 = iterators.get(0);
lead2 = iterators.get(1);
others = iterators.subList(2, iterators.size()).toArray(new DocIdSetIterator[0]);
Expand Down Expand Up @@ -326,13 +325,7 @@ private ConjunctionTwoPhaseIterator(
assert twoPhaseIterators.size() > 0;

CollectionUtil.timSort(
twoPhaseIterators,
new Comparator<TwoPhaseIterator>() {
@Override
public int compare(TwoPhaseIterator o1, TwoPhaseIterator o2) {
return Float.compare(o1.matchCost(), o2.matchCost());
}
});
twoPhaseIterators, (o1, o2) -> Float.compare(o1.matchCost(), o2.matchCost()));

this.twoPhaseIterators =
twoPhaseIterators.toArray(new TwoPhaseIterator[twoPhaseIterators.size()]);
Expand Down