Skip to content

Commit

Permalink
Enable boosts on JoinUtil queries (#12388)
Browse files Browse the repository at this point in the history
Boosts should not be ignored by queries returned from JoinUtil
  • Loading branch information
romseygeek committed Jun 26, 2023
1 parent ddcaa59 commit 3ff2c9b
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 22 deletions.
3 changes: 2 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ Optimizations

Bug Fixes
---------------------
(No changes)

* GITHUB#12388: JoinUtil queries were ignoring boosts. (Alan Woodward)

Other
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,21 @@ abstract class BaseGlobalOrdinalScorer extends Scorer {

final SortedDocValues values;
final DocIdSetIterator approximation;
final float boost;

float score;

public BaseGlobalOrdinalScorer(
Weight weight, SortedDocValues values, DocIdSetIterator approximationScorer) {
Weight weight, SortedDocValues values, DocIdSetIterator approximationScorer, float boost) {
super(weight);
this.values = values;
this.approximation = approximationScorer;
this.boost = boost;
}

@Override
public float score() throws IOException {
return score;
return score * boost;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public OrdinalMapScorer(
SortedDocValues values,
DocIdSetIterator approximationScorer,
LongValues segmentOrdToGlobalOrdLookup) {
super(weight, values, approximationScorer);
super(weight, values, approximationScorer, 1);
this.score = score;
this.foundOrds = foundOrds;
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
Expand Down Expand Up @@ -255,7 +255,7 @@ public SegmentOrdinalScorer(
LongBitSet foundOrds,
SortedDocValues values,
DocIdSetIterator approximationScorer) {
super(weight, values, approximationScorer);
super(weight, values, approximationScorer, 1);
this.score = score;
this.foundOrds = foundOrds;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.lucene.search.join;

import java.io.IOException;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.OrdinalMap;
import org.apache.lucene.index.SortedDocValues;
Expand Down Expand Up @@ -117,7 +116,8 @@ public Weight createWeight(
}
return new W(
this,
toQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f));
toQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f),
boost);
}

@Override
Expand Down Expand Up @@ -169,13 +169,16 @@ public long ramBytesUsed() {

final class W extends FilterWeight {

W(Query query, Weight approximationWeight) {
final float boost;

W(Query query, Weight approximationWeight, float boost) {
super(query, approximationWeight);
this.boost = boost;
}

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField);
SortedDocValues values = context.reader().getSortedDocValues(joinField);
if (values == null) {
return Explanation.noMatch("Not a match");
}
Expand All @@ -197,12 +200,16 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
}

float score = collector.score(ord);
return Explanation.match(score, "A match, join value " + Term.toString(joinValue));
if (boost == 1.0f) {
return Explanation.match(score, "A match, join value " + Term.toString(joinValue));
}
return Explanation.match(
score * boost, "A match, join value " + Term.toString(joinValue) + "^" + boost);
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField);
SortedDocValues values = context.reader().getSortedDocValues(joinField);
if (values == null) {
return null;
}
Expand All @@ -214,11 +221,13 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return new OrdinalMapScorer(
this,
collector,
boost,
values,
approximationScorer.iterator(),
globalOrds.getGlobalOrds(context.ord));
} else {
return new SegmentOrdinalScorer(this, collector, values, approximationScorer.iterator());
return new SegmentOrdinalScorer(
this, collector, values, boost, approximationScorer.iterator());
}
}

Expand All @@ -239,10 +248,11 @@ static final class OrdinalMapScorer extends BaseGlobalOrdinalScorer {
public OrdinalMapScorer(
Weight weight,
GlobalOrdinalsWithScoreCollector collector,
float boost,
SortedDocValues values,
DocIdSetIterator approximation,
LongValues segmentOrdToGlobalOrdLookup) {
super(weight, values, approximation);
super(weight, values, approximation, boost);
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
this.collector = collector;
}
Expand Down Expand Up @@ -280,8 +290,9 @@ public SegmentOrdinalScorer(
Weight weight,
GlobalOrdinalsWithScoreCollector collector,
SortedDocValues values,
float boost,
DocIdSetIterator approximation) {
super(weight, values, approximation);
super(weight, values, approximation, boost);
this.collector = collector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,17 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
postingsEnum = segmentTermsEnum.postings(postingsEnum, PostingsEnum.NONE);
if (postingsEnum.advance(doc) == doc) {
final float score = TermsIncludingScoreQuery.this.scores[ords[i]];
return Explanation.match(
score, "Score based on join value " + segmentTermsEnum.term().utf8ToString());
if (boost == 1.0f) {
return Explanation.match(
score, "Score based on join value " + segmentTermsEnum.term().utf8ToString());
} else {
return Explanation.match(
score * boost,
"Score based on join value "
+ segmentTermsEnum.term().utf8ToString()
+ "^"
+ boost);
}
}
}
}
Expand All @@ -172,9 +181,11 @@ public Scorer scorer(LeafReaderContext context) throws IOException {

TermsEnum segmentTermsEnum = terms.iterator();
if (multipleValuesPerDocument) {
return new MVInOrderScorer(this, segmentTermsEnum, context.reader().maxDoc(), cost);
return new MVInOrderScorer(
this, segmentTermsEnum, context.reader().maxDoc(), cost, boost);
} else {
return new SVInOrderScorer(this, segmentTermsEnum, context.reader().maxDoc(), cost);
return new SVInOrderScorer(
this, segmentTermsEnum, context.reader().maxDoc(), cost, boost);
}
}

Expand All @@ -190,14 +201,17 @@ class SVInOrderScorer extends Scorer {
final DocIdSetIterator matchingDocsIterator;
final float[] scores;
final long cost;
final float boost;

SVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost) throws IOException {
SVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost, float boost)
throws IOException {
super(weight);
FixedBitSet matchingDocs = new FixedBitSet(maxDoc);
this.scores = new float[maxDoc];
fillDocsAndScores(matchingDocs, termsEnum);
this.matchingDocsIterator = new BitSetIterator(matchingDocs, cost);
this.cost = cost;
this.boost = boost;
}

protected void fillDocsAndScores(FixedBitSet matchingDocs, TermsEnum termsEnum)
Expand All @@ -223,7 +237,7 @@ protected void fillDocsAndScores(FixedBitSet matchingDocs, TermsEnum termsEnum)

@Override
public float score() throws IOException {
return scores[docID()];
return scores[docID()] * boost;
}

@Override
Expand All @@ -246,8 +260,9 @@ public DocIdSetIterator iterator() {
// related documents.
class MVInOrderScorer extends SVInOrderScorer {

MVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost) throws IOException {
super(weight, termsEnum, maxDoc, cost);
MVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost, float boost)
throws IOException {
super(weight, termsEnum, maxDoc, cost, boost);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.analysis.MockTokenizer;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet;
Expand Down Expand Up @@ -689,6 +690,7 @@ public void testMinMaxDocs() throws Exception {
}
}
assertEquals(expectedCount, totalHits);
checkBoost(joinQuery, searcher);
}
searcher.getIndexReader().close();
dir.close();
Expand Down Expand Up @@ -997,6 +999,7 @@ public void testSimpleWithScoring() throws Exception {
assertEquals(2, result.totalHits.value);
assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);

// Score mode max.
joinQuery =
Expand All @@ -1011,6 +1014,7 @@ public void testSimpleWithScoring() throws Exception {
assertEquals(2, result.totalHits.value);
assertEquals(3, result.scoreDocs[0].doc);
assertEquals(0, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);

// Score mode total
joinQuery =
Expand All @@ -1025,6 +1029,7 @@ public void testSimpleWithScoring() throws Exception {
assertEquals(2, result.totalHits.value);
assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);

// Score mode avg
joinQuery =
Expand All @@ -1039,11 +1044,20 @@ public void testSimpleWithScoring() throws Exception {
assertEquals(2, result.totalHits.value);
assertEquals(3, result.scoreDocs[0].doc);
assertEquals(0, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);

indexSearcher.getIndexReader().close();
dir.close();
}

private void checkBoost(Query query, IndexSearcher searcher) throws IOException {
TopDocs result = searcher.search(query, 10);
Query boostedQuery = new BoostQuery(query, 10);
TopDocs boostedResult = searcher.search(boostedQuery, 10);
assertEquals(result.scoreDocs[0].score * 10, boostedResult.scoreDocs[0].score, 0.000001f);
QueryUtils.checkExplanations(boostedQuery, searcher);
}

public void testEquals() throws Exception {
final int numDocs = atLeast(random(), 50);
try (final Directory dir = newDirectory()) {
Expand Down

0 comments on commit 3ff2c9b

Please sign in to comment.