From 1039bb48e094613707ffca95f26c8feb35335743 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Christoph=20B=C3=BCscher?=
Date: Fri, 25 Oct 2019 13:34:44 +0200
Subject: [PATCH] BlendedTermQuery's equals method should consider boosts
(#48193)
This changes the queries equals() method so that the boost factors for each term
are considered for the equality calculation. This means queries are only equal
if both their terms and associated boosts match. The ordering of the terms
doesn't matter as before, which is why we internally need to sort the terms and
boost for comparison on the first equals() call like before. Boosts that are
`null` are considered equal to boosts of 1.0f because topLevelQuery() will only
wrap into BoostQuery if boost is not null and different from 1f.
Closes #48184
---
.../lucene/queries/BlendedTermQuery.java | 77 +++++++++++++++----
.../lucene/queries/BlendedTermQueryTests.java | 71 +++++++++++++++++
2 files changed, 132 insertions(+), 16 deletions(-)
diff --git a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java
index 5f00631ad6028..4f724017740be 100644
--- a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java
+++ b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java
@@ -57,7 +57,6 @@
* which is the minimum number of documents the terms occurs in.
*
*/
-// TODO maybe contribute to Lucene
public abstract class BlendedTermQuery extends Query {
private final Term[] terms;
@@ -243,36 +242,82 @@ public String toString(String field) {
return builder.toString();
}
- private volatile Term[] equalTerms = null;
+ private class TermAndBoost implements Comparable {
+ protected final Term term;
+ protected float boost;
- private Term[] equalsTerms() {
- if (terms.length == 1) {
- return terms;
+ protected TermAndBoost(Term term, float boost) {
+ this.term = term;
+ this.boost = boost;
+ }
+
+ @Override
+ public int compareTo(TermAndBoost other) {
+ int compareTo = term.compareTo(other.term);
+ if (compareTo == 0) {
+ compareTo = Float.compare(boost, other.boost);
+ }
+ return compareTo;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o instanceof TermAndBoost == false) {
+ return false;
+ }
+
+ TermAndBoost that = (TermAndBoost) o;
+ return term.equals(that.term) && (Float.compare(boost, that.boost) == 0);
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * term.hashCode() + Float.hashCode(boost);
}
- if (equalTerms == null) {
+ }
+
+ private volatile TermAndBoost[] equalTermsAndBoosts = null;
+
+ private TermAndBoost[] equalsTermsAndBoosts() {
+ if (equalTermsAndBoosts != null) {
+ return equalTermsAndBoosts;
+ }
+ if (terms.length == 1) {
+ float boost = (boosts != null ? boosts[0] : 1f);
+ equalTermsAndBoosts = new TermAndBoost[] {new TermAndBoost(terms[0], boost)};
+ } else {
// sort the terms to make sure equals and hashCode are consistent
// this should be a very small cost and equivalent to a HashSet but less object creation
- final Term[] t = new Term[terms.length];
- System.arraycopy(terms, 0, t, 0, terms.length);
- ArrayUtil.timSort(t);
- equalTerms = t;
+ equalTermsAndBoosts = new TermAndBoost[terms.length];
+ for (int i = 0; i < terms.length; i++) {
+ float boost = (boosts != null ? boosts[i] : 1f);
+ equalTermsAndBoosts[i] = new TermAndBoost(terms[i], boost);
+ }
+ ArrayUtil.timSort(equalTermsAndBoosts);
}
- return equalTerms;
-
+ return equalTermsAndBoosts;
}
@Override
public boolean equals(Object o) {
- if (this == o) return true;
- if (sameClassAs(o) == false) return false;
+ if (this == o) {
+ return true;
+ }
+ if (sameClassAs(o) == false) {
+ return false;
+ }
BlendedTermQuery that = (BlendedTermQuery) o;
- return Arrays.equals(equalsTerms(), that.equalsTerms());
+ return Arrays.equals(equalsTermsAndBoosts(), that.equalsTermsAndBoosts());
+
}
@Override
public int hashCode() {
- return Objects.hash(classHash(), Arrays.hashCode(equalsTerms()));
+ return Objects.hash(classHash(), Arrays.hashCode(equalsTermsAndBoosts()));
}
public static BlendedTermQuery dismaxBlendedQuery(Term[] terms, final float tieBreakerMultiplier) {
diff --git a/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java b/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java
index b7b2107320e39..1513817e37a5f 100644
--- a/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java
+++ b/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java
@@ -42,6 +42,9 @@
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.store.Directory;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.EqualsHashCodeTestUtils;
+import org.elasticsearch.test.EqualsHashCodeTestUtils.CopyFunction;
+import org.elasticsearch.test.EqualsHashCodeTestUtils.MutateFunction;
import java.io.IOException;
import java.util.Arrays;
@@ -254,4 +257,72 @@ public void testMinTTF() throws IOException {
w.close();
dir.close();
}
+
+ public void testEqualsAndHash() {
+ String[] fields = new String[1 + random().nextInt(10)];
+ for (int i = 0; i < fields.length; i++) {
+ fields[i] = randomRealisticUnicodeOfLengthBetween(1, 10);
+ }
+ String term = randomRealisticUnicodeOfLengthBetween(1, 10);
+ Term[] terms = toTerms(fields, term);
+ float tieBreaker = randomFloat();
+ final float[] boosts;
+ if (randomBoolean()) {
+ boosts = new float[terms.length];
+ for (int i = 0; i < terms.length; i++) {
+ boosts[i] = randomFloat();
+ }
+ } else {
+ boosts = null;
+ }
+
+ BlendedTermQuery original = BlendedTermQuery.dismaxBlendedQuery(terms, boosts, tieBreaker);
+ CopyFunction copyFunction = org -> {
+ Term[] termsCopy = new Term[terms.length];
+ System.arraycopy(terms, 0, termsCopy, 0, terms.length);
+
+ float[] boostsCopy = null;
+ if (boosts != null) {
+ boostsCopy = new float[boosts.length];
+ System.arraycopy(boosts, 0, boostsCopy, 0, terms.length);
+ }
+ if (randomBoolean() && terms.length > 1) {
+ // if we swap two elements, the resulting query should still be regarded as equal
+ int swapPos = randomIntBetween(1, terms.length - 1);
+
+ Term swpTerm = termsCopy[0];
+ termsCopy[0] = termsCopy[swapPos];
+ termsCopy[swapPos] = swpTerm;
+
+ if (boosts != null) {
+ float swpBoost = boostsCopy[0];
+ boostsCopy[0] = boostsCopy[swapPos];
+ boostsCopy[swapPos] = swpBoost;
+ }
+ }
+ return BlendedTermQuery.dismaxBlendedQuery(termsCopy, boostsCopy, tieBreaker);
+ };
+ MutateFunction mutateFunction = org -> {
+ if (randomBoolean()) {
+ Term[] termsCopy = new Term[terms.length];
+ System.arraycopy(terms, 0, termsCopy, 0, terms.length);
+ termsCopy[randomIntBetween(0, terms.length - 1)] = new Term(randomAlphaOfLength(10), randomAlphaOfLength(10));
+ return BlendedTermQuery.dismaxBlendedQuery(termsCopy, boosts, tieBreaker);
+ } else {
+ float[] boostsCopy = null;
+ if (boosts != null) {
+ boostsCopy = new float[boosts.length];
+ System.arraycopy(boosts, 0, boostsCopy, 0, terms.length);
+ boostsCopy[randomIntBetween(0, terms.length - 1)] = randomFloat();
+ } else {
+ boostsCopy = new float[terms.length];
+ for (int i = 0; i < terms.length; i++) {
+ boostsCopy[i] = randomFloat();
+ }
+ }
+ return BlendedTermQuery.dismaxBlendedQuery(terms, boostsCopy, tieBreaker);
+ }
+ };
+ EqualsHashCodeTestUtils.checkEqualsAndHashCode(original, copyFunction, mutateFunction );
+ }
}