Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LUCENE-10456: Implement Weight#count for MultiRangeQuery #731

Merged
merged 13 commits into from
Apr 5, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
Expand Down Expand Up @@ -163,6 +165,76 @@ public void visit(QueryVisitor visitor) {
}
}

/**
* Merges the overlapping ranges and returns unconnected ranges by calling {@link
* #mergeOverlappingRanges}
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
if (numDims != 1) {
return this;
}
List<RangeClause> mergedRanges = mergeOverlappingRanges(rangeClauses, bytesPerDim);
if (mergedRanges != rangeClauses) {
return new MultiRangeQuery(field, numDims, bytesPerDim, mergedRanges) {
@Override
protected String toString(int dimension, byte[] value) {
return MultiRangeQuery.this.toString(dimension, value);
}
};
} else {
return this;
}
}

/**
* merge overlapping ranges to some unconnected ranges
*
* @param rangeClauses some overlapping ranges
* @param bytesPerDim bytes per Dimension of the point value
* @return unconnected ranges
*/
public static List<RangeClause> mergeOverlappingRanges(
List<RangeClause> rangeClauses, int bytesPerDim) {
if (rangeClauses.size() <= 1) {
return rangeClauses;
}
List<RangeClause> originRangeClause = new ArrayList<>(rangeClauses);
final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim);
originRangeClause.sort(
new Comparator<RangeClause>() {
@Override
public int compare(RangeClause o1, RangeClause o2) {
int result = comparator.compare(o1.lowerValue, 0, o2.lowerValue, 0);
if (result == 0) {
return comparator.compare(o1.upperValue, 0, o2.upperValue, 0);
} else {
return result;
}
}
});
List<RangeClause> finalRangeClause = new ArrayList<>();
RangeClause current = originRangeClause.get(0);
for (int i = 1; i < originRangeClause.size(); i++) {
RangeClause nextClause = originRangeClause.get(i);
if (comparator.compare(nextClause.lowerValue, 0, current.upperValue, 0) > 0) {
finalRangeClause.add(current);
current = nextClause;
} else {
if (comparator.compare(nextClause.upperValue, 0, current.upperValue, 0) > 0) {
current = new RangeClause(current.lowerValue, nextClause.upperValue);
}
}
}
finalRangeClause.add(current);
/** saves us from creating an extra MultiRangeQuery object in {@link #rewrite} */
if (finalRangeClause.size() != rangeClauses.size()) {
return finalRangeClause;
} else {
return rangeClauses;
}
}

/*
* TODO: Organize ranges similar to how EdgeTree does, to avoid linear scan of ranges
*/
Expand Down Expand Up @@ -314,6 +386,32 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}

@Override
public int count(LeafReaderContext context) throws IOException {
if (numDims != 1 || context.reader().hasDeletions() == true) {
return super.count(context);
}
List<RangeClause> mergeRangeClause = mergeOverlappingRanges(rangeClauses, bytesPerDim);
int total = 0;
for (RangeClause rangeClause : mergeRangeClause) {
PointRangeQuery pointRangeQuery =
new PointRangeQuery(field, rangeClause.lowerValue, rangeClause.upperValue, numDims) {
@Override
protected String toString(int dimension, byte[] value) {
return MultiRangeQuery.this.toString(dimension, value);
}
};
int count = pointRangeQuery.createWeight(searcher, scoreMode, boost).count(context);

if (count != -1) {
total += count;
} else {
return super.count(context);
}
}
return total;
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import java.io.IOException;
import java.util.Random;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
Expand All @@ -33,6 +34,7 @@
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
Expand Down Expand Up @@ -761,4 +763,93 @@ public void testEqualsAndHashCode() {
assertNotEquals(query1.hashCode(), query3.hashCode());
}
}

private void addRandomDocs(RandomIndexWriter w) throws IOException {
Random random = random();
for (int i = 0; i < random.nextInt(100, 500); i++) {
int numPoints = RandomNumbers.randomIntBetween(random(), 1, 200);
long value = RandomNumbers.randomLongBetween(random(), 0, 2000);
for (int j = 0; j < numPoints; j++) {
Document doc = new Document();
doc.add(new LongPoint("point", value));
w.addDocument(doc);
}
}
w.flush();
w.forceMerge(1);
}

/**
* test rewrite query. the hit doc count of the rewritten query should be the same as origin
* query's.
*/
public void testRandomRewrite() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
int dims = 1;
addRandomDocs(w);
Random random = random();

IndexReader reader = w.getReader();
IndexSearcher searcher = newSearcher(reader);
int numIters = random.nextInt(200);

for (int n = 0; n < numIters; n++) {
int numRanges = RandomNumbers.randomIntBetween(random(), 1, 20);
LongPointMultiRangeBuilder builder1 = new LongPointMultiRangeBuilder("point", dims);
for (int i = 0; i < numRanges; i++) {
long[] lower = new long[dims];
long[] upper = new long[dims];
for (int j = 0; j < dims; j++) {
lower[j] = RandomNumbers.randomLongBetween(random(), 0, 2000);
upper[j] = lower[j] + RandomNumbers.randomLongBetween(random(), 0, 2000);
}
builder1.add(lower, upper);
}

MultiRangeQuery multiRangeQuery = builder1.build();
MultiRangeQuery rewriteMultiRangeQuery = (MultiRangeQuery) multiRangeQuery.rewrite(reader);
int count = searcher.count(multiRangeQuery);
int rewriteCount = searcher.count(rewriteMultiRangeQuery);
assertEquals(rewriteCount, count);
}
IOUtils.close(reader, w, dir);
}

public void testOneDimensionCount() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
int dims = 1;
addRandomDocs(w);

IndexReader reader = w.getReader();
IndexSearcher searcher = newSearcher(reader);
Random random = random();
int numIters = random.nextInt(200);
for (int n = 0; n < numIters; n++) {
int numRanges = RandomNumbers.randomIntBetween(random(), 1, 20);
LongPointMultiRangeBuilder builder1 = new LongPointMultiRangeBuilder("point", dims);
BooleanQuery.Builder builder2 = new BooleanQuery.Builder();
for (int i = 0; i < numRanges; i++) {
long[] lower = new long[dims];
long[] upper = new long[dims];
for (int j = 0; j < dims; j++) {
lower[j] = RandomNumbers.randomLongBetween(random(), 0, 2000);
upper[j] = lower[j] + RandomNumbers.randomLongBetween(random(), 0, 2000);
}
builder1.add(lower, upper);
builder2.add(LongPoint.newRangeQuery("point", lower, upper), BooleanClause.Occur.SHOULD);
}

MultiRangeQuery multiRangeQuery = builder1.build();
BooleanQuery booleanQuery = builder2.build();
int count =
multiRangeQuery
.createWeight(searcher, ScoreMode.COMPLETE, 1.0f)
.count(searcher.getLeafContexts().get(0));
int booleanCount = searcher.count(booleanQuery);
assertEquals(booleanCount, count);
}
IOUtils.close(reader, w, dir);
}
}