leaves = context.searcher().getIndexReader().leaves();
- long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
- for (LeafReaderContext leaf : leaves) {
- final PointValues values = leaf.reader().getPointValues(fieldName);
- if (values != null) {
- min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0));
- max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0));
- }
- }
-
- if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) {
- return null;
- }
- return new long[] { min, max };
- }
-
- /**
- * Finds the min and max bounds of the field for the segment
- *
- * @return null if the field is empty or not indexed
- */
- private static long[] getSegmentBounds(final LeafReaderContext context, final String fieldName) throws IOException {
- long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
- final PointValues values = context.reader().getPointValues(fieldName);
- if (values != null) {
- min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0));
- max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0));
- }
-
- if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) {
- return null;
- }
- return new long[] { min, max };
- }
-
- /**
- * Gets the min and max bounds of the field for the shard search
- * Depending on the query part, the bounds are computed differently
- *
- * @return null if the processed query not supported by the optimization
- */
- public static long[] getDateHistoAggBounds(final SearchContext context, final String fieldName) throws IOException {
- final Query cq = unwrapIntoConcreteQuery(context.query());
- if (cq instanceof PointRangeQuery) {
- final PointRangeQuery prq = (PointRangeQuery) cq;
- final long[] indexBounds = getShardBounds(context, fieldName);
- if (indexBounds == null) return null;
- return getBoundsWithRangeQuery(prq, fieldName, indexBounds);
- } else if (cq instanceof MatchAllDocsQuery) {
- return getShardBounds(context, fieldName);
- } else if (cq instanceof FieldExistsQuery) {
- // when a range query covers all values of a shard, it will be rewrite field exists query
- if (((FieldExistsQuery) cq).getField().equals(fieldName)) {
- return getShardBounds(context, fieldName);
- }
- }
-
- return null;
- }
-
- private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldName, long[] indexBounds) {
- // Ensure that the query and aggregation are on the same field
- if (prq.getField().equals(fieldName)) {
- // Minimum bound for aggregation is the max between query and global
- long lower = Math.max(NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0), indexBounds[0]);
- // Maximum bound for aggregation is the min between query and global
- long upper = Math.min(NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0), indexBounds[1]);
- if (lower > upper) {
- return null;
- }
- return new long[] { lower, upper };
- }
-
- return null;
- }
-
- /**
- * Context object for fast filter optimization
- *
- * Usage: first set aggregation type, then check isRewriteable, then buildFastFilter
- */
- public static class FastFilterContext {
- private boolean rewriteable = false;
- private boolean rangesBuiltAtShardLevel = false;
-
- private AggregationType aggregationType;
- private final SearchContext context;
-
- private MappedFieldType fieldType;
- private Ranges ranges;
-
- // debug info related fields
- public int leaf;
- public int inner;
- public int segments;
- public int optimizedSegments;
-
- public FastFilterContext(SearchContext context) {
- this.context = context;
- }
-
- public FastFilterContext(SearchContext context, AggregationType aggregationType) {
- this.context = context;
- this.aggregationType = aggregationType;
- }
-
- public AggregationType getAggregationType() {
- return aggregationType;
- }
-
- public void setAggregationType(AggregationType aggregationType) {
- this.aggregationType = aggregationType;
- }
-
- public boolean isRewriteable(final Object parent, final int subAggLength) {
- if (context.maxAggRewriteFilters() == 0) return false;
-
- boolean rewriteable = aggregationType.isRewriteable(parent, subAggLength);
- logger.debug("Fast filter rewriteable: {} for shard {}", rewriteable, context.indexShard().shardId());
- this.rewriteable = rewriteable;
- return rewriteable;
- }
-
- public void buildRanges(MappedFieldType fieldType) throws IOException {
- assert ranges == null : "Ranges should only be built once at shard level, but they are already built";
- this.fieldType = fieldType;
- this.ranges = this.aggregationType.buildRanges(context, fieldType);
- if (ranges != null) {
- logger.debug("Ranges built for shard {}", context.indexShard().shardId());
- rangesBuiltAtShardLevel = true;
- }
- }
-
- private Ranges buildRanges(LeafReaderContext leaf) throws IOException {
- Ranges ranges = this.aggregationType.buildRanges(leaf, context, fieldType);
- if (ranges != null) {
- logger.debug("Ranges built for shard {} segment {}", context.indexShard().shardId(), leaf.ord);
- }
- return ranges;
- }
-
- /**
- * Try to populate the bucket doc counts for aggregation
- *
- * Usage: invoked at segment level — in getLeafCollector of aggregator
- *
- * @param bucketOrd bucket ordinal producer
- * @param incrementDocCount consume the doc_count results for certain ordinal
- */
- public boolean tryFastFilterAggregation(
- final LeafReaderContext ctx,
- final BiConsumer incrementDocCount,
- final Function bucketOrd
- ) throws IOException {
- this.segments++;
- if (!this.rewriteable) {
- return false;
- }
-
- if (ctx.reader().hasDeletions()) return false;
-
- PointValues values = ctx.reader().getPointValues(this.fieldType.name());
- if (values == null) return false;
- // only proceed if every document corresponds to exactly one point
- if (values.getDocCount() != values.size()) return false;
-
- NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME);
- if (docCountValues.nextDoc() != NO_MORE_DOCS) {
- logger.debug(
- "Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization",
- this.context.indexShard().shardId(),
- ctx.ord
- );
- return false;
- }
-
- // even if no ranges built at shard level, we can still perform the optimization
- // when functionally match-all at segment level
- if (!this.rangesBuiltAtShardLevel && !segmentMatchAll(this.context, ctx)) {
- return false;
- }
-
- Ranges ranges = this.ranges;
- if (ranges == null) {
- logger.debug(
- "Shard {} segment {} functionally match all documents. Build the fast filter",
- this.context.indexShard().shardId(),
- ctx.ord
- );
- ranges = this.buildRanges(ctx);
- if (ranges == null) {
- return false;
- }
- }
-
- DebugInfo debugInfo = this.aggregationType.tryFastFilterAggregation(values, ranges, incrementDocCount, bucketOrd);
- this.consumeDebugInfo(debugInfo);
-
- this.optimizedSegments++;
- logger.debug("Fast filter optimization applied to shard {} segment {}", this.context.indexShard().shardId(), ctx.ord);
- logger.debug("crossed leaf nodes: {}, inner nodes: {}", this.leaf, this.inner);
- return true;
- }
-
- private void consumeDebugInfo(DebugInfo debug) {
- leaf += debug.leaf;
- inner += debug.inner;
- }
- }
-
- /**
- * Different types have different pre-conditions, filter building logic, etc.
- */
- interface AggregationType {
- boolean isRewriteable(Object parent, int subAggLength);
-
- Ranges buildRanges(SearchContext ctx, MappedFieldType fieldType) throws IOException;
-
- Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) throws IOException;
-
- DebugInfo tryFastFilterAggregation(
- PointValues values,
- Ranges ranges,
- BiConsumer incrementDocCount,
- Function bucketOrd
- ) throws IOException;
- }
-
- /**
- * For date histogram aggregation
- */
- public static abstract class AbstractDateHistogramAggregationType implements AggregationType {
- private final MappedFieldType fieldType;
- private final boolean missing;
- private final boolean hasScript;
- private LongBounds hardBounds;
-
- public AbstractDateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript) {
- this.fieldType = fieldType;
- this.missing = missing;
- this.hasScript = hasScript;
- }
-
- public AbstractDateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript, LongBounds hardBounds) {
- this(fieldType, missing, hasScript);
- this.hardBounds = hardBounds;
- }
-
- @Override
- public boolean isRewriteable(Object parent, int subAggLength) {
- if (parent == null && subAggLength == 0 && !missing && !hasScript) {
- if (fieldType != null && fieldType instanceof DateFieldMapper.DateFieldType) {
- return fieldType.isSearchable();
- }
- }
- return false;
- }
-
- @Override
- public Ranges buildRanges(SearchContext context, MappedFieldType fieldType) throws IOException {
- long[] bounds = getDateHistoAggBounds(context, fieldType.name());
- logger.debug("Bounds are {} for shard {}", bounds, context.indexShard().shardId());
- return buildRanges(context, bounds);
- }
-
- @Override
- public Ranges buildRanges(LeafReaderContext leaf, SearchContext context, MappedFieldType fieldType) throws IOException {
- long[] bounds = getSegmentBounds(leaf, fieldType.name());
- logger.debug("Bounds are {} for shard {} segment {}", bounds, context.indexShard().shardId(), leaf.ord);
- return buildRanges(context, bounds);
- }
-
- private Ranges buildRanges(SearchContext context, long[] bounds) throws IOException {
- bounds = processHardBounds(bounds);
- if (bounds == null) {
- return null;
- }
- assert bounds[0] <= bounds[1] : "Low bound should be less than high bound";
-
- final Rounding rounding = getRounding(bounds[0], bounds[1]);
- final OptionalLong intervalOpt = Rounding.getInterval(rounding);
- if (intervalOpt.isEmpty()) {
- return null;
- }
- final long interval = intervalOpt.getAsLong();
-
- // process the after key of composite agg
- processAfterKey(bounds, interval);
-
- return FastFilterRewriteHelper.createRangesFromAgg(
- context,
- (DateFieldMapper.DateFieldType) fieldType,
- interval,
- getRoundingPrepared(),
- bounds[0],
- bounds[1]
- );
- }
-
- protected abstract Rounding getRounding(final long low, final long high);
-
- protected abstract Rounding.Prepared getRoundingPrepared();
-
- protected void processAfterKey(long[] bound, long interval) {}
-
- protected long[] processHardBounds(long[] bounds) {
- if (bounds != null) {
- // Update min/max limit if user specified any hard bounds
- if (hardBounds != null) {
- if (hardBounds.getMin() > bounds[0]) {
- bounds[0] = hardBounds.getMin();
- }
- if (hardBounds.getMax() - 1 < bounds[1]) {
- bounds[1] = hardBounds.getMax() - 1; // hard bounds max is exclusive
- }
- if (bounds[0] > bounds[1]) {
- return null;
- }
- }
- }
- return bounds;
- }
-
- public DateFieldMapper.DateFieldType getFieldType() {
- assert fieldType instanceof DateFieldMapper.DateFieldType;
- return (DateFieldMapper.DateFieldType) fieldType;
- }
-
- @Override
- public DebugInfo tryFastFilterAggregation(
- PointValues values,
- Ranges ranges,
- BiConsumer incrementDocCount,
- Function bucketOrd
- ) throws IOException {
- int size = Integer.MAX_VALUE;
- if (this instanceof CompositeAggregator.CompositeAggregationType) {
- size = ((CompositeAggregator.CompositeAggregationType) this).getSize();
- }
-
- DateFieldMapper.DateFieldType fieldType = getFieldType();
- BiConsumer incrementFunc = (activeIndex, docCount) -> {
- long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
- rangeStart = fieldType.convertNanosToMillis(rangeStart);
- long ord = getBucketOrd(bucketOrd.apply(rangeStart));
- incrementDocCount.accept(ord, (long) docCount);
- };
-
- return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
- }
-
- private static long getBucketOrd(long bucketOrd) {
- if (bucketOrd < 0) { // already seen
- bucketOrd = -1 - bucketOrd;
- }
-
- return bucketOrd;
- }
- }
-
- /**
- * For range aggregation
- */
- public static class RangeAggregationType implements AggregationType {
-
- private final ValuesSourceConfig config;
- private final Range[] ranges;
-
- public RangeAggregationType(ValuesSourceConfig config, Range[] ranges) {
- this.config = config;
- this.ranges = ranges;
- }
-
- @Override
- public boolean isRewriteable(Object parent, int subAggLength) {
- if (config.fieldType() == null) return false;
- MappedFieldType fieldType = config.fieldType();
- if (fieldType.isSearchable() == false || !(fieldType instanceof NumericPointEncoder)) return false;
-
- if (parent == null && subAggLength == 0 && config.script() == null && config.missing() == null) {
- if (config.getValuesSource() instanceof ValuesSource.Numeric.FieldData) {
- // ranges are already sorted by from and then to
- // we want ranges not overlapping with each other
- double prevTo = ranges[0].getTo();
- for (int i = 1; i < ranges.length; i++) {
- if (prevTo > ranges[i].getFrom()) {
- return false;
- }
- prevTo = ranges[i].getTo();
- }
- return true;
- }
- }
- return false;
- }
-
- @Override
- public Ranges buildRanges(SearchContext context, MappedFieldType fieldType) {
- assert fieldType instanceof NumericPointEncoder;
- NumericPointEncoder numericPointEncoder = (NumericPointEncoder) fieldType;
- byte[][] lowers = new byte[ranges.length][];
- byte[][] uppers = new byte[ranges.length][];
- for (int i = 0; i < ranges.length; i++) {
- double rangeMin = ranges[i].getFrom();
- double rangeMax = ranges[i].getTo();
- byte[] lower = numericPointEncoder.encodePoint(rangeMin);
- byte[] upper = numericPointEncoder.encodePoint(rangeMax);
- lowers[i] = lower;
- uppers[i] = upper;
- }
-
- return new Ranges(lowers, uppers);
- }
-
- @Override
- public Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) {
- throw new UnsupportedOperationException("Range aggregation should not build ranges at segment level");
- }
-
- @Override
- public DebugInfo tryFastFilterAggregation(
- PointValues values,
- Ranges ranges,
- BiConsumer incrementDocCount,
- Function bucketOrd
- ) throws IOException {
- int size = Integer.MAX_VALUE;
-
- BiConsumer incrementFunc = (activeIndex, docCount) -> {
- long ord = bucketOrd.apply(activeIndex);
- incrementDocCount.accept(ord, (long) docCount);
- };
-
- return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
- }
- }
-
- public static boolean isCompositeAggRewriteable(CompositeValuesSourceConfig[] sourceConfigs) {
- return sourceConfigs.length == 1 && sourceConfigs[0].valuesSource() instanceof RoundingValuesSource;
- }
-
- private static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException {
- Weight weight = ctx.searcher().createWeight(ctx.query(), ScoreMode.COMPLETE_NO_SCORES, 1f);
- return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs();
- }
-
- /**
- * Creates the date ranges from date histo aggregations using its interval,
- * and min/max boundaries
- */
- private static Ranges createRangesFromAgg(
- final SearchContext context,
- final DateFieldMapper.DateFieldType fieldType,
- final long interval,
- final Rounding.Prepared preparedRounding,
- long low,
- final long high
- ) {
- // Calculate the number of buckets using range and interval
- long roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
- long prevRounded = roundedLow;
- int bucketCount = 0;
- while (roundedLow <= fieldType.convertNanosToMillis(high)) {
- bucketCount++;
- int maxNumFilterBuckets = context.maxAggRewriteFilters();
- if (bucketCount > maxNumFilterBuckets) {
- logger.debug("Max number of filters reached [{}], skip the fast filter optimization", maxNumFilterBuckets);
- return null;
- }
- // Below rounding is needed as the interval could return in
- // non-rounded values for something like calendar month
- roundedLow = preparedRounding.round(roundedLow + interval);
- if (prevRounded == roundedLow) break; // prevents getting into an infinite loop
- prevRounded = roundedLow;
- }
-
- long[][] ranges = new long[bucketCount][2];
- if (bucketCount > 0) {
- roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
-
- int i = 0;
- while (i < bucketCount) {
- // Calculate the lower bucket bound
- long lower = i == 0 ? low : fieldType.convertRoundedMillisToNanos(roundedLow);
- roundedLow = preparedRounding.round(roundedLow + interval);
-
- // plus one on high value because upper bound is exclusive, but high value exists
- long upper = i + 1 == bucketCount ? high + 1 : fieldType.convertRoundedMillisToNanos(roundedLow);
-
- ranges[i][0] = lower;
- ranges[i][1] = upper;
- i++;
- }
- }
-
- byte[][] lowers = new byte[ranges.length][];
- byte[][] uppers = new byte[ranges.length][];
- for (int i = 0; i < ranges.length; i++) {
- byte[] lower = LONG.encodePoint(ranges[i][0]);
- byte[] max = LONG.encodePoint(ranges[i][1]);
- lowers[i] = lower;
- uppers[i] = max;
- }
-
- return new Ranges(lowers, uppers);
- }
-
- /**
- * @param maxNumNonZeroRanges the number of non-zero ranges to collect
- */
- private static DebugInfo multiRangesTraverse(
- final PointValues.PointTree tree,
- final Ranges ranges,
- final BiConsumer incrementDocCount,
- final int maxNumNonZeroRanges
- ) throws IOException {
- DebugInfo debugInfo = new DebugInfo();
- int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue());
- if (activeIndex < 0) {
- logger.debug("No ranges match the query, skip the fast filter optimization");
- return debugInfo;
- }
- RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, maxNumNonZeroRanges, ranges, activeIndex);
- PointValues.IntersectVisitor visitor = getIntersectVisitor(collector);
- try {
- intersectWithRanges(visitor, tree, collector, debugInfo);
- } catch (CollectionTerminatedException e) {
- logger.debug("Early terminate since no more range to collect");
- }
- collector.finalizePreviousRange();
-
- return debugInfo;
- }
-
- private static class Ranges {
- byte[][] lowers; // inclusive
- byte[][] uppers; // exclusive
- int size;
- int byteLen;
- static ArrayUtil.ByteArrayComparator comparator;
-
- Ranges(byte[][] lowers, byte[][] uppers) {
- this.lowers = lowers;
- this.uppers = uppers;
- assert lowers.length == uppers.length;
- this.size = lowers.length;
- this.byteLen = lowers[0].length;
- comparator = ArrayUtil.getUnsignedComparator(byteLen);
- }
-
- public int firstRangeIndex(byte[] globalMin, byte[] globalMax) {
- if (compareByteValue(lowers[0], globalMax) > 0) {
- return -1;
- }
- int i = 0;
- while (compareByteValue(uppers[i], globalMin) <= 0) {
- i++;
- if (i >= size) {
- return -1;
- }
- }
- return i;
- }
-
- public static int compareByteValue(byte[] value1, byte[] value2) {
- return comparator.compare(value1, 0, value2, 0);
- }
-
- public static boolean withinLowerBound(byte[] value, byte[] lowerBound) {
- return compareByteValue(value, lowerBound) >= 0;
- }
-
- public static boolean withinUpperBound(byte[] value, byte[] upperBound) {
- return compareByteValue(value, upperBound) < 0;
- }
- }
-
- private static void intersectWithRanges(
- PointValues.IntersectVisitor visitor,
- PointValues.PointTree pointTree,
- RangeCollectorForPointTree collector,
- DebugInfo debug
- ) throws IOException {
- PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
-
- switch (r) {
- case CELL_INSIDE_QUERY:
- collector.countNode((int) pointTree.size());
- debug.visitInner();
- break;
- case CELL_CROSSES_QUERY:
- if (pointTree.moveToChild()) {
- do {
- intersectWithRanges(visitor, pointTree, collector, debug);
- } while (pointTree.moveToSibling());
- pointTree.moveToParent();
- } else {
- pointTree.visitDocValues(visitor);
- debug.visitLeaf();
- }
- break;
- case CELL_OUTSIDE_QUERY:
- }
- }
-
- private static PointValues.IntersectVisitor getIntersectVisitor(RangeCollectorForPointTree collector) {
- return new PointValues.IntersectVisitor() {
- @Override
- public void visit(int docID) throws IOException {
- // this branch should be unreachable
- throw new UnsupportedOperationException(
- "This IntersectVisitor does not perform any actions on a " + "docID=" + docID + " node being visited"
- );
- }
-
- @Override
- public void visit(int docID, byte[] packedValue) throws IOException {
- visitPoints(packedValue, collector::count);
- }
-
- @Override
- public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
- visitPoints(packedValue, () -> {
- for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) {
- collector.count();
- }
- });
- }
-
- private void visitPoints(byte[] packedValue, CheckedRunnable collect) throws IOException {
- if (!collector.withinUpperBound(packedValue)) {
- collector.finalizePreviousRange();
- if (collector.iterateRangeEnd(packedValue)) {
- throw new CollectionTerminatedException();
- }
- }
-
- if (collector.withinRange(packedValue)) {
- collect.run();
- }
- }
-
- @Override
- public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
- // try to find the first range that may collect values from this cell
- if (!collector.withinUpperBound(minPackedValue)) {
- collector.finalizePreviousRange();
- if (collector.iterateRangeEnd(minPackedValue)) {
- throw new CollectionTerminatedException();
- }
- }
- // after the loop, min < upper
- // cell could be outside [min max] lower
- if (!collector.withinLowerBound(maxPackedValue)) {
- return PointValues.Relation.CELL_OUTSIDE_QUERY;
- }
- if (collector.withinRange(minPackedValue) && collector.withinRange(maxPackedValue)) {
- return PointValues.Relation.CELL_INSIDE_QUERY;
- }
- return PointValues.Relation.CELL_CROSSES_QUERY;
- }
- };
- }
-
- private static class RangeCollectorForPointTree {
- private final BiConsumer incrementRangeDocCount;
- private int counter = 0;
-
- private final Ranges ranges;
- private int activeIndex;
-
- private int visitedRange = 0;
- private final int maxNumNonZeroRange;
-
- public RangeCollectorForPointTree(
- BiConsumer incrementRangeDocCount,
- int maxNumNonZeroRange,
- Ranges ranges,
- int activeIndex
- ) {
- this.incrementRangeDocCount = incrementRangeDocCount;
- this.maxNumNonZeroRange = maxNumNonZeroRange;
- this.ranges = ranges;
- this.activeIndex = activeIndex;
- }
-
- private void count() {
- counter++;
- }
-
- private void countNode(int count) {
- counter += count;
- }
-
- private void finalizePreviousRange() {
- if (counter > 0) {
- incrementRangeDocCount.accept(activeIndex, counter);
- counter = 0;
- }
- }
-
- /**
- * @return true when iterator exhausted or collect enough non-zero ranges
- */
- private boolean iterateRangeEnd(byte[] value) {
- // the new value may not be contiguous to the previous one
- // so try to find the first next range that cross the new value
- while (!withinUpperBound(value)) {
- if (++activeIndex >= ranges.size) {
- return true;
- }
- }
- visitedRange++;
- return visitedRange > maxNumNonZeroRange;
- }
-
- private boolean withinLowerBound(byte[] value) {
- return Ranges.withinLowerBound(value, ranges.lowers[activeIndex]);
- }
-
- private boolean withinUpperBound(byte[] value) {
- return Ranges.withinUpperBound(value, ranges.uppers[activeIndex]);
- }
-
- private boolean withinRange(byte[] value) {
- return withinLowerBound(value) && withinUpperBound(value);
- }
- }
-
- /**
- * Contains debug info of BKD traversal to show in profile
- */
- private static class DebugInfo {
- private int leaf = 0; // leaf node visited
- private int inner = 0; // inner node visited
-
- private void visitLeaf() {
- leaf++;
- }
-
- private void visitInner() {
- inner++;
- }
- }
-}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java
index bfb484dcf478d..cfe716eb57ca8 100644
--- a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java
@@ -73,8 +73,8 @@
import org.opensearch.search.aggregations.MultiBucketCollector;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
-import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
-import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper.AbstractDateHistogramAggregationType;
+import org.opensearch.search.aggregations.bucket.filterrewrite.CompositeAggregatorBridge;
+import org.opensearch.search.aggregations.bucket.filterrewrite.FilterRewriteOptimizationContext;
import org.opensearch.search.aggregations.bucket.missing.MissingOrder;
import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.opensearch.search.internal.SearchContext;
@@ -89,13 +89,15 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
+import java.util.function.Function;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import static org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
+import static org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll;
/**
- * Main aggregator that aggregates docs from mulitple aggregations
+ * Main aggregator that aggregates docs from multiple aggregations
*
* @opensearch.internal
*/
@@ -118,9 +120,8 @@ public final class CompositeAggregator extends BucketsAggregator {
private boolean earlyTerminated;
- private final FastFilterRewriteHelper.FastFilterContext fastFilterContext;
- private LongKeyedBucketOrds bucketOrds = null;
- private Rounding.Prepared preparedRounding = null;
+ private final FilterRewriteOptimizationContext filterRewriteOptimizationContext;
+ private LongKeyedBucketOrds bucketOrds;
CompositeAggregator(
String name,
@@ -166,57 +167,62 @@ public final class CompositeAggregator extends BucketsAggregator {
this.queue = new CompositeValuesCollectorQueue(context.bigArrays(), sources, size, rawAfterKey);
this.rawAfterKey = rawAfterKey;
- fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(context);
- if (!FastFilterRewriteHelper.isCompositeAggRewriteable(sourceConfigs)) {
- return;
- }
- fastFilterContext.setAggregationType(new CompositeAggregationType());
- if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
- // bucketOrds is used for saving date histogram results
- bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE);
- preparedRounding = ((CompositeAggregationType) fastFilterContext.getAggregationType()).getRoundingPrepared();
- fastFilterContext.buildRanges(sourceConfigs[0].fieldType());
- }
- }
+ CompositeAggregatorBridge bridge = new CompositeAggregatorBridge() {
+ private RoundingValuesSource valuesSource;
+ private long afterKey = -1L;
- /**
- * Currently the filter rewrite is only supported for date histograms
- */
- public class CompositeAggregationType extends AbstractDateHistogramAggregationType {
- private final RoundingValuesSource valuesSource;
- private long afterKey = -1L;
-
- public CompositeAggregationType() {
- super(sourceConfigs[0].fieldType(), sourceConfigs[0].missingBucket(), sourceConfigs[0].hasScript());
- this.valuesSource = (RoundingValuesSource) sourceConfigs[0].valuesSource();
- if (rawAfterKey != null) {
- assert rawAfterKey.size() == 1 && formats.size() == 1;
- this.afterKey = formats.get(0).parseLong(rawAfterKey.get(0).toString(), false, () -> {
- throw new IllegalArgumentException("now() is not supported in [after] key");
- });
+ @Override
+ protected boolean canOptimize() {
+ if (canOptimize(sourceConfigs)) {
+ this.valuesSource = (RoundingValuesSource) sourceConfigs[0].valuesSource();
+ if (rawAfterKey != null) {
+ assert rawAfterKey.size() == 1 && formats.size() == 1;
+ this.afterKey = formats.get(0).parseLong(rawAfterKey.get(0).toString(), false, () -> {
+ throw new IllegalArgumentException("now() is not supported in [after] key");
+ });
+ }
+
+ // bucketOrds is used for saving the date histogram results got from the optimization path
+ bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE);
+ return true;
+ }
+ return false;
}
- }
- public Rounding getRounding(final long low, final long high) {
- return valuesSource.getRounding();
- }
+ @Override
+ protected void prepare() throws IOException {
+ buildRanges(context);
+ }
- public Rounding.Prepared getRoundingPrepared() {
- return valuesSource.getPreparedRounding();
- }
+ protected Rounding getRounding(final long low, final long high) {
+ return valuesSource.getRounding();
+ }
- @Override
- protected void processAfterKey(long[] bound, long interval) {
- // afterKey is the last bucket key in previous response, and the bucket key
- // is the minimum of all values in the bucket, so need to add the interval
- if (afterKey != -1L) {
- bound[0] = afterKey + interval;
+ protected Rounding.Prepared getRoundingPrepared() {
+ return valuesSource.getPreparedRounding();
}
- }
- public int getSize() {
- return size;
- }
+ @Override
+ protected long[] processAfterKey(long[] bounds, long interval) {
+ // afterKey is the last bucket key in previous response, and the bucket key
+ // is the minimum of all values in the bucket, so need to add the interval
+ if (afterKey != -1L) {
+ bounds[0] = afterKey + interval;
+ }
+ return bounds;
+ }
+
+ @Override
+ protected int getSize() {
+ return size;
+ }
+
+ @Override
+ protected Function bucketOrdProducer() {
+ return (key) -> bucketOrds.add(0, getRoundingPrepared().round((long) key));
+ }
+ };
+ filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context);
}
@Override
@@ -368,7 +374,7 @@ private boolean isMaybeMultivalued(LeafReaderContext context, SortField sortFiel
return v2 != null && DocValues.unwrapSingleton(v2) == null;
default:
- // we have no clue whether the field is multi-valued or not so we assume it is.
+ // we have no clue whether the field is multivalued or not so we assume it is.
return true;
}
}
@@ -551,11 +557,7 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
- boolean optimized = fastFilterContext.tryFastFilterAggregation(
- ctx,
- this::incrementBucketDocCount,
- (key) -> bucketOrds.add(0, preparedRounding.round((long) key))
- );
+ boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();
finishLeaf();
@@ -709,11 +711,6 @@ private static class Entry {
@Override
public void collectDebugInfo(BiConsumer add) {
- if (fastFilterContext.optimizedSegments > 0) {
- add.accept("optimized_segments", fastFilterContext.optimizedSegments);
- add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments);
- add.accept("leaf_visited", fastFilterContext.leaf);
- add.accept("inner_visited", fastFilterContext.inner);
- }
+ filterRewriteOptimizationContext.populateDebugInfo(add);
}
}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/AggregatorBridge.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/AggregatorBridge.java
new file mode 100644
index 0000000000000..6b34582b259ea
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/AggregatorBridge.java
@@ -0,0 +1,84 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.aggregations.bucket.filterrewrite;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PointValues;
+import org.opensearch.index.mapper.MappedFieldType;
+
+import java.io.IOException;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+
+/**
+ * This interface provides a bridge between an aggregator and the optimization context, allowing
+ * the aggregator to provide data and optimize the aggregation process.
+ *
+ * The main purpose of this interface is to encapsulate the aggregator-specific optimization
+ * logic and provide access to the data in Aggregator that is required for optimization, while keeping the optimization
+ * business logic separate from the aggregator implementation.
+ *
+ *
To use this interface to optimize an aggregator, you should subclass this interface in this package
+ * and put any specific optimization business logic in it. Then implement this subclass in the aggregator
+ * to provide data that is needed for doing the optimization
+ *
+ * @opensearch.internal
+ */
+public abstract class AggregatorBridge {
+
+ /**
+ * The field type associated with this aggregator bridge.
+ */
+ MappedFieldType fieldType;
+
+ Consumer setRanges;
+
+ void setRangesConsumer(Consumer setRanges) {
+ this.setRanges = setRanges;
+ }
+
+ /**
+ * Checks whether the aggregator can be optimized.
+ *
+ * This method is supposed to be implemented in a specific aggregator to take in fields from there
+ *
+ * @return {@code true} if the aggregator can be optimized, {@code false} otherwise.
+ * The result will be saved in the optimization context.
+ */
+ protected abstract boolean canOptimize();
+
+ /**
+ * Prepares the optimization at shard level after checking aggregator is optimizable.
+ *
+ * For example, figure out what are the ranges from the aggregation to do the optimization later
+ *
+ * This method is supposed to be implemented in a specific aggregator to take in fields from there
+ */
+ protected abstract void prepare() throws IOException;
+
+ /**
+ * Prepares the optimization for a specific segment when the segment is functionally matching all docs
+ *
+ * @param leaf the leaf reader context for the segment
+ */
+ abstract Ranges tryBuildRangesFromSegment(LeafReaderContext leaf) throws IOException;
+
+ /**
+ * Attempts to build aggregation results for a segment
+ *
+ * @param values the point values (index structure for numeric values) for a segment
+ * @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket
+ * @param ranges
+ */
+ abstract FilterRewriteOptimizationContext.DebugInfo tryOptimize(
+ PointValues values,
+ BiConsumer incrementDocCount,
+ Ranges ranges
+ ) throws IOException;
+}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/CompositeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/CompositeAggregatorBridge.java
new file mode 100644
index 0000000000000..e122e7bda0b6a
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/CompositeAggregatorBridge.java
@@ -0,0 +1,36 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.aggregations.bucket.filterrewrite;
+
+import org.opensearch.index.mapper.DateFieldMapper;
+import org.opensearch.index.mapper.MappedFieldType;
+import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig;
+import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource;
+
+/**
+ * For composite aggregation to do optimization when it only has a single date histogram source
+ */
+public abstract class CompositeAggregatorBridge extends DateHistogramAggregatorBridge {
+ protected boolean canOptimize(CompositeValuesSourceConfig[] sourceConfigs) {
+ if (sourceConfigs.length != 1 || !(sourceConfigs[0].valuesSource() instanceof RoundingValuesSource)) return false;
+ return canOptimize(sourceConfigs[0].missingBucket(), sourceConfigs[0].hasScript(), sourceConfigs[0].fieldType());
+ }
+
+ private boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType fieldType) {
+ if (!missing && !hasScript) {
+ if (fieldType instanceof DateFieldMapper.DateFieldType) {
+ if (fieldType.isSearchable()) {
+ this.fieldType = fieldType;
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/DateHistogramAggregatorBridge.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/DateHistogramAggregatorBridge.java
new file mode 100644
index 0000000000000..8bff3fdc5fefb
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/DateHistogramAggregatorBridge.java
@@ -0,0 +1,174 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.aggregations.bucket.filterrewrite;
+
+import org.apache.lucene.document.LongPoint;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PointValues;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
+import org.opensearch.common.Rounding;
+import org.opensearch.index.mapper.DateFieldMapper;
+import org.opensearch.index.mapper.MappedFieldType;
+import org.opensearch.search.aggregations.bucket.histogram.LongBounds;
+import org.opensearch.search.aggregations.support.ValuesSourceConfig;
+import org.opensearch.search.internal.SearchContext;
+
+import java.io.IOException;
+import java.util.OptionalLong;
+import java.util.function.BiConsumer;
+import java.util.function.Function;
+
+import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;
+
+/**
+ * For date histogram aggregation
+ */
+public abstract class DateHistogramAggregatorBridge extends AggregatorBridge {
+
+ int maxRewriteFilters;
+
+ protected boolean canOptimize(ValuesSourceConfig config) {
+ if (config.script() == null && config.missing() == null) {
+ MappedFieldType fieldType = config.fieldType();
+ if (fieldType instanceof DateFieldMapper.DateFieldType) {
+ if (fieldType.isSearchable()) {
+ this.fieldType = fieldType;
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ protected void buildRanges(SearchContext context) throws IOException {
+ long[] bounds = Helper.getDateHistoAggBounds(context, fieldType.name());
+ this.maxRewriteFilters = context.maxAggRewriteFilters();
+ setRanges.accept(buildRanges(bounds, maxRewriteFilters));
+ }
+
+ @Override
+ final Ranges tryBuildRangesFromSegment(LeafReaderContext leaf) throws IOException {
+ long[] bounds = Helper.getSegmentBounds(leaf, fieldType.name());
+ return buildRanges(bounds, maxRewriteFilters);
+ }
+
+ private Ranges buildRanges(long[] bounds, int maxRewriteFilters) {
+ bounds = processHardBounds(bounds);
+ if (bounds == null) {
+ return null;
+ }
+ assert bounds[0] <= bounds[1] : "Low bound should be less than high bound";
+
+ final Rounding rounding = getRounding(bounds[0], bounds[1]);
+ final OptionalLong intervalOpt = Rounding.getInterval(rounding);
+ if (intervalOpt.isEmpty()) {
+ return null;
+ }
+ final long interval = intervalOpt.getAsLong();
+
+ // process the after key of composite agg
+ bounds = processAfterKey(bounds, interval);
+
+ return Helper.createRangesFromAgg(
+ (DateFieldMapper.DateFieldType) fieldType,
+ interval,
+ getRoundingPrepared(),
+ bounds[0],
+ bounds[1],
+ maxRewriteFilters
+ );
+ }
+
+ protected abstract Rounding getRounding(final long low, final long high);
+
+ protected abstract Rounding.Prepared getRoundingPrepared();
+
+ protected long[] processAfterKey(long[] bounds, long interval) {
+ return bounds;
+ }
+
+ protected long[] processHardBounds(long[] bounds) {
+ return processHardBounds(bounds, null);
+ }
+
+ protected long[] processHardBounds(long[] bounds, LongBounds hardBounds) {
+ if (bounds != null) {
+ // Update min/max limit if user specified any hard bounds
+ if (hardBounds != null) {
+ if (hardBounds.getMin() > bounds[0]) {
+ bounds[0] = hardBounds.getMin();
+ }
+ if (hardBounds.getMax() - 1 < bounds[1]) {
+ bounds[1] = hardBounds.getMax() - 1; // hard bounds max is exclusive
+ }
+ if (bounds[0] > bounds[1]) {
+ return null;
+ }
+ }
+ }
+ return bounds;
+ }
+
+ private DateFieldMapper.DateFieldType getFieldType() {
+ assert fieldType instanceof DateFieldMapper.DateFieldType;
+ return (DateFieldMapper.DateFieldType) fieldType;
+ }
+
+ protected int getSize() {
+ return Integer.MAX_VALUE;
+ }
+
+ @Override
+ final FilterRewriteOptimizationContext.DebugInfo tryOptimize(
+ PointValues values,
+ BiConsumer incrementDocCount,
+ Ranges ranges
+ ) throws IOException {
+ int size = getSize();
+
+ DateFieldMapper.DateFieldType fieldType = getFieldType();
+ BiConsumer incrementFunc = (activeIndex, docCount) -> {
+ long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
+ rangeStart = fieldType.convertNanosToMillis(rangeStart);
+ long bucketOrd = getBucketOrd(bucketOrdProducer().apply(rangeStart));
+ incrementDocCount.accept(bucketOrd, (long) docCount);
+ };
+
+ return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
+ }
+
+ private static long getBucketOrd(long bucketOrd) {
+ if (bucketOrd < 0) { // already seen
+ bucketOrd = -1 - bucketOrd;
+ }
+
+ return bucketOrd;
+ }
+
+ /**
+ * Provides a function to produce bucket ordinals from the lower bound of the range
+ */
+ protected abstract Function bucketOrdProducer();
+
+ /**
+ * Checks whether the top level query matches all documents on the segment
+ *
+ * This method creates a weight from the search context's query and checks whether the weight's
+ * document count matches the total number of documents in the leaf reader context.
+ *
+ * @param ctx the search context
+ * @param leafCtx the leaf reader context for the segment
+ * @return {@code true} if the segment matches all documents, {@code false} otherwise
+ */
+ public static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException {
+ Weight weight = ctx.query().rewrite(ctx.searcher()).createWeight(ctx.searcher(), ScoreMode.COMPLETE_NO_SCORES, 1f);
+ return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs();
+ }
+}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteOptimizationContext.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteOptimizationContext.java
new file mode 100644
index 0000000000000..87faafe4526de
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteOptimizationContext.java
@@ -0,0 +1,189 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.aggregations.bucket.filterrewrite;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.lucene.index.DocValues;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.index.PointValues;
+import org.opensearch.index.mapper.DocCountFieldMapper;
+import org.opensearch.search.internal.SearchContext;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+/**
+ * Context object for doing the filter rewrite optimization in ranges type aggregation
+ *
+ * This holds the common business logic and delegate aggregator-specific logic to {@link AggregatorBridge}
+ *
+ * @opensearch.internal
+ */
+public final class FilterRewriteOptimizationContext {
+
+ private static final Logger logger = LogManager.getLogger(Helper.loggerName);
+
+ private final boolean canOptimize;
+ private boolean preparedAtShardLevel = false;
+
+ private final AggregatorBridge aggregatorBridge;
+ private String shardId;
+
+ private Ranges ranges; // built at shard level
+
+ // debug info related fields
+ private final AtomicInteger leafNodeVisited = new AtomicInteger();
+ private final AtomicInteger innerNodeVisited = new AtomicInteger();
+ private final AtomicInteger segments = new AtomicInteger();
+ private final AtomicInteger optimizedSegments = new AtomicInteger();
+
+ public FilterRewriteOptimizationContext(
+ AggregatorBridge aggregatorBridge,
+ final Object parent,
+ final int subAggLength,
+ SearchContext context
+ ) throws IOException {
+ this.aggregatorBridge = aggregatorBridge;
+ this.canOptimize = this.canOptimize(parent, subAggLength, context);
+ }
+
+ /**
+ * common logic for checking whether the optimization can be applied and prepare at shard level
+ * if the aggregation has any special logic, it should be done using {@link AggregatorBridge}
+ */
+ private boolean canOptimize(final Object parent, final int subAggLength, SearchContext context) throws IOException {
+ if (context.maxAggRewriteFilters() == 0) return false;
+
+ if (parent != null || subAggLength != 0) return false;
+
+ boolean canOptimize = aggregatorBridge.canOptimize();
+ if (canOptimize) {
+ aggregatorBridge.setRangesConsumer(this::setRanges);
+
+ this.shardId = context.indexShard().shardId().toString();
+
+ assert ranges == null : "Ranges should only be built once at shard level, but they are already built";
+ aggregatorBridge.prepare();
+ if (ranges != null) {
+ preparedAtShardLevel = true;
+ }
+ }
+ logger.debug("Fast filter rewriteable: {} for shard {}", canOptimize, shardId);
+
+ return canOptimize;
+ }
+
+ void setRanges(Ranges ranges) {
+ this.ranges = ranges;
+ }
+
+ /**
+ * Try to populate the bucket doc counts for aggregation
+ *
+ * Usage: invoked at segment level — in getLeafCollector of aggregator
+ *
+ * @param incrementDocCount consume the doc_count results for certain ordinal
+ * @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment
+ */
+ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer incrementDocCount, boolean segmentMatchAll)
+ throws IOException {
+ segments.incrementAndGet();
+ if (!canOptimize) {
+ return false;
+ }
+
+ if (leafCtx.reader().hasDeletions()) return false;
+
+ PointValues values = leafCtx.reader().getPointValues(aggregatorBridge.fieldType.name());
+ if (values == null) return false;
+ // only proceed if every document corresponds to exactly one point
+ if (values.getDocCount() != values.size()) return false;
+
+ NumericDocValues docCountValues = DocValues.getNumeric(leafCtx.reader(), DocCountFieldMapper.NAME);
+ if (docCountValues.nextDoc() != NO_MORE_DOCS) {
+ logger.debug(
+ "Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization",
+ shardId,
+ leafCtx.ord
+ );
+ return false;
+ }
+
+ Ranges ranges = getRanges(leafCtx, segmentMatchAll);
+ if (ranges == null) return false;
+
+ consumeDebugInfo(aggregatorBridge.tryOptimize(values, incrementDocCount, ranges));
+
+ optimizedSegments.incrementAndGet();
+ logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord);
+ logger.debug("Crossed leaf nodes: {}, inner nodes: {}", leafNodeVisited, innerNodeVisited);
+
+ return true;
+ }
+
+ Ranges getRanges(LeafReaderContext leafCtx, boolean segmentMatchAll) {
+ if (!preparedAtShardLevel) {
+ try {
+ return getRangesFromSegment(leafCtx, segmentMatchAll);
+ } catch (IOException e) {
+ logger.warn("Failed to build ranges from segment.", e);
+ return null;
+ }
+ }
+ return ranges;
+ }
+
+ /**
+ * Even when ranges cannot be built at shard level, we can still build ranges
+ * at segment level when it's functionally match-all at segment level
+ */
+ private Ranges getRangesFromSegment(LeafReaderContext leafCtx, boolean segmentMatchAll) throws IOException {
+ if (!segmentMatchAll) {
+ return null;
+ }
+
+ logger.debug("Shard {} segment {} functionally match all documents. Build the fast filter", shardId, leafCtx.ord);
+ return aggregatorBridge.tryBuildRangesFromSegment(leafCtx);
+ }
+
+ /**
+ * Contains debug info of BKD traversal to show in profile
+ */
+ static class DebugInfo {
+ private final AtomicInteger leafNodeVisited = new AtomicInteger(); // leaf node visited
+ private final AtomicInteger innerNodeVisited = new AtomicInteger(); // inner node visited
+
+ void visitLeaf() {
+ leafNodeVisited.incrementAndGet();
+ }
+
+ void visitInner() {
+ innerNodeVisited.incrementAndGet();
+ }
+ }
+
+ void consumeDebugInfo(DebugInfo debug) {
+ leafNodeVisited.addAndGet(debug.leafNodeVisited.get());
+ innerNodeVisited.addAndGet(debug.innerNodeVisited.get());
+ }
+
+ public void populateDebugInfo(BiConsumer add) {
+ if (optimizedSegments.get() > 0) {
+ add.accept("optimized_segments", optimizedSegments.get());
+ add.accept("unoptimized_segments", segments.get() - optimizedSegments.get());
+ add.accept("leaf_visited", leafNodeVisited.get());
+ add.accept("inner_visited", innerNodeVisited.get());
+ }
+ }
+}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Helper.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Helper.java
new file mode 100644
index 0000000000000..7493754d8efa2
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Helper.java
@@ -0,0 +1,213 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.aggregations.bucket.filterrewrite;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PointValues;
+import org.apache.lucene.search.ConstantScoreQuery;
+import org.apache.lucene.search.FieldExistsQuery;
+import org.apache.lucene.search.IndexOrDocValuesQuery;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.PointRangeQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.util.NumericUtils;
+import org.opensearch.common.Rounding;
+import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
+import org.opensearch.index.mapper.DateFieldMapper;
+import org.opensearch.index.query.DateRangeIncludingNowQuery;
+import org.opensearch.search.internal.SearchContext;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG;
+
+/**
+ * Utility class to help range filters rewrite optimization
+ *
+ * @opensearch.internal
+ */
+final class Helper {
+
+ private Helper() {}
+
+ static final String loggerName = Helper.class.getPackageName();
+ private static final Logger logger = LogManager.getLogger(loggerName);
+
+ private static final Map, Function> queryWrappers;
+
+ // Initialize the wrapper map for unwrapping the query
+ static {
+ queryWrappers = new HashMap<>();
+ queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery) q).getQuery());
+ queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getSubQuery());
+ queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery) q).getQuery());
+ queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery());
+ }
+
+ /**
+ * Recursively unwraps query into the concrete form
+ * for applying the optimization
+ */
+ private static Query unwrapIntoConcreteQuery(Query query) {
+ while (queryWrappers.containsKey(query.getClass())) {
+ query = queryWrappers.get(query.getClass()).apply(query);
+ }
+
+ return query;
+ }
+
+ /**
+ * Finds the global min and max bounds of the field for the shard across all segments
+ *
+ * @return null if the field is empty or not indexed
+ */
+ private static long[] getShardBounds(final List leaves, final String fieldName) throws IOException {
+ long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
+ for (LeafReaderContext leaf : leaves) {
+ final PointValues values = leaf.reader().getPointValues(fieldName);
+ if (values != null) {
+ min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0));
+ max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0));
+ }
+ }
+
+ if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) {
+ return null;
+ }
+ return new long[] { min, max };
+ }
+
+ /**
+ * Finds the min and max bounds of the field for the segment
+ *
+ * @return null if the field is empty or not indexed
+ */
+ static long[] getSegmentBounds(final LeafReaderContext context, final String fieldName) throws IOException {
+ long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
+ final PointValues values = context.reader().getPointValues(fieldName);
+ if (values != null) {
+ min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0));
+ max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0));
+ }
+
+ if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) {
+ return null;
+ }
+ return new long[] { min, max };
+ }
+
+ /**
+ * Gets the min and max bounds of the field for the shard search
+ * Depending on the query part, the bounds are computed differently
+ *
+ * @return null if the processed query not supported by the optimization
+ */
+ public static long[] getDateHistoAggBounds(final SearchContext context, final String fieldName) throws IOException {
+ final Query cq = unwrapIntoConcreteQuery(context.query());
+ final List leaves = context.searcher().getIndexReader().leaves();
+
+ if (cq instanceof PointRangeQuery) {
+ final PointRangeQuery prq = (PointRangeQuery) cq;
+ final long[] indexBounds = getShardBounds(leaves, fieldName);
+ if (indexBounds == null) return null;
+ return getBoundsWithRangeQuery(prq, fieldName, indexBounds);
+ } else if (cq instanceof MatchAllDocsQuery) {
+ return getShardBounds(leaves, fieldName);
+ } else if (cq instanceof FieldExistsQuery) {
+ // when a range query covers all values of a shard, it will be rewrite field exists query
+ if (((FieldExistsQuery) cq).getField().equals(fieldName)) {
+ return getShardBounds(leaves, fieldName);
+ }
+ }
+
+ return null;
+ }
+
+ private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldName, long[] indexBounds) {
+ // Ensure that the query and aggregation are on the same field
+ if (prq.getField().equals(fieldName)) {
+ // Minimum bound for aggregation is the max between query and global
+ long lower = Math.max(NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0), indexBounds[0]);
+ // Maximum bound for aggregation is the min between query and global
+ long upper = Math.min(NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0), indexBounds[1]);
+ if (lower > upper) {
+ return null;
+ }
+ return new long[] { lower, upper };
+ }
+
+ return null;
+ }
+
+ /**
+ * Creates the date ranges from date histo aggregations using its interval,
+ * and min/max boundaries
+ */
+ static Ranges createRangesFromAgg(
+ final DateFieldMapper.DateFieldType fieldType,
+ final long interval,
+ final Rounding.Prepared preparedRounding,
+ long low,
+ final long high,
+ final int maxAggRewriteFilters
+ ) {
+ // Calculate the number of buckets using range and interval
+ long roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
+ long prevRounded = roundedLow;
+ int bucketCount = 0;
+ while (roundedLow <= fieldType.convertNanosToMillis(high)) {
+ bucketCount++;
+ if (bucketCount > maxAggRewriteFilters) {
+ logger.debug("Max number of range filters reached [{}], skip the optimization", maxAggRewriteFilters);
+ return null;
+ }
+ // Below rounding is needed as the interval could return in
+ // non-rounded values for something like calendar month
+ roundedLow = preparedRounding.round(roundedLow + interval);
+ if (prevRounded == roundedLow) break; // prevents getting into an infinite loop
+ prevRounded = roundedLow;
+ }
+
+ long[][] ranges = new long[bucketCount][2];
+ if (bucketCount > 0) {
+ roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
+
+ int i = 0;
+ while (i < bucketCount) {
+ // Calculate the lower bucket bound
+ long lower = i == 0 ? low : fieldType.convertRoundedMillisToNanos(roundedLow);
+ roundedLow = preparedRounding.round(roundedLow + interval);
+
+ // plus one on high value because upper bound is exclusive, but high value exists
+ long upper = i + 1 == bucketCount ? high + 1 : fieldType.convertRoundedMillisToNanos(roundedLow);
+
+ ranges[i][0] = lower;
+ ranges[i][1] = upper;
+ i++;
+ }
+ }
+
+ byte[][] lowers = new byte[ranges.length][];
+ byte[][] uppers = new byte[ranges.length][];
+ for (int i = 0; i < ranges.length; i++) {
+ byte[] lower = LONG.encodePoint(ranges[i][0]);
+ byte[] max = LONG.encodePoint(ranges[i][1]);
+ lowers[i] = lower;
+ uppers[i] = max;
+ }
+
+ return new Ranges(lowers, uppers);
+ }
+}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/PointTreeTraversal.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/PointTreeTraversal.java
new file mode 100644
index 0000000000000..581ecc416f486
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/PointTreeTraversal.java
@@ -0,0 +1,223 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.aggregations.bucket.filterrewrite;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.lucene.index.PointValues;
+import org.apache.lucene.search.CollectionTerminatedException;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.opensearch.common.CheckedRunnable;
+
+import java.io.IOException;
+import java.util.function.BiConsumer;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+/**
+ * Utility class for traversing a {@link PointValues.PointTree} and collecting document counts for the ranges.
+ *
+ * The main entry point is the {@link #multiRangesTraverse(PointValues.PointTree, Ranges,
+ * BiConsumer, int)} method
+ *
+ *
The class uses a {@link RangeCollectorForPointTree} to keep track of the active ranges and
+ * determine which parts of the tree to visit. The {@link
+ * PointValues.IntersectVisitor} implementation is responsible for the actual visitation and
+ * document count collection.
+ */
+final class PointTreeTraversal {
+ private PointTreeTraversal() {}
+
+ private static final Logger logger = LogManager.getLogger(Helper.loggerName);
+
+ /**
+ * Traverses the given {@link PointValues.PointTree} and collects document counts for the intersecting ranges.
+ *
+ * @param tree the point tree to traverse
+ * @param ranges the set of ranges to intersect with
+ * @param incrementDocCount a callback to increment the document count for a range bucket
+ * @param maxNumNonZeroRanges the maximum number of non-zero ranges to collect
+ * @return a {@link FilterRewriteOptimizationContext.DebugInfo} object containing debug information about the traversal
+ */
+ static FilterRewriteOptimizationContext.DebugInfo multiRangesTraverse(
+ final PointValues.PointTree tree,
+ final Ranges ranges,
+ final BiConsumer incrementDocCount,
+ final int maxNumNonZeroRanges
+ ) throws IOException {
+ FilterRewriteOptimizationContext.DebugInfo debugInfo = new FilterRewriteOptimizationContext.DebugInfo();
+ int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue());
+ if (activeIndex < 0) {
+ logger.debug("No ranges match the query, skip the fast filter optimization");
+ return debugInfo;
+ }
+ RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, maxNumNonZeroRanges, ranges, activeIndex);
+ PointValues.IntersectVisitor visitor = getIntersectVisitor(collector);
+ try {
+ intersectWithRanges(visitor, tree, collector, debugInfo);
+ } catch (CollectionTerminatedException e) {
+ logger.debug("Early terminate since no more range to collect");
+ }
+ collector.finalizePreviousRange();
+
+ return debugInfo;
+ }
+
+ private static void intersectWithRanges(
+ PointValues.IntersectVisitor visitor,
+ PointValues.PointTree pointTree,
+ RangeCollectorForPointTree collector,
+ FilterRewriteOptimizationContext.DebugInfo debug
+ ) throws IOException {
+ PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
+
+ switch (r) {
+ case CELL_INSIDE_QUERY:
+ collector.countNode((int) pointTree.size());
+ debug.visitInner();
+ break;
+ case CELL_CROSSES_QUERY:
+ if (pointTree.moveToChild()) {
+ do {
+ intersectWithRanges(visitor, pointTree, collector, debug);
+ } while (pointTree.moveToSibling());
+ pointTree.moveToParent();
+ } else {
+ pointTree.visitDocValues(visitor);
+ debug.visitLeaf();
+ }
+ break;
+ case CELL_OUTSIDE_QUERY:
+ }
+ }
+
+ private static PointValues.IntersectVisitor getIntersectVisitor(RangeCollectorForPointTree collector) {
+ return new PointValues.IntersectVisitor() {
+ @Override
+ public void visit(int docID) {
+ // this branch should be unreachable
+ throw new UnsupportedOperationException(
+ "This IntersectVisitor does not perform any actions on a " + "docID=" + docID + " node being visited"
+ );
+ }
+
+ @Override
+ public void visit(int docID, byte[] packedValue) throws IOException {
+ visitPoints(packedValue, collector::count);
+ }
+
+ @Override
+ public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
+ visitPoints(packedValue, () -> {
+ for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) {
+ collector.count();
+ }
+ });
+ }
+
+ private void visitPoints(byte[] packedValue, CheckedRunnable collect) throws IOException {
+ if (!collector.withinUpperBound(packedValue)) {
+ collector.finalizePreviousRange();
+ if (collector.iterateRangeEnd(packedValue)) {
+ throw new CollectionTerminatedException();
+ }
+ }
+
+ if (collector.withinRange(packedValue)) {
+ collect.run();
+ }
+ }
+
+ @Override
+ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+ // try to find the first range that may collect values from this cell
+ if (!collector.withinUpperBound(minPackedValue)) {
+ collector.finalizePreviousRange();
+ if (collector.iterateRangeEnd(minPackedValue)) {
+ throw new CollectionTerminatedException();
+ }
+ }
+ // after the loop, min < upper
+ // cell could be outside [min max] lower
+ if (!collector.withinLowerBound(maxPackedValue)) {
+ return PointValues.Relation.CELL_OUTSIDE_QUERY;
+ }
+ if (collector.withinRange(minPackedValue) && collector.withinRange(maxPackedValue)) {
+ return PointValues.Relation.CELL_INSIDE_QUERY;
+ }
+ return PointValues.Relation.CELL_CROSSES_QUERY;
+ }
+ };
+ }
+
+ private static class RangeCollectorForPointTree {
+ private final BiConsumer incrementRangeDocCount;
+ private int counter = 0;
+
+ private final Ranges ranges;
+ private int activeIndex;
+
+ private int visitedRange = 0;
+ private final int maxNumNonZeroRange;
+
+ public RangeCollectorForPointTree(
+ BiConsumer incrementRangeDocCount,
+ int maxNumNonZeroRange,
+ Ranges ranges,
+ int activeIndex
+ ) {
+ this.incrementRangeDocCount = incrementRangeDocCount;
+ this.maxNumNonZeroRange = maxNumNonZeroRange;
+ this.ranges = ranges;
+ this.activeIndex = activeIndex;
+ }
+
+ private void count() {
+ counter++;
+ }
+
+ private void countNode(int count) {
+ counter += count;
+ }
+
+ private void finalizePreviousRange() {
+ if (counter > 0) {
+ incrementRangeDocCount.accept(activeIndex, counter);
+ counter = 0;
+ }
+ }
+
+ /**
+ * @return true when iterator exhausted or collect enough non-zero ranges
+ */
+ private boolean iterateRangeEnd(byte[] value) {
+ // the new value may not be contiguous to the previous one
+ // so try to find the first next range that cross the new value
+ while (!withinUpperBound(value)) {
+ if (++activeIndex >= ranges.size) {
+ return true;
+ }
+ }
+ visitedRange++;
+ return visitedRange > maxNumNonZeroRange;
+ }
+
+ private boolean withinLowerBound(byte[] value) {
+ return Ranges.withinLowerBound(value, ranges.lowers[activeIndex]);
+ }
+
+ private boolean withinUpperBound(byte[] value) {
+ return Ranges.withinUpperBound(value, ranges.uppers[activeIndex]);
+ }
+
+ private boolean withinRange(byte[] value) {
+ return withinLowerBound(value) && withinUpperBound(value);
+ }
+ }
+}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/RangeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/RangeAggregatorBridge.java
new file mode 100644
index 0000000000000..b590a444c8b04
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/RangeAggregatorBridge.java
@@ -0,0 +1,96 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.aggregations.bucket.filterrewrite;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PointValues;
+import org.opensearch.index.mapper.MappedFieldType;
+import org.opensearch.index.mapper.NumericPointEncoder;
+import org.opensearch.search.aggregations.bucket.range.RangeAggregator;
+import org.opensearch.search.aggregations.support.ValuesSource;
+import org.opensearch.search.aggregations.support.ValuesSourceConfig;
+
+import java.io.IOException;
+import java.util.function.BiConsumer;
+import java.util.function.Function;
+
+import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;
+
+/**
+ * For range aggregation
+ */
+public abstract class RangeAggregatorBridge extends AggregatorBridge {
+
+ protected boolean canOptimize(ValuesSourceConfig config, RangeAggregator.Range[] ranges) {
+ if (config.fieldType() == null) return false;
+ MappedFieldType fieldType = config.fieldType();
+ assert fieldType != null;
+ if (fieldType.isSearchable() == false || !(fieldType instanceof NumericPointEncoder)) return false;
+
+ if (config.script() == null && config.missing() == null) {
+ if (config.getValuesSource() instanceof ValuesSource.Numeric.FieldData) {
+ // ranges are already sorted by from and then to
+ // we want ranges not overlapping with each other
+ double prevTo = ranges[0].getTo();
+ for (int i = 1; i < ranges.length; i++) {
+ if (prevTo > ranges[i].getFrom()) {
+ return false;
+ }
+ prevTo = ranges[i].getTo();
+ }
+ this.fieldType = config.fieldType();
+ return true;
+ }
+ }
+ return false;
+ }
+
+ protected void buildRanges(RangeAggregator.Range[] ranges) {
+ assert fieldType instanceof NumericPointEncoder;
+ NumericPointEncoder numericPointEncoder = (NumericPointEncoder) fieldType;
+ byte[][] lowers = new byte[ranges.length][];
+ byte[][] uppers = new byte[ranges.length][];
+ for (int i = 0; i < ranges.length; i++) {
+ double rangeMin = ranges[i].getFrom();
+ double rangeMax = ranges[i].getTo();
+ byte[] lower = numericPointEncoder.encodePoint(rangeMin);
+ byte[] upper = numericPointEncoder.encodePoint(rangeMax);
+ lowers[i] = lower;
+ uppers[i] = upper;
+ }
+
+ setRanges.accept(new Ranges(lowers, uppers));
+ }
+
+ @Override
+ final Ranges tryBuildRangesFromSegment(LeafReaderContext leaf) {
+ throw new UnsupportedOperationException("Range aggregation should not build ranges at segment level");
+ }
+
+ @Override
+ final FilterRewriteOptimizationContext.DebugInfo tryOptimize(
+ PointValues values,
+ BiConsumer incrementDocCount,
+ Ranges ranges
+ ) throws IOException {
+ int size = Integer.MAX_VALUE;
+
+ BiConsumer incrementFunc = (activeIndex, docCount) -> {
+ long bucketOrd = bucketOrdProducer().apply(activeIndex);
+ incrementDocCount.accept(bucketOrd, (long) docCount);
+ };
+
+ return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
+ }
+
+ /**
+ * Provides a function to produce bucket ordinals from index of the corresponding range in the range array
+ */
+ protected abstract Function bucketOrdProducer();
+}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Ranges.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Ranges.java
new file mode 100644
index 0000000000000..2819778ce215b
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/Ranges.java
@@ -0,0 +1,57 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.aggregations.bucket.filterrewrite;
+
+import org.apache.lucene.util.ArrayUtil;
+
+/**
+ * Internal ranges representation for the filter rewrite optimization
+ */
+final class Ranges {
+ byte[][] lowers; // inclusive
+ byte[][] uppers; // exclusive
+ int size;
+ int byteLen;
+ static ArrayUtil.ByteArrayComparator comparator;
+
+ Ranges(byte[][] lowers, byte[][] uppers) {
+ this.lowers = lowers;
+ this.uppers = uppers;
+ assert lowers.length == uppers.length;
+ this.size = lowers.length;
+ this.byteLen = lowers[0].length;
+ comparator = ArrayUtil.getUnsignedComparator(byteLen);
+ }
+
+ public int firstRangeIndex(byte[] globalMin, byte[] globalMax) {
+ if (compareByteValue(lowers[0], globalMax) > 0) {
+ return -1;
+ }
+ int i = 0;
+ while (compareByteValue(uppers[i], globalMin) <= 0) {
+ i++;
+ if (i >= size) {
+ return -1;
+ }
+ }
+ return i;
+ }
+
+ public static int compareByteValue(byte[] value1, byte[] value2) {
+ return comparator.compare(value1, 0, value2, 0);
+ }
+
+ public static boolean withinLowerBound(byte[] value, byte[] lowerBound) {
+ return compareByteValue(value, lowerBound) >= 0;
+ }
+
+ public static boolean withinUpperBound(byte[] value, byte[] upperBound) {
+ return compareByteValue(value, upperBound) < 0;
+ }
+}
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/package-info.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/package-info.java
new file mode 100644
index 0000000000000..bbd0a8db6cbb6
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+/**
+ * This package contains filter rewrite optimization for range-type aggregations
+ *
+ * The idea is to
+ *
+ * figure out the "ranges" from the aggregation
+ * leverage the ranges and bkd index to get the result of each range bucket quickly
+ *
+ * More details in https://github.com/opensearch-project/OpenSearch/pull/14464
+ */
+package org.opensearch.search.aggregations.bucket.filterrewrite;
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java
index d13d575a9d696..f3a36b4882d19 100644
--- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java
@@ -42,7 +42,6 @@
import org.opensearch.common.util.IntArray;
import org.opensearch.common.util.LongArray;
import org.opensearch.core.common.util.ByteArray;
-import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
@@ -53,8 +52,9 @@
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator;
import org.opensearch.search.aggregations.bucket.DeferringBucketCollector;
-import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
import org.opensearch.search.aggregations.bucket.MergingBucketsDeferringCollector;
+import org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge;
+import org.opensearch.search.aggregations.bucket.filterrewrite.FilterRewriteOptimizationContext;
import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder.RoundingInfo;
import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.opensearch.search.aggregations.support.ValuesSource;
@@ -64,11 +64,12 @@
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
-import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongToIntFunction;
+import static org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll;
+
/**
* An aggregator for date values that attempts to return a specific number of
* buckets, reconfiguring how it rounds dates to buckets on the fly as new
@@ -135,7 +136,7 @@ static AutoDateHistogramAggregator build(
protected int roundingIdx;
protected Rounding.Prepared preparedRounding;
- private final FastFilterRewriteHelper.FastFilterContext fastFilterContext;
+ private final FilterRewriteOptimizationContext filterRewriteOptimizationContext;
private AutoDateHistogramAggregator(
String name,
@@ -158,53 +159,52 @@ private AutoDateHistogramAggregator(
this.roundingPreparer = roundingPreparer;
this.preparedRounding = prepareRounding(0);
- fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(
- context,
- new AutoHistogramAggregationType(
- valuesSourceConfig.fieldType(),
- valuesSourceConfig.missing() != null,
- valuesSourceConfig.script() != null
- )
- );
- if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
- fastFilterContext.buildRanges(Objects.requireNonNull(valuesSourceConfig.fieldType()));
- }
- }
-
- private class AutoHistogramAggregationType extends FastFilterRewriteHelper.AbstractDateHistogramAggregationType {
+ DateHistogramAggregatorBridge bridge = new DateHistogramAggregatorBridge() {
+ @Override
+ protected boolean canOptimize() {
+ return canOptimize(valuesSourceConfig);
+ }
- public AutoHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript) {
- super(fieldType, missing, hasScript);
- }
+ @Override
+ protected void prepare() throws IOException {
+ buildRanges(context);
+ }
- @Override
- protected Rounding getRounding(final long low, final long high) {
- // max - min / targetBuckets = bestDuration
- // find the right innerInterval this bestDuration belongs to
- // since we cannot exceed targetBuckets, bestDuration should go up,
- // so the right innerInterval should be an upper bound
- long bestDuration = (high - low) / targetBuckets;
- // reset so this function is idempotent
- roundingIdx = 0;
- while (roundingIdx < roundingInfos.length - 1) {
- final RoundingInfo curRoundingInfo = roundingInfos[roundingIdx];
- final int temp = curRoundingInfo.innerIntervals[curRoundingInfo.innerIntervals.length - 1];
- // If the interval duration is covered by the maximum inner interval,
- // we can start with this outer interval for creating the buckets
- if (bestDuration <= temp * curRoundingInfo.roughEstimateDurationMillis) {
- break;
+ @Override
+ protected Rounding getRounding(final long low, final long high) {
+ // max - min / targetBuckets = bestDuration
+ // find the right innerInterval this bestDuration belongs to
+ // since we cannot exceed targetBuckets, bestDuration should go up,
+ // so the right innerInterval should be an upper bound
+ long bestDuration = (high - low) / targetBuckets;
+ // reset so this function is idempotent
+ roundingIdx = 0;
+ while (roundingIdx < roundingInfos.length - 1) {
+ final RoundingInfo curRoundingInfo = roundingInfos[roundingIdx];
+ final int temp = curRoundingInfo.innerIntervals[curRoundingInfo.innerIntervals.length - 1];
+ // If the interval duration is covered by the maximum inner interval,
+ // we can start with this outer interval for creating the buckets
+ if (bestDuration <= temp * curRoundingInfo.roughEstimateDurationMillis) {
+ break;
+ }
+ roundingIdx++;
}
- roundingIdx++;
+
+ preparedRounding = prepareRounding(roundingIdx);
+ return roundingInfos[roundingIdx].rounding;
}
- preparedRounding = prepareRounding(roundingIdx);
- return roundingInfos[roundingIdx].rounding;
- }
+ @Override
+ protected Prepared getRoundingPrepared() {
+ return preparedRounding;
+ }
- @Override
- protected Prepared getRoundingPrepared() {
- return preparedRounding;
- }
+ @Override
+ protected Function bucketOrdProducer() {
+ return (key) -> getBucketOrds().add(0, preparedRounding.round((long) key));
+ }
+ };
+ filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context);
}
protected abstract LongKeyedBucketOrds getBucketOrds();
@@ -236,11 +236,7 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc
return LeafBucketCollector.NO_OP_COLLECTOR;
}
- boolean optimized = fastFilterContext.tryFastFilterAggregation(
- ctx,
- this::incrementBucketDocCount,
- (key) -> getBucketOrds().add(0, preparedRounding.round((long) key))
- );
+ boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();
final SortedNumericDocValues values = valuesSource.longValues(ctx);
@@ -308,12 +304,7 @@ protected final void merge(long[] mergeMap, long newNumBuckets) {
@Override
public void collectDebugInfo(BiConsumer add) {
super.collectDebugInfo(add);
- if (fastFilterContext.optimizedSegments > 0) {
- add.accept("optimized_segments", fastFilterContext.optimizedSegments);
- add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments);
- add.accept("leaf_visited", fastFilterContext.leaf);
- add.accept("inner_visited", fastFilterContext.inner);
- }
+ filterRewriteOptimizationContext.populateDebugInfo(add);
}
/**
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java
index 4b84797c18922..96a49bc3fd5f6 100644
--- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java
@@ -39,7 +39,6 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.Rounding;
import org.opensearch.common.lease.Releasables;
-import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
@@ -49,7 +48,8 @@
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
-import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
+import org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge;
+import org.opensearch.search.aggregations.bucket.filterrewrite.FilterRewriteOptimizationContext;
import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
@@ -58,8 +58,10 @@
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
-import java.util.Objects;
import java.util.function.BiConsumer;
+import java.util.function.Function;
+
+import static org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll;
/**
* An aggregator for date values. Every date is rounded down using a configured
@@ -84,7 +86,7 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
private final LongBounds hardBounds;
private final LongKeyedBucketOrds bucketOrds;
- private final FastFilterRewriteHelper.FastFilterContext fastFilterContext;
+ private final FilterRewriteOptimizationContext filterRewriteOptimizationContext;
DateHistogramAggregator(
String name,
@@ -117,35 +119,38 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), cardinality);
- fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(
- context,
- new DateHistogramAggregationType(
- valuesSourceConfig.fieldType(),
- valuesSourceConfig.missing() != null,
- valuesSourceConfig.script() != null,
- hardBounds
- )
- );
- if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
- fastFilterContext.buildRanges(Objects.requireNonNull(valuesSourceConfig.fieldType()));
- }
- }
+ DateHistogramAggregatorBridge bridge = new DateHistogramAggregatorBridge() {
+ @Override
+ protected boolean canOptimize() {
+ return canOptimize(valuesSourceConfig);
+ }
- private class DateHistogramAggregationType extends FastFilterRewriteHelper.AbstractDateHistogramAggregationType {
+ @Override
+ protected void prepare() throws IOException {
+ buildRanges(context);
+ }
- public DateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript, LongBounds hardBounds) {
- super(fieldType, missing, hasScript, hardBounds);
- }
+ @Override
+ protected Rounding getRounding(long low, long high) {
+ return rounding;
+ }
- @Override
- protected Rounding getRounding(long low, long high) {
- return rounding;
- }
+ @Override
+ protected Rounding.Prepared getRoundingPrepared() {
+ return preparedRounding;
+ }
- @Override
- protected Rounding.Prepared getRoundingPrepared() {
- return preparedRounding;
- }
+ @Override
+ protected long[] processHardBounds(long[] bounds) {
+ return super.processHardBounds(bounds, hardBounds);
+ }
+
+ @Override
+ protected Function bucketOrdProducer() {
+ return (key) -> bucketOrds.add(0, preparedRounding.round((long) key));
+ }
+ };
+ filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context);
}
@Override
@@ -162,11 +167,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
return LeafBucketCollector.NO_OP_COLLECTOR;
}
- boolean optimized = fastFilterContext.tryFastFilterAggregation(
- ctx,
- this::incrementBucketDocCount,
- (key) -> bucketOrds.add(0, preparedRounding.round((long) key))
- );
+ boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();
SortedNumericDocValues values = valuesSource.longValues(ctx);
@@ -253,12 +254,7 @@ public void doClose() {
@Override
public void collectDebugInfo(BiConsumer add) {
add.accept("total_buckets", bucketOrds.size());
- if (fastFilterContext.optimizedSegments > 0) {
- add.accept("optimized_segments", fastFilterContext.optimizedSegments);
- add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments);
- add.accept("leaf_visited", fastFilterContext.leaf);
- add.accept("inner_visited", fastFilterContext.inner);
- }
+ filterRewriteOptimizationContext.populateDebugInfo(add);
}
/**
diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java
index 2ba2b06514de1..17461f228e993 100644
--- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java
+++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java
@@ -55,7 +55,8 @@
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.NonCollectingAggregator;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
-import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
+import org.opensearch.search.aggregations.bucket.filterrewrite.FilterRewriteOptimizationContext;
+import org.opensearch.search.aggregations.bucket.filterrewrite.RangeAggregatorBridge;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
@@ -66,6 +67,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
+import java.util.function.Function;
import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -249,7 +251,7 @@ public boolean equals(Object obj) {
final double[] maxTo;
- private final FastFilterRewriteHelper.FastFilterContext fastFilterContext;
+ private final FilterRewriteOptimizationContext filterRewriteOptimizationContext;
public RangeAggregator(
String name,
@@ -279,13 +281,23 @@ public RangeAggregator(
maxTo[i] = Math.max(this.ranges[i].to, maxTo[i - 1]);
}
- fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(
- context,
- new FastFilterRewriteHelper.RangeAggregationType(config, ranges)
- );
- if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
- fastFilterContext.buildRanges(Objects.requireNonNull(config.fieldType()));
- }
+ RangeAggregatorBridge bridge = new RangeAggregatorBridge() {
+ @Override
+ protected boolean canOptimize() {
+ return canOptimize(config, ranges);
+ }
+
+ @Override
+ protected void prepare() {
+ buildRanges(ranges);
+ }
+
+ @Override
+ protected Function bucketOrdProducer() {
+ return (activeIndex) -> subBucketOrdinal(0, (int) activeIndex);
+ }
+ };
+ filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context);
}
@Override
@@ -298,11 +310,7 @@ public ScoreMode scoreMode() {
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
- boolean optimized = fastFilterContext.tryFastFilterAggregation(
- ctx,
- this::incrementBucketDocCount,
- (activeIndex) -> subBucketOrdinal(0, (int) activeIndex)
- );
+ boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
if (optimized) throw new CollectionTerminatedException();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
@@ -452,11 +460,6 @@ public InternalAggregation buildEmptyAggregation() {
@Override
public void collectDebugInfo(BiConsumer add) {
super.collectDebugInfo(add);
- if (fastFilterContext.optimizedSegments > 0) {
- add.accept("optimized_segments", fastFilterContext.optimizedSegments);
- add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments);
- add.accept("leaf_visited", fastFilterContext.leaf);
- add.accept("inner_visited", fastFilterContext.inner);
- }
+ filterRewriteOptimizationContext.populateDebugInfo(add);
}
}