Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pre aggregate changes #22

Open
wants to merge 15 commits into
base: 2.11
Choose a base branch
from
Prev Previous commit
Next Next commit
search query and aggs changes with ndv iterator bug
Signed-off-by: Bharathwaj G <bharath78910@gmail.com>
  • Loading branch information
bharath-techie committed Feb 5, 2024
commit d73d470c40ff202e88d640b141ee94137465a888
Original file line number Diff line number Diff line change
@@ -27,7 +27,9 @@
import java.util.Queue;
import java.util.Set;
import java.util.function.Predicate;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.DocIdSetBuilder;
@@ -92,11 +94,11 @@ public DocIdSetIterator getStarTreeResult()
// TODO : set to max value of doc values
DocIdSetBuilder builder = new DocIdSetBuilder(Integer.MAX_VALUE);
List<Predicate<Long>> compositePredicateEvaluators = _predicateEvaluators.get(remainingPredicateColumn);
NumericDocValues ndv = this.dimValueMap.get(remainingPredicateColumn);
SortedNumericDocValues ndv = DocValues.singleton(this.dimValueMap.get(remainingPredicateColumn));
for (int docID = ndv.nextDoc(); docID != NO_MORE_DOCS; docID = ndv.nextDoc()) {
for (Predicate<Long> compositePredicateEvaluator : compositePredicateEvaluators) {
// TODO : this might be expensive as its done against all doc values docs
if (compositePredicateEvaluator.test(ndv.longValue())) {
if (compositePredicateEvaluator.test(ndv.nextValue())) {
builder.grow(1).add(docID);
break;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.codec.freshstartree.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import org.apache.lucene.search.Query;
import org.opensearch.common.lucene.search.Queries;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ObjectParser;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;

import static org.opensearch.core.xcontent.ObjectParser.fromList;


public class StarTreeQueryBuilder extends AbstractQueryBuilder<StarTreeQueryBuilder> {
public static final String NAME = "startree";
private static final ParseField FILTER = new ParseField("filter");
private final List<QueryBuilder> filterClauses = new ArrayList<>();

private final Set<String> groupBy = new HashSet<>();
Map<String, List<Predicate<Long>>> predicateMap = new HashMap<>();


public StarTreeQueryBuilder() {}

/**
* Read from a stream.
*/
public StarTreeQueryBuilder(StreamInput in) throws IOException {
super(in);
filterClauses.addAll(readQueries(in));
}

static List<QueryBuilder> readQueries(StreamInput in) throws IOException {
int size = in.readVInt();
List<QueryBuilder> queries = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
queries.add(in.readNamedWriteable(QueryBuilder.class));
}
return queries;
}

@Override
protected void doWriteTo(StreamOutput out) {
// only superclass has state
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
//printBoostAndQueryName(builder);
doXArrayContent(FILTER, filterClauses, builder, params);
builder.endObject();
}

private static void doXArrayContent(ParseField field, List<QueryBuilder> clauses, XContentBuilder builder, Params params)
throws IOException {
if (clauses.isEmpty()) {
return;
}
builder.startArray(field.getPreferredName());
for (QueryBuilder clause : clauses) {
clause.toXContent(builder, params);
}
builder.endArray();
}

private static final ObjectParser<StarTreeQueryBuilder, Void> PARSER = new ObjectParser<>(NAME, StarTreeQueryBuilder::new);

static {
//declareStandardFields(PARSER);
PARSER.declareObjectArrayOrNull(
(builder, clauses) -> clauses.forEach(builder::filter),
(p, c) -> parseInnerQueryBuilder(p),
FILTER
);
PARSER.declareStringArray(StarTreeQueryBuilder::groupby, new ParseField("groupby"));

}

private void groupby(List<String> strings) {
groupBy.addAll(strings);
}

public StarTreeQueryBuilder filter(QueryBuilder queryBuilder) {
if (queryBuilder == null) {
throw new IllegalArgumentException("inner bool query clause cannot be null");
}
filterClauses.add(queryBuilder);

List<Predicate<Long>> predicates = new ArrayList<>();
//predicates.add(day -> day > 2 && day < 5);
//predicates.add(day -> day == 30);
//predicateMap.put("day", predicates);
predicates = new ArrayList<>();
predicates.add(status -> status == 200);
predicateMap.put("status", predicates);


return this;
}

public static StarTreeQueryBuilder fromXContent(XContentParser parser) {
try {
return PARSER.apply(parser, null);
} catch (IllegalArgumentException e) {
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
}
}

@Override
protected Query doToQuery(QueryShardContext context) {
// Set<String> groupByCols = new HashSet<>();
// groupByCols.add("hour");
if(predicateMap.size() > 0) {
return new StarTreeQuery(predicateMap, new HashSet<>());
}
return new StarTreeQuery(new HashMap<>(), this.groupBy);
}

@Override
protected boolean doEquals(StarTreeQueryBuilder other) {
return true;
}

@Override
protected int doHashCode() {
return 0;
}

@Override
public String getWriteableName() {
return NAME;
}
}
13 changes: 13 additions & 0 deletions server/src/main/java/org/opensearch/search/SearchModule.java
Original file line number Diff line number Diff line change
@@ -47,6 +47,7 @@
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.codec.freshstartree.query.StarTreeQueryBuilder;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.BoostingQueryBuilder;
import org.opensearch.index.query.CommonTermsQueryBuilder;
@@ -155,6 +156,9 @@
import org.opensearch.search.aggregations.bucket.sampler.InternalSampler;
import org.opensearch.search.aggregations.bucket.sampler.SamplerAggregationBuilder;
import org.opensearch.search.aggregations.bucket.sampler.UnmappedSampler;
import org.opensearch.search.aggregations.bucket.startree.InternalStarTree;
import org.opensearch.search.aggregations.bucket.startree.StarTreeAggregationBuilder;
import org.opensearch.search.aggregations.bucket.startree.StarTreeAggregator;
import org.opensearch.search.aggregations.bucket.terms.DoubleTerms;
import org.opensearch.search.aggregations.bucket.terms.InternalMultiTerms;
import org.opensearch.search.aggregations.bucket.terms.LongRareTerms;
@@ -493,6 +497,11 @@ private ValuesSourceRegistry registerAggregations(List<SearchPlugin> plugins) {
.addResultReader(InternalGlobal::new),
builder
);
registerAggregation(
new AggregationSpec(StarTreeAggregationBuilder.NAME, StarTreeAggregationBuilder::new, StarTreeAggregationBuilder::parse)
.addResultReader(InternalStarTree::new),
builder
);
registerAggregation(
new AggregationSpec(MissingAggregationBuilder.NAME, MissingAggregationBuilder::new, MissingAggregationBuilder.PARSER)
.addResultReader(InternalMissing::new)
@@ -675,6 +684,7 @@ private ValuesSourceRegistry registerAggregations(List<SearchPlugin> plugins) {
.setAggregatorRegistrar(MultiTermsAggregationFactory::registerAggregators),
builder
);

registerFromPlugin(plugins, SearchPlugin::getAggregations, (agg) -> this.registerAggregation(agg, builder));

// after aggs have been registered, see if there are any new VSTypes that need to be linked to core fields
@@ -1202,6 +1212,9 @@ private void registerQueryParsers(List<SearchPlugin> plugins) {
registerQuery(new QuerySpec<>(GeoShapeQueryBuilder.NAME, GeoShapeQueryBuilder::new, GeoShapeQueryBuilder::fromXContent));
}

registerQuery(new QuerySpec<>(StarTreeQueryBuilder.NAME, StarTreeQueryBuilder::new, StarTreeQueryBuilder::fromXContent));


registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);
}

Original file line number Diff line number Diff line change
@@ -31,6 +31,8 @@

package org.opensearch.search.aggregations.bucket.range;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.core.ParseField;
@@ -62,6 +64,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.opensearch.search.query.QueryPhase;

import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg;

@@ -75,6 +78,8 @@ public class RangeAggregator extends BucketsAggregator {
public static final ParseField RANGES_FIELD = new ParseField("ranges");
public static final ParseField KEYED_FIELD = new ParseField("keyed");

private static final Logger LOGGER = LogManager.getLogger(RangeAggregator.class);

/**
* Range for the range aggregator
*
@@ -351,6 +356,9 @@ private int collect(int doc, double value, long owningBucketOrdinal, int lowBoun
}

private long subBucketOrdinal(long owningBucketOrdinal, int rangeOrd) {
long subord = owningBucketOrdinal * ranges.length + rangeOrd;
LOGGER.info("Owning bucket ordinal : {} , rangeord : {} , len : {} == SubOrd : {}",
owningBucketOrdinal, rangeOrd, ranges.length, subord);
return owningBucketOrdinal * ranges.length + rangeOrd;
}

Loading