Skip to content

Commit

Permalink
Add filter function for NeuralQueryBuilder and HybridQueryBuilder and…
Browse files Browse the repository at this point in the history
… modify fromXContent function in HybridQueryBuilder to support filter field.

Signed-off-by: Chloe Gao <chloewq@amazon.com>
  • Loading branch information
chloewqg committed Mar 10, 2025
1 parent 5f25d6c commit cb7dff7
Show file tree
Hide file tree
Showing 8 changed files with 733 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
### Features
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
- Support filter function for HybridQueryBuilder and NeuralQueryBuilder ([#1206](https://github.com/opensearch-project/neural-search/pull/1206))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
public static final String NAME = "hybrid";

private static final ParseField QUERIES_FIELD = new ParseField("queries");
private static final ParseField FILTER_FIELD = new ParseField("filter");
private static final ParseField PAGINATION_DEPTH_FIELD = new ParseField("pagination_depth");

private final List<QueryBuilder> queries = new ArrayList<>();
Expand Down Expand Up @@ -94,6 +95,28 @@ public HybridQueryBuilder add(QueryBuilder queryBuilder) {
return this;
}

/**
* Function to support filter on HybridQueryBuilder filter. Currently pushing down a filter
* to HybridQueryBuilder is not supported by design. We would simply check if the filter is valid
* and throw exception telling this is an unsupported operation. If the filter is null, then we do nothing and
* return.
* @param filter the filter parameter
* @return HybridQueryBuilder itself
*/
public QueryBuilder filter(QueryBuilder filter) {
if (validateFilterParams(filter) == false) {
return this;
}
for (int i = 0; i < queries.size(); i++) {
QueryBuilder query = queries.get(i);
if (query instanceof HybridQueryBuilder) {
throw new UnsupportedOperationException("Cannot push filter to nested hybridQueryBuilder");
}
queries.set(i, query.filter(filter));
}
return this;
}

/**
* Create builder object with a content of this hybrid query
* @param builder
Expand Down Expand Up @@ -155,6 +178,10 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
* }
* }
* ]
* "filter":
* "term": {
* "text": "keyword"
* }
* }
* }
* }
Expand All @@ -168,6 +195,7 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx

Integer paginationDepth = null;
final List<QueryBuilder> queries = new ArrayList<>();
QueryBuilder filter = null;
String queryName = null;

String currentFieldName = null;
Expand All @@ -178,6 +206,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
} else if (token == XContentParser.Token.START_OBJECT) {
if (QUERIES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
queries.add(parseInnerQueryBuilder(parser));
} else if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
filter = parseInnerQueryBuilder(parser);
} else {
log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName));
throw new ParsingException(
Expand Down Expand Up @@ -240,7 +270,11 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx
compoundQueryBuilder.paginationDepth(paginationDepth);
}
for (QueryBuilder query : queries) {
compoundQueryBuilder.add(query);
if (filter == null) {
compoundQueryBuilder.add(query);
} else {
compoundQueryBuilder.add(query.filter(filter));
}
}
return compoundQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,23 @@ protected void doWriteTo(StreamOutput out) throws IOException {
RescoreParser.streamOutput(out, rescoreContext);
}

/**
* Add a filter to Neural Query Builder
* @param filterToBeAdded filter to be added
* @return return itself with underlying filter combined with passed in filter
*/
public QueryBuilder filter(QueryBuilder filterToBeAdded) {
if (validateFilterParams(filterToBeAdded) == false) {
return this;
}
if (filter == null) {
filter = filterToBeAdded;
return this;
}
filter = filter.filter(filterToBeAdded);
return this;
}

@Override
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject(NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand All @@ -82,6 +83,7 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase {
static final String TEXT_FIELD_NAME = "field";
static final String QUERY_TEXT = "Hello world!";
static final String TERM_QUERY_TEXT = "keyword";
static final String FILTER_TERM_QUERY_TEXT = "filterKeyword";
static final String MODEL_ID = "mfgfgdsfgfdgsde";
static final int K = 10;
static final float BOOST = 1.8f;
Expand Down Expand Up @@ -436,6 +438,121 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
assertEquals(TERM_QUERY_TEXT, termQueryBuilder.value());
}

/**
* Tests basic query:
* {
* "query": {
* "hybrid": {
* "queries": [
* {
* "neural": {
* "text_knn": {
* "query_text": "Hello world",
* "model_id": "dcsdcasd",
* "k": 1
* }
* }
* },
* {
* "term": {
* "text": "keyword"
* }
* }
* ]
* "filter": {
* "term": {
* "text": "filterKeyword"
* }
* }
* }
* }
* }
*/
@SneakyThrows
public void testFromXContent_whenMultipleSubQueriesAndFilter_thenBuildSuccessfully() {
setUpClusterService();
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startArray("queries")
.startObject()
.startObject(NeuralQueryBuilder.NAME)
.startObject(VECTOR_FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
.field(K_FIELD.getPreferredName(), K)
.field(BOOST_FIELD.getPreferredName(), BOOST)
.endObject()
.endObject()
.endObject()
.startObject()
.startObject(TermQueryBuilder.NAME)
.field(TEXT_FIELD_NAME, TERM_QUERY_TEXT)
.endObject()
.endObject()
.endArray()

.field("pagination_depth", 10)
.startObject("filter")
.startObject(TermQueryBuilder.NAME)
.field(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT)
.endObject()
.endObject()
.endObject();

NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
List.of(
new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent),
new NamedXContentRegistry.Entry(
QueryBuilder.class,
new ParseField(NeuralQueryBuilder.NAME),
NeuralQueryBuilder::fromXContent
),
new NamedXContentRegistry.Entry(
QueryBuilder.class,
new ParseField(HybridQueryBuilder.NAME),
HybridQueryBuilder::fromXContent
)
)
);
XContentParser contentParser = createParser(
namedXContentRegistry,
xContentBuilder.contentType().xContent(),
BytesReference.bytes(xContentBuilder)
);
contentParser.nextToken();

HybridQueryBuilder queryTwoSubQueries = HybridQueryBuilder.fromXContent(contentParser);
assertEquals(2, queryTwoSubQueries.queries().size());
assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder);

assertTrue(queryTwoSubQueries.queries().get(1) instanceof BoolQueryBuilder);
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().size());
assertTrue(((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0) instanceof TermQueryBuilder);
assertEquals(1, ((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().size());

assertEquals(10, queryTwoSubQueries.paginationDepth().intValue());
// verify knn vector query
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0);
assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
assertEquals(K, (int) neuralQueryBuilder.k());
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
assertEquals(BOOST, neuralQueryBuilder.boost(), 0f);
assertEquals(
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
((NeuralQueryBuilder) queryTwoSubQueries.queries().get(0)).filter()
);
// verify term query
assertEquals(
new TermQueryBuilder(TEXT_FIELD_NAME, TERM_QUERY_TEXT),
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).must().get(0)
);
assertEquals(
new TermQueryBuilder(TEXT_FIELD_NAME, FILTER_TERM_QUERY_TEXT),
((BoolQueryBuilder) queryTwoSubQueries.queries().get(1)).filter().get(0)
);
}

@SneakyThrows
public void testFromXContent_whenIncorrectFormat_thenFail() {
XContentBuilder unsupportedFieldXContentBuilder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -960,6 +1077,29 @@ public void testVisit() {
assertEquals(3, visitedQueries.size());
}

public void testFilter() {
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(
NeuralQueryBuilder.builder().fieldName("test").queryText("test").build()
).add(new NeuralSparseQueryBuilder());
// Test for Null filter Case
QueryBuilder queryBuilder = hybridQueryBuilder.filter(null);
assertEquals(queryBuilder, hybridQueryBuilder);

// Test for Non-Null filter case and assert every field as expected
HybridQueryBuilder updatedHybridQueryBuilder = (HybridQueryBuilder) hybridQueryBuilder.filter(new MatchAllQueryBuilder());
assertEquals(updatedHybridQueryBuilder.queryName(), hybridQueryBuilder.queryName());
assertEquals(updatedHybridQueryBuilder.paginationDepth(), hybridQueryBuilder.paginationDepth());
NeuralQueryBuilder updatedNeuralQueryBuilder = (NeuralQueryBuilder) updatedHybridQueryBuilder.queries().get(0);
assertEquals(new MatchAllQueryBuilder(), updatedNeuralQueryBuilder.filter());
BoolQueryBuilder updatedNeuralSparseQueryBuilder = (BoolQueryBuilder) updatedHybridQueryBuilder.queries().get(1);
assertEquals(new NeuralSparseQueryBuilder(), updatedNeuralSparseQueryBuilder.must().get(0));
assertEquals(new MatchAllQueryBuilder(), updatedNeuralSparseQueryBuilder.filter().get(0));

// Test for Non-Null filter case but encountered Nested HybridQueryBuilder to throw Unsupported Exception
updatedHybridQueryBuilder.add(new HybridQueryBuilder());
assertThrows(UnsupportedOperationException.class, () -> updatedHybridQueryBuilder.filter(new MatchAllQueryBuilder()));
}

private Map<String, Object> getInnerMap(Object innerObject, String queryName, String fieldName) {
if (!(innerObject instanceof Map)) {
fail("field name does not map to nested object");
Expand Down
Loading

0 comments on commit cb7dff7

Please sign in to comment.