Skip to content

Commit

Permalink
Restructure term query validation (#808)
Browse files Browse the repository at this point in the history
  • Loading branch information
aprudhomme authored Jan 8, 2025
1 parent fd2415a commit 1c2d664
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 43 deletions.
16 changes: 16 additions & 0 deletions src/main/java/com/yelp/nrtsearch/server/field/AtomFieldDef.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import com.yelp.nrtsearch.server.grpc.Field;
import com.yelp.nrtsearch.server.grpc.RangeQuery;
import com.yelp.nrtsearch.server.grpc.SortType;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.core.KeywordAnalyzer;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.SortedSetSortField;
Expand Down Expand Up @@ -82,6 +85,19 @@ protected Analyzer parseSearchAnalyzer(Field requestField) {
return keywordAnalyzer;
}

@Override
public Query getTermQueryFromTextValue(String textValue) {
verifySearchable("Term query");
return new org.apache.lucene.search.TermQuery(new Term(getName(), textValue));
}

@Override
public Query getTermInSetQueryFromTextValues(List<String> textValues) {
verifySearchable("Term in set query");
List<BytesRef> textTerms = textValues.stream().map(BytesRef::new).collect(Collectors.toList());
return new org.apache.lucene.search.TermInSetQuery(getName(), textTerms);
}

@Override
public SortField getSortField(SortType type) {
verifyDocValues("Sort field");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,14 @@ public String getType() {

@Override
public Query getTermQueryFromBooleanValue(boolean booleanValue) {
verifySearchable("Term query");
String indexTermValue = booleanValue ? "1" : "0";
return new org.apache.lucene.search.TermQuery(new Term(getName(), indexTermValue));
}

@Override
public Query getTermQueryFromTextValue(String textValue) {
boolean termValue = parseBooleanOrThrow(textValue);
String indexTermValue = termValue ? "1" : "0";
return new org.apache.lucene.search.TermQuery(new Term(getName(), indexTermValue));
return getTermQueryFromBooleanValue(parseBooleanOrThrow(textValue));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,9 @@ public Query getRangeQuery(RangeQuery rangeQuery) {
return new IndexOrDocValuesQuery(pointQuery, dvQuery);
}

@Override
public void checkTermQueriesSupported() {
if (!isSearchable() && !hasDocValues()) {
throw new IllegalStateException(
"Field \""
+ getName()
+ "\" is not searchable or does not have doc values, which is required for TermQuery / TermInSetQuery");
}
}

@Override
public Query getTermQueryFromLongValue(long longValue) {
verifySearchableOrDocValues("Term query");
Query pointQuery = null;
Query dvQuery = null;
if (isSearchable()) {
Expand All @@ -180,6 +171,7 @@ public Query getTermQueryFromLongValue(long longValue) {

@Override
public Query getTermInSetQueryFromLongValues(List<Long> longValues) {
verifySearchableOrDocValues("Term in set query");
Query pointQuery = null;
Query dvQuery = null;
if (isSearchable()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,23 +150,25 @@ private void ensureUpperIsMoreThanLower(RangeQuery rangeQuery, double lower, dou

@Override
public Query getTermQueryFromDoubleValue(double doubleValue) {
verifySearchable("Term query");
return DoublePoint.newExactQuery(getName(), doubleValue);
}

@Override
public Query getTermInSetQueryFromDoubleValues(List<Double> doubleValues) {
verifySearchable("Term in set query");
return DoublePoint.newSetQuery(getName(), doubleValues);
}

@Override
public Query getTermQueryFromTextValue(String textValue) {
return DoublePoint.newExactQuery(getName(), Double.parseDouble(textValue));
return getTermQueryFromDoubleValue(Double.parseDouble(textValue));
}

@Override
public Query getTermInSetQueryFromTextValues(List<String> textValues) {
List<Double> doubleTerms = new ArrayList<>(textValues.size());
textValues.forEach((s) -> doubleTerms.add(Double.parseDouble(s)));
return DoublePoint.newSetQuery(getName(), doubleTerms);
return getTermInSetQueryFromDoubleValues(doubleTerms);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,25 @@ private void ensureUpperIsMoreThanLower(RangeQuery rangeQuery, float lower, floa

@Override
public Query getTermQueryFromFloatValue(float floatValue) {
verifySearchable("Term query");
return FloatPoint.newExactQuery(getName(), floatValue);
}

@Override
public Query getTermInSetQueryFromFloatValues(List<Float> floatValues) {
verifySearchable("Term in set query");
return FloatPoint.newSetQuery(getName(), floatValues);
}

@Override
public Query getTermQueryFromTextValue(String textValue) {
return FloatPoint.newExactQuery(getName(), Float.parseFloat(textValue));
return getTermQueryFromFloatValue(Float.parseFloat(textValue));
}

@Override
public Query getTermInSetQueryFromTextValues(List<String> textValues) {
List<Float> floatTerms = new ArrayList<>(textValues.size());
textValues.forEach((s) -> floatTerms.add(Float.parseFloat(s)));
return FloatPoint.newSetQuery(getName(), floatTerms);
return getTermInSetQueryFromFloatValues(floatTerms);
}
}
2 changes: 2 additions & 0 deletions src/main/java/com/yelp/nrtsearch/server/field/IdFieldDef.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ public Term getTerm(Document document) {

@Override
public Query getTermQueryFromTextValue(String textValue) {
// _ID fields are always searchable
return new org.apache.lucene.search.TermQuery(new Term(getName(), textValue));
}

@Override
public Query getTermInSetQueryFromTextValues(List<String> textValues) {
// _ID fields are always searchable
List<BytesRef> textTerms = textValues.stream().map(BytesRef::new).collect(Collectors.toList());
return new org.apache.lucene.search.TermInSetQuery(getName(), textTerms);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,26 @@ private void ensureUpperIsMoreThanLower(RangeQuery rangeQuery, int lower, int up

@Override
public Query getTermQueryFromIntValue(int intValue) {
verifySearchable("Term query");
return IntPoint.newExactQuery(getName(), intValue);
}

@Override
public Query getTermInSetQueryFromIntValues(List<Integer> intValues) {
verifySearchable("Term in set query");
return IntPoint.newSetQuery(getName(), intValues);
}

@Override
public Query getTermQueryFromTextValue(String textValue) {
return IntPoint.newExactQuery(getName(), Integer.parseInt(textValue));
return getTermQueryFromIntValue(Integer.parseInt(textValue));
}

@Override
public Query getTermInSetQueryFromTextValues(List<String> textValues) {
List<Integer> intTerms = new ArrayList<>(textValues.size());
textValues.forEach((s) -> intTerms.add(Integer.parseInt(s)));
return IntPoint.newSetQuery(getName(), intTerms);
return getTermInSetQueryFromIntValues(intTerms);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,24 +139,26 @@ private void ensureUpperIsMoreThanLower(RangeQuery rangeQuery, long lower, long

@Override
public Query getTermQueryFromLongValue(long longValue) {
verifySearchable("Term query");
return LongPoint.newExactQuery(getName(), longValue);
}

@Override
public Query getTermInSetQueryFromLongValues(List<Long> longValues) {
verifySearchable("Term in set query");
return LongPoint.newSetQuery(getName(), longValues);
}

@Override
public Query getTermQueryFromTextValue(String textValue) {
return LongPoint.newExactQuery(getName(), Long.parseLong(textValue));
return getTermQueryFromLongValue(Long.parseLong(textValue));
}

@Override
public Query getTermInSetQueryFromTextValues(List<String> textValues) {
List<Long> longTerms = new ArrayList<>(textValues.size());
textValues.forEach((s) -> longTerms.add(Long.parseLong(s)));
return LongPoint.newSetQuery(getName(), longTerms);
return getTermInSetQueryFromLongValues(longTerms);
}

protected Number parseNumberString(String numberString) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,13 @@ private void addFacet(Document document, String value, List<String> paths) {

@Override
public Query getTermQueryFromTextValue(String textValue) {
verifySearchable("Term query");
return new org.apache.lucene.search.TermQuery(new Term(getName(), textValue));
}

@Override
public Query getTermInSetQueryFromTextValues(List<String> textValues) {
verifySearchable("Term in set query");
List<BytesRef> textTerms = textValues.stream().map(BytesRef::new).collect(Collectors.toList());
return new org.apache.lucene.search.TermInSetQuery(getName(), textTerms);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package com.yelp.nrtsearch.server.field.properties;

import com.yelp.nrtsearch.server.field.FieldDef;
import com.yelp.nrtsearch.server.field.IndexableFieldDef;
import com.yelp.nrtsearch.server.grpc.TermInSetQuery;
import com.yelp.nrtsearch.server.grpc.TermQuery;
import java.util.List;
Expand Down Expand Up @@ -261,21 +260,4 @@ default Query getTermInSetQueryFromLongValues(List<Long> longValues) {
default Query getTermInSetQueryFromTextValues(List<String> textValues) {
return null;
}

/**
* Verify that this field supports term/term in set queries.
*
* @throws IllegalStateException if term queries are not supported
*/
default void checkTermQueriesSupported() {
if (!(this instanceof IndexableFieldDef<?> indexableFieldDef)) {
throw new IllegalStateException("Instance is not an IndexableFieldDef");
}
if (!indexableFieldDef.isSearchable()) {
throw new IllegalStateException(
"Field "
+ indexableFieldDef.getName()
+ " is not searchable, which is required for TermQuery / TermInSetQuery");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ private Query getTermQuery(com.yelp.nrtsearch.server.grpc.TermQuery termQuery, I
FieldDef fieldDef = state.getFieldOrThrow(fieldName);

if (fieldDef instanceof TermQueryable termQueryable) {
termQueryable.checkTermQueriesSupported();
return termQueryable.getTermQuery(termQuery);
}

Expand All @@ -312,7 +311,6 @@ private Query getTermInSetQuery(
FieldDef fieldDef = state.getFieldOrThrow(fieldName);

if (fieldDef instanceof TermQueryable termQueryable) {
termQueryable.checkTermQueriesSupported();
return termQueryable.getTermInSetQuery(termInSetQuery);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ public void testTermQuery_noSearchOrDV() {
assertTrue(
e.getMessage()
.contains(
"Field \"stored_only\" is not searchable or does not have doc values, which is required for TermQuery / TermInSetQuery"));
"Term query requires field to be searchable or have doc values: stored_only"));
}
}

Expand Down Expand Up @@ -830,7 +830,7 @@ public void testTermInSetQuery_noSearchOrDV() {
assertTrue(
e.getMessage()
.contains(
"Field \"stored_only\" is not searchable or does not have doc values, which is required for TermQuery / TermInSetQuery"));
"Term in set query requires field to be searchable or have doc values: stored_only"));
}
}

Expand Down

0 comments on commit 1c2d664

Please sign in to comment.