Skip to content

Commit

Permalink
Refactor mapper to extend FieldMapper, adjust tests accordingly
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Jul 6, 2022
1 parent 0325d0b commit ecb18fd
Show file tree
Hide file tree
Showing 4 changed files with 569 additions and 663 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,32 @@
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.Explicit;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.Nullable;
import org.opensearch.common.unit.Fuzziness;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.support.XContentMapValues;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

/**
* Field Mapper for Dense vector type. Extends ParametrizedFieldMapper in order to easily configure mapping parameters.
*
* @opensearch.internal
*/
public final class DenseVectorFieldMapper extends ParametrizedFieldMapper {
public final class DenseVectorFieldMapper extends FieldMapper {

public static final String CONTENT_TYPE = "dense_vector";

Expand All @@ -49,12 +51,29 @@ private static DenseVectorFieldMapper toType(FieldMapper in) {
* Builder for DenseVectorFieldMapper. This class defines the set of parameters that can be applied to the knn_vector
* field type
*/
public static class Builder extends ParametrizedFieldMapper.Builder {
public static class Builder extends FieldMapper.Builder<Builder> {
private CopyTo copyTo = CopyTo.empty();
private Integer dimension = 1;
private KnnContext knnContext = null;

private final Parameter<Boolean> hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, false);
public Builder(String name) {
super(name, Defaults.FIELD_TYPE);
builder = this;
}

@Override
public DenseVectorFieldMapper build(BuilderContext context) {
final DenseVectorFieldType mappedFieldType = new DenseVectorFieldType(buildFullName(context), dimension, knnContext);
return new DenseVectorFieldMapper(
buildFullName(context),
fieldType,
mappedFieldType,
multiFieldsBuilder.build(this, context),
copyTo
);
}

protected final Parameter<Integer> dimension = new Parameter<>(Names.DIMENSION.getValue(), false, () -> 1, (n, c, o) -> {
int value = XContentMapValues.nodeIntegerValue(o);
public Builder dimension(int value) {
if (value > MAX_DIMENSION) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[dimension] value %d cannot be greater than %d for vector [%s]", value, MAX_DIMENSION, name)
Expand All @@ -65,41 +84,13 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
String.format(Locale.ROOT, "[dimension] value %d must be greater than 0 for vector [%s]", value, name)
);
}
return value;
}, m -> toType(m).dimension).setSerializer((b, n, v) -> b.field(n, v.intValue()), v -> Integer.toString(v.intValue()));

private final Parameter<KnnContext> knnContext = new Parameter<>(
Names.KNN.getValue(),
false,
() -> null,
(n, c, o) -> KnnContext.parse(o),
m -> toType(m).knnContext
).setSerializer(((b, n, v) -> {
if (v == null) {
return;
}
b.startObject(n);
v.toXContent(b, ToXContent.EMPTY_PARAMS);
b.endObject();
}), m -> m.getKnnAlgorithmContext().getMethod().name());

public Builder(String name) {
super(name);
this.dimension = value;
return this;
}

@Override
protected List<Parameter<?>> getParameters() {
return List.of(dimension, knnContext, hasDocValues);
}

@Override
public DenseVectorFieldMapper build(BuilderContext context) {
return new DenseVectorFieldMapper(
buildFullName(context),
new DenseVectorFieldType(buildFullName(context), dimension.get(), knnContext.get()),
multiFieldsBuilder.build(this, context),
copyTo.build()
);
public Builder knn(KnnContext value) {
this.knnContext = value;
return this;
}
}

Expand All @@ -113,12 +104,30 @@ public static class TypeParser implements Mapper.TypeParser {
@Override
public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserContext parserContext) throws MapperParsingException {
Builder builder = new DenseVectorFieldMapper.Builder(name);
Object dimensionField = node.get(Names.DIMENSION.getValue());
String dimension = XContentMapValues.nodeStringValue(dimensionField, null);
if (dimension == null) {
throw new MapperParsingException(String.format(Locale.ROOT, "[dimension] property must be specified for field [%s]", name));
TypeParsers.parseField(builder, name, node, parserContext);

for (Iterator<Map.Entry<String, Object>> iterator = node.entrySet().iterator(); iterator.hasNext();) {
Map.Entry<String, Object> entry = iterator.next();
String fieldName = entry.getKey();
Object fieldNode = entry.getValue();
switch (fieldName) {
case "dimension":
if (fieldNode == null) {
throw new MapperParsingException(
String.format(Locale.ROOT, "[dimension] property must be specified for field [%s]", name)
);
}
builder.dimension(XContentMapValues.nodeIntegerValue(fieldNode, 1));
iterator.remove();
break;
case "knn":
builder.knn(KnnContext.parse(fieldNode));
iterator.remove();
break;
default:
break;
}
}
builder.parse(name, parserContext, node);
return builder;
}
}
Expand All @@ -145,7 +154,7 @@ public DenseVectorFieldType(String name, Map<String, String> meta, int dimension

@Override
public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) {
throw new UnsupportedOperationException("Dense_vector does not support fields search");
throw new UnsupportedOperationException("[fields search] are not supported on [" + CONTENT_TYPE + "] fields.");
}

@Override
Expand All @@ -154,16 +163,47 @@ public String typeName() {
}

@Override
public Query existsQuery(QueryShardContext context) {
return new FieldExistsQuery(name());
public Query termQuery(Object value, QueryShardContext context) {
throw new UnsupportedOperationException("[term] queries are not supported on [" + CONTENT_TYPE + "] fields.");
}

@Override
public Query termQuery(Object value, QueryShardContext context) {
throw new QueryShardException(
context,
"Dense_vector does not support exact searching, use KNN queries instead [" + name() + "]"
);
public Query fuzzyQuery(
Object value,
Fuzziness fuzziness,
int prefixLength,
int maxExpansions,
boolean transpositions,
QueryShardContext context
) {
throw new UnsupportedOperationException("[fuzzy] queries are not supported on [" + CONTENT_TYPE + "] fields.");
}

@Override
public Query prefixQuery(String value, MultiTermQuery.RewriteMethod method, boolean caseInsensitive, QueryShardContext context) {
throw new UnsupportedOperationException("[prefix] queries are not supported on [" + CONTENT_TYPE + "] fields.");
}

@Override
public Query wildcardQuery(
String value,
@Nullable MultiTermQuery.RewriteMethod method,
boolean caseInsensitive,
QueryShardContext context
) {
throw new UnsupportedOperationException("[wildcard] queries are not supported on [" + CONTENT_TYPE + "] fields.");
}

@Override
public Query regexpQuery(
String value,
int syntaxFlags,
int matchFlags,
int maxDeterminizedStates,
MultiTermQuery.RewriteMethod method,
QueryShardContext context
) {
throw new UnsupportedOperationException("[regexp] queries are not supported on [" + CONTENT_TYPE + "] fields.");
}

public int getDimension() {
Expand All @@ -182,8 +222,14 @@ public KnnContext getKnnContext() {
protected boolean hasDocValues;
protected String modelId;

public DenseVectorFieldMapper(String simpleName, DenseVectorFieldType mappedFieldType, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, mappedFieldType, multiFields, copyTo);
public DenseVectorFieldMapper(
String simpleName,
FieldType fieldType,
DenseVectorFieldType mappedFieldType,
MultiFields multiFields,
CopyTo copyTo
) {
super(simpleName, fieldType, mappedFieldType, multiFields, copyTo);
dimension = mappedFieldType.getDimension();
fieldType = new FieldType(DenseVectorFieldMapper.Defaults.FIELD_TYPE);
isKnnEnabled = mappedFieldType.getKnnContext() != null;
Expand All @@ -207,6 +253,57 @@ protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getDimension());
}

@Override
protected void mergeOptions(FieldMapper other, List<String> conflicts) {
DenseVectorFieldMapper denseVectorMergeWith = (DenseVectorFieldMapper) other;
if (!Objects.equals(dimension, denseVectorMergeWith.dimension)) {
conflicts.add("mapper [" + name() + "] has different [dimension]");
}

if (isOnlyOneObjectNull(knnContext, denseVectorMergeWith.knnContext)
|| (isBothObjectsNotNull(knnContext, denseVectorMergeWith.knnContext)
&& !Objects.equals(knnContext.getMetric(), denseVectorMergeWith.knnContext.getMetric()))) {
conflicts.add("mapper [" + name() + "] has different [metric]");
}

if (isBothObjectsNotNull(knnContext, denseVectorMergeWith.knnContext)) {

if (!Objects.equals(knnContext.getMetric(), denseVectorMergeWith.knnContext.getMetric())) {
conflicts.add("mapper [" + name() + "] has different [metric]");
}

if (isBothObjectsNotNull(knnContext.getKnnAlgorithmContext(), denseVectorMergeWith.knnContext.getKnnAlgorithmContext())) {
KnnAlgorithmContext knnAlgorithmContext = knnContext.getKnnAlgorithmContext();
KnnAlgorithmContext mergeWithKnnAlgorithmContext = denseVectorMergeWith.knnContext.getKnnAlgorithmContext();

if (isOnlyOneObjectNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext)
|| (isBothObjectsNotNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext)
&& !Objects.equals(knnAlgorithmContext.getMethod(), mergeWithKnnAlgorithmContext.getMethod()))) {
conflicts.add("mapper [" + name() + "] has different [method]");
}

if (isBothObjectsNotNull(knnAlgorithmContext, mergeWithKnnAlgorithmContext)) {
Map<String, Object> knnAlgoParams = knnAlgorithmContext.getParameters();
Map<String, Object> mergeWithKnnAlgoParams = mergeWithKnnAlgorithmContext.getParameters();

if (isOnlyOneObjectNull(knnAlgoParams, mergeWithKnnAlgoParams)
|| (isBothObjectsNotNull(knnAlgoParams, mergeWithKnnAlgoParams)
&& !Objects.equals(knnAlgoParams, mergeWithKnnAlgoParams))) {
conflicts.add("mapper [" + name() + "] has different [knn algorithm parameters]");
}
}
}
}
}

private boolean isOnlyOneObjectNull(Object object1, Object object2) {
return object1 == null && object2 != null || object2 == null && object1 != null;
}

private boolean isBothObjectsNotNull(Object object1, Object object2) {
return object1 != null && object2 != null;
}

protected void parseCreateField(ParseContext context, int dimension) throws IOException {

context.path().add(simpleName());
Expand Down Expand Up @@ -276,12 +373,7 @@ protected boolean docValuesByDefault() {
}

@Override
public ParametrizedFieldMapper.Builder getMergeBuilder() {
return new DenseVectorFieldMapper.Builder(simpleName()).init(this);
}

@Override
public final boolean parsesArrayValue() {
public boolean parsesArrayValue() {
return true;
}

Expand All @@ -293,6 +385,13 @@ public DenseVectorFieldType fieldType() {
@Override
protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException {
super.doXContentBody(builder, includeDefaults, params);

builder.field("dimension", dimension);
if (knnContext != null) {
builder.startObject("knn");
knnContext.toXContent(builder, params);
builder.endObject();
}
}

/**
Expand Down Expand Up @@ -325,6 +424,7 @@ static class Defaults {

static {
FIELD_TYPE.setTokenized(false);
FIELD_TYPE.setOmitNorms(true);
FIELD_TYPE.setIndexOptions(IndexOptions.NONE);
FIELD_TYPE.setDocValuesType(DocValuesType.NONE);
FIELD_TYPE.freeze();
Expand Down
Loading

0 comments on commit ecb18fd

Please sign in to comment.