Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LUCENE-10456: Implement Weight#count for MultiRangeQuery #731

Merged
merged 13 commits into from
Apr 5, 2022
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ New Features
implementation. `Monitor` can be created with a readonly `QueryIndex` in order to
have readonly `Monitor` instances. (Niko Usai)

* LUCENE-10456: Implement rewrite and Weight#count for MultiRangeQuery
by merging overlapping ranges . (Jianping Weng)

Improvements
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
Expand All @@ -41,12 +43,11 @@

/**
* Abstract class for range queries involving multiple ranges against physical points such as {@code
* IntPoints} All ranges are logically ORed together TODO: Add capability for handling overlapping
* ranges at rewrite time
* IntPoints} All ranges are logically ORed together
*
* @lucene.experimental
*/
public abstract class MultiRangeQuery extends Query {
public abstract class MultiRangeQuery extends Query implements Cloneable {
/** Representation of a single clause in a MultiRangeQuery */
public static final class RangeClause {
byte[] lowerValue;
Expand Down Expand Up @@ -140,7 +141,7 @@ private void checkArgs(Object lowerPoint, Object upperPoint) {
final String field;
final int numDims;
final int bytesPerDim;
final List<RangeClause> rangeClauses;
List<RangeClause> rangeClauses;
/**
* Expert: create a multidimensional range query with multiple connected ranges
*
Expand All @@ -163,6 +164,79 @@ public void visit(QueryVisitor visitor) {
}
}

/**
* Merges the overlapping ranges and returns unconnected ranges by calling {@link
* #mergeOverlappingRanges}
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
if (numDims != 1) {
return this;
}
List<RangeClause> mergedRanges = mergeOverlappingRanges(rangeClauses, bytesPerDim);
if (mergedRanges != rangeClauses) {
try {
MultiRangeQuery clone = (MultiRangeQuery) super.clone();
clone.rangeClauses = mergedRanges;
return clone;
} catch (CloneNotSupportedException e) {
throw new AssertionError(e);
}
} else {
return this;
}
}

/**
* Merges overlapping ranges and returns unconnected ranges
*
* @param rangeClauses some overlapping ranges
* @param bytesPerDim bytes per Dimension of the point value
* @return unconnected ranges
*/
static List<RangeClause> mergeOverlappingRanges(List<RangeClause> rangeClauses, int bytesPerDim) {
if (rangeClauses.size() <= 1) {
return rangeClauses;
}
List<RangeClause> originRangeClause = new ArrayList<>(rangeClauses);
final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim);
originRangeClause.sort(
new Comparator<RangeClause>() {
@Override
public int compare(RangeClause o1, RangeClause o2) {
int result = comparator.compare(o1.lowerValue, 0, o2.lowerValue, 0);
if (result == 0) {
return comparator.compare(o1.upperValue, 0, o2.upperValue, 0);
} else {
return result;
}
}
});
List<RangeClause> finalRangeClause = new ArrayList<>();
RangeClause current = originRangeClause.get(0);
for (int i = 1; i < originRangeClause.size(); i++) {
RangeClause nextClause = originRangeClause.get(i);
if (comparator.compare(nextClause.lowerValue, 0, current.upperValue, 0) > 0) {
finalRangeClause.add(current);
current = nextClause;
} else {
if (comparator.compare(nextClause.upperValue, 0, current.upperValue, 0) > 0) {
current = new RangeClause(current.lowerValue, nextClause.upperValue);
}
}
}
finalRangeClause.add(current);
/**
* in {@link #rewrite} it compares the returned rangeClauses with origin rangeClauses to decide
* if rewrite should return a new query or the origin query
*/
if (finalRangeClause.size() != rangeClauses.size()) {
return finalRangeClause;
} else {
return rangeClauses;
}
}

/*
* TODO: Organize ranges similar to how EdgeTree does, to avoid linear scan of ranges
*/
Expand Down Expand Up @@ -314,6 +388,38 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}

@Override
public int count(LeafReaderContext context) throws IOException {
if (numDims != 1 || context.reader().hasDeletions() == true) {
return super.count(context);
}
PointValues pointValues = context.reader().getPointValues(field);
if (pointValues == null || pointValues.size() != pointValues.getDocCount()) {
return super.count(context);
}
int total = 0;
for (RangeClause rangeClause : rangeClauses) {
PointRangeQuery pointRangeQuery =
new PointRangeQuery(field, rangeClause.lowerValue, rangeClause.upperValue, numDims) {
@Override
protected String toString(int dimension, byte[] value) {
return MultiRangeQuery.this.toString(dimension, value);
}
};
int count =
pointRangeQuery
.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f)
.count(context);

if (count != -1) {
total += count;
} else {
return super.count(context);
}
}
return total;
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import java.io.IOException;
import java.util.Random;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
Expand All @@ -33,8 +34,10 @@
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollectorManager;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.QueryUtils;
Expand Down Expand Up @@ -761,4 +764,90 @@ public void testEqualsAndHashCode() {
assertNotEquals(query1.hashCode(), query3.hashCode());
}
}

private void addRandomDocs(RandomIndexWriter w) throws IOException {
Random random = random();
for (int i = 0, end = random.nextInt(100, 500); i < end; i++) {
int numPoints = RandomNumbers.randomIntBetween(random(), 1, 200);
long value = RandomNumbers.randomLongBetween(random(), 0, 2000);
for (int j = 0; j < numPoints; j++) {
Document doc = new Document();
doc.add(new LongPoint("point", value));
w.addDocument(doc);
}
}
w.flush();
w.forceMerge(1);
}

/** The hit doc count of the rewritten query should be the same as origin query's */
public void testRandomRewrite() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
int dims = 1;
addRandomDocs(w);

IndexReader reader = w.getReader();
IndexSearcher searcher = newSearcher(reader);
int numIters = atLeast(100);

for (int n = 0; n < numIters; n++) {
int numRanges = RandomNumbers.randomIntBetween(random(), 1, 20);
LongPointMultiRangeBuilder builder1 = new LongPointMultiRangeBuilder("point", dims);
BooleanQuery.Builder builder2 = new BooleanQuery.Builder();
for (int i = 0; i < numRanges; i++) {
long[] lower = new long[dims];
long[] upper = new long[dims];
for (int j = 0; j < dims; j++) {
lower[j] = RandomNumbers.randomLongBetween(random(), 0, 2000);
upper[j] = lower[j] + RandomNumbers.randomLongBetween(random(), 0, 2000);
}
builder1.add(lower, upper);
builder2.add(LongPoint.newRangeQuery("point", lower, upper), BooleanClause.Occur.SHOULD);
}

MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
BooleanQuery booleanQuery = builder2.build();
int count = searcher.search(multiRangeQuery, new TotalHitCountCollectorManager());
int booleanCount = searcher.search(booleanQuery, new TotalHitCountCollectorManager());
assertEquals(booleanCount, count);
}
IOUtils.close(reader, w, dir);
}

public void testOneDimensionCount() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
int dims = 1;
addRandomDocs(w);

IndexReader reader = w.getReader();
IndexSearcher searcher = newSearcher(reader);
int numIters = atLeast(100);
for (int n = 0; n < numIters; n++) {
int numRanges = RandomNumbers.randomIntBetween(random(), 1, 20);
LongPointMultiRangeBuilder builder1 = new LongPointMultiRangeBuilder("point", dims);
BooleanQuery.Builder builder2 = new BooleanQuery.Builder();
for (int i = 0; i < numRanges; i++) {
long[] lower = new long[dims];
long[] upper = new long[dims];
for (int j = 0; j < dims; j++) {
lower[j] = RandomNumbers.randomLongBetween(random(), 0, 2000);
upper[j] = lower[j] + RandomNumbers.randomLongBetween(random(), 0, 2000);
}
builder1.add(lower, upper);
builder2.add(LongPoint.newRangeQuery("point", lower, upper), BooleanClause.Occur.SHOULD);
}

MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
BooleanQuery booleanQuery = builder2.build();
int count =
multiRangeQuery
.createWeight(searcher, ScoreMode.COMPLETE, 1.0f)
.count(searcher.getLeafContexts().get(0));
int booleanCount = searcher.count(booleanQuery);
assertEquals(booleanCount, count);
}
IOUtils.close(reader, w, dir);
}
}