diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f17590635447..baafaa93a4f6 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -58,6 +58,8 @@ API Changes New Features --------------------- +* GITHUB#12252 Add function queries for computing similarity scores between knn vectors. (Elia Porciani, Alessandro Benedetti) + * LUCENE-10010 Introduce NFARunAutomaton to run NFA directly. (Patrick Zhai) * LUCENE-10626 Hunspell: add tools to aid dictionary editing: diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java index 0d0a99919006..6b2a8a1f546c 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java @@ -70,6 +70,14 @@ public boolean boolVal(int doc) throws IOException { return intVal(doc) != 0; } + public float[] floatVectorVal(int doc) throws IOException { + throw new UnsupportedOperationException(); + } + + public byte[] byteVectorVal(int doc) throws IOException { + throw new UnsupportedOperationException(); + } + /** * returns the bytes representation of the string val - TODO: should this return the indexed raw * bytes not? diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java new file mode 100644 index 000000000000..c8a4a93a2dfc --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +/** + * An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields. + */ +public class ByteKnnVectorFieldSource extends ValueSource { + private final String fieldName; + + public ByteKnnVectorFieldSource(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + + final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); + + if (vectorValues == null) { + throw new IllegalArgumentException( + "no byte vector value is indexed for field '" + fieldName + "'"); + } + + return new VectorFieldFunction(this) { + + @Override + public byte[] byteVectorVal(int doc) throws IOException { + if (exists(doc)) { + return vectorValues.vectorValue(); + } else { + return null; + } + } + + @Override + protected DocIdSetIterator getVectorIterator() { + return vectorValues; + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ByteKnnVectorFieldSource other = (ByteKnnVectorFieldSource) o; + return Objects.equals(fieldName, other.fieldName); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().hashCode(), fieldName); + } + + @Override + public String description() { + return "ByteKnnVectorFieldSource(" + fieldName + ")"; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java new file mode 100644 index 000000000000..fb6ec68ee9e5 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +/** + * ByteVectorSimilarityFunction returns a similarity function between two knn vectors + * with byte elements. + */ +public class ByteVectorSimilarityFunction extends VectorSimilarityFunction { + public ByteVectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { + super(similarityFunction, vector1, vector2); + } + + @Override + protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + + var v1 = f1.byteVectorVal(doc); + var v2 = f2.byteVectorVal(doc); + + if (v1 == null || v2 == null) { + return 0.f; + } + + assert v1.length == v2.length : "Vectors must have the same length"; + + return similarityFunction.compare(v1, v2); + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java new file mode 100644 index 000000000000..4996e026abee --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +/** Function that returns a constant byte vector value for every document. */ +public class ConstKnnByteVectorValueSource extends ValueSource { + private final byte[] vector; + + public ConstKnnByteVectorValueSource(byte[] constVector) { + this.vector = Objects.requireNonNull(constVector, "constVector"); + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + return new FunctionValues() { + @Override + public byte[] byteVectorVal(int doc) { + return vector; + } + + @Override + public String strVal(int doc) { + return Arrays.toString(vector); + } + + @Override + public String toString(int doc) throws IOException { + return description() + '=' + strVal(doc); + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConstKnnByteVectorValueSource other = (ConstKnnByteVectorValueSource) o; + return Arrays.equals(vector, other.vector); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector)); + } + + @Override + public String description() { + return "ConstKnnByteVectorValueSource(" + Arrays.toString(vector) + ')'; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java new file mode 100644 index 000000000000..57c016eb793e --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.util.VectorUtil; + +/** Function that returns a constant float vector value for every document. */ +public class ConstKnnFloatValueSource extends ValueSource { + private final float[] vector; + + public ConstKnnFloatValueSource(float[] constVector) { + this.vector = VectorUtil.checkFinite(Objects.requireNonNull(constVector, "constVector")); + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + return new FunctionValues() { + @Override + public float[] floatVectorVal(int doc) { + return vector; + } + + @Override + public String strVal(int doc) { + return Arrays.toString(vector); + } + + @Override + public String toString(int doc) throws IOException { + return description() + '=' + strVal(doc); + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConstKnnFloatValueSource other = (ConstKnnFloatValueSource) o; + return Arrays.equals(vector, other.vector); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector)); + } + + @Override + public String description() { + return "ConstKnnFloatValueSource(" + Arrays.toString(vector) + ')'; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java new file mode 100644 index 000000000000..9a1f27a7c79d --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +/** + * An implementation for retrieving {@link FunctionValues} instances for float knn vectors fields. + */ +public class FloatKnnVectorFieldSource extends ValueSource { + private final String fieldName; + + public FloatKnnVectorFieldSource(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + + final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); + + if (vectorValues == null) { + throw new IllegalArgumentException( + "no float vector value is indexed for field '" + fieldName + "'"); + } + return new VectorFieldFunction(this) { + + @Override + public float[] floatVectorVal(int doc) throws IOException { + if (exists(doc)) { + return vectorValues.vectorValue(); + } else { + return null; + } + } + + @Override + protected DocIdSetIterator getVectorIterator() { + return vectorValues; + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FloatKnnVectorFieldSource other = (FloatKnnVectorFieldSource) o; + return Objects.equals(fieldName, other.fieldName); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().hashCode(), fieldName); + } + + @Override + public String description() { + return "FloatKnnVectorFieldSource(" + fieldName + ")"; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java new file mode 100644 index 000000000000..296775388856 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +/** + * FloatVectorSimilarityFunction returns a similarity function between two knn vectors + * with float elements. + */ +public class FloatVectorSimilarityFunction extends VectorSimilarityFunction { + public FloatVectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { + super(similarityFunction, vector1, vector2); + } + + @Override + protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + + var v1 = f1.floatVectorVal(doc); + var v2 = f2.floatVectorVal(doc); + + if (v1 == null || v2 == null) { + return 0.f; + } + + assert v1.length == v2.length : "Vectors must have the same length"; + return similarityFunction.compare(v1, v2); + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java new file mode 100644 index 000000000000..de64984249fe --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +/** An implementation for retrieving {@link FunctionValues} instances for knn vectors fields. */ +public abstract class VectorFieldFunction extends FunctionValues { + + protected final ValueSource valueSource; + int lastDocID; + + protected VectorFieldFunction(ValueSource valueSource) { + this.valueSource = valueSource; + } + + protected abstract DocIdSetIterator getVectorIterator(); + + @Override + public String toString(int doc) throws IOException { + return valueSource.description() + strVal(doc); + } + + @Override + public boolean exists(int doc) throws IOException { + if (doc < lastDocID) { + throw new IllegalArgumentException( + "docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + doc); + } + + lastDocID = doc; + + int curDocID = getVectorIterator().docID(); + if (doc > curDocID) { + curDocID = getVectorIterator().advance(doc); + } + return doc == curDocID; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java new file mode 100644 index 000000000000..9ba2d359a568 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +/** VectorSimilarityFunction returns a similarity function between two knn vectors. */ +public abstract class VectorSimilarityFunction extends ValueSource { + + protected final org.apache.lucene.index.VectorSimilarityFunction similarityFunction; + protected final ValueSource vector1; + protected final ValueSource vector2; + + public VectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { + + this.similarityFunction = similarityFunction; + this.vector1 = vector1; + this.vector2 = vector2; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + + final FunctionValues vector1Vals = vector1.getValues(context, readerContext); + final FunctionValues vector2Vals = vector2.getValues(context, readerContext); + return new FunctionValues() { + @Override + public float floatVal(int doc) throws IOException { + return func(doc, vector1Vals, vector2Vals); + } + + @Override + public String strVal(int doc) throws IOException { + return Float.toString(floatVal(doc)); + } + + @Override + public boolean exists(int doc) throws IOException { + return MultiFunction.allExists(doc, vector1Vals, vector2Vals); + } + + @Override + public String toString(int doc) throws IOException { + return description() + " = " + strVal(doc); + } + }; + } + + protected abstract float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException; + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return Objects.equals(vector1, ((VectorSimilarityFunction) o).vector1) + && Objects.equals(vector2, ((VectorSimilarityFunction) o).vector2); + } + + @Override + public int hashCode() { + return Objects.hash(similarityFunction, vector1, vector2); + } + + @Override + public String description() { + return similarityFunction.name() + + "(" + + vector1.description() + + ", " + + vector2.description() + + ")"; + } +} diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java new file mode 100644 index 000000000000..12144b252ba0 --- /dev/null +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function; + +import java.util.List; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource; +import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource; +import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.CheckHits; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestKnnVectorSimilarityFunctions extends LuceneTestCase { + static Directory dir; + static Analyzer analyzer; + static IndexReader reader; + static IndexSearcher searcher; + static final List documents = List.of("1", "2"); + + @BeforeClass + public static void beforeClass() throws Exception { + dir = newDirectory(); + analyzer = new MockAnalyzer(random()); + IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer); + iwConfig.setMergePolicy(newLogMergePolicy()); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig); + + Document document = new Document(); + document.add(new StringField("id", "1", Field.Store.NO)); + document.add(new SortedDocValuesField("id", new BytesRef("1"))); + document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document.add(new KnnFloatVectorField("knnFloatField2", new float[] {5.2f, 3.2f, 3.1f})); + + // add only to the first document + document.add(new KnnFloatVectorField("knnFloatField3", new float[] {1.0f, 1.0f, 1.0f})); + document.add(new KnnByteVectorField("knnByteField3", new byte[] {1, 1, 1})); + + document.add(new KnnByteVectorField("knnByteField1", new byte[] {1, 2, 3})); + document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); + iw.addDocument(document); + + Document document2 = new Document(); + document2.add(new StringField("id", "2", Field.Store.NO)); + document2.add(new SortedDocValuesField("id", new BytesRef("2"))); + document2.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document2.add(new KnnFloatVectorField("knnFloatField2", new float[] {5.2f, 3.2f, 3.1f})); + + document2.add(new KnnByteVectorField("knnByteField1", new byte[] {1, 2, 3})); + document2.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); + iw.addDocument(document2); + + reader = iw.getReader(); + searcher = newSearcher(reader); + iw.close(); + } + + @AfterClass + public static void afterClass() throws Exception { + searcher = null; + reader.close(); + reader = null; + dir.close(); + dir = null; + analyzer.close(); + analyzer = null; + } + + @Test + public void vectorSimilarity_floatConstantVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ConstKnnFloatValueSource(new float[] {1, 2, 3}); + var v2 = new ConstKnnFloatValueSource(new float[] {5, 4, 1}); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.04f, 0.04f}); + } + + @Test + public void vectorSimilarity_byteConstantVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); + var v2 = new ConstKnnByteVectorValueSource(new byte[] {2, 5, 6}); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.05f, 0.05f}); + } + + @Test + public void vectorSimilarity_floatFieldVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new FloatKnnVectorFieldSource("knnFloatField1"); + var v2 = new FloatKnnVectorFieldSource("knnFloatField2"); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.049776014f, 0.049776014f}); + } + + @Test + public void vectorSimilarity_byteFieldVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ByteKnnVectorFieldSource("knnByteField1"); + var v2 = new ByteKnnVectorFieldSource("knnByteField2"); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.1f, 0.1f}); + } + + @Test + public void vectorSimilarity_FloatConstAndFloatFieldVectors_shouldReturnFloatSimilarity() + throws Exception { + var v1 = new ConstKnnFloatValueSource(new float[] {1, 2, 4}); + var v2 = new FloatKnnVectorFieldSource("knnFloatField1"); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.5f}); + } + + @Test + public void vectorSimilarity_ByteConstAndByteFieldVectors_shouldReturnFloatSimilarity() + throws Exception { + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 4}); + var v2 = new ByteKnnVectorFieldSource("knnByteField1"); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.5f}); + } + + @Test + public void vectorSimilarity_missingFloatVectorField_shouldReturnZero() throws Exception { + var v1 = new ConstKnnFloatValueSource(new float[] {2.f, 1.f, 1.f}); + var v2 = new FloatKnnVectorFieldSource("knnFloatField3"); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.f}); + } + + @Test + public void vectorSimilarity_missingByteVectorField_shouldReturnZero() throws Exception { + var v1 = new ConstKnnByteVectorValueSource(new byte[] {2, 1, 1}); + var v2 = new ByteKnnVectorFieldSource("knnByteField3"); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.f}); + } + + @Test + public void vectorSimilarity_twoVectorsWithDifferentDimensions_shouldRaiseException() { + ValueSource v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3, 4}); + ValueSource v2 = new ByteKnnVectorFieldSource("knnByteField1"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + AssertionError.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + + v1 = new ConstKnnFloatValueSource(new float[] {1.f, 2.f}); + v2 = new FloatKnnVectorFieldSource("knnFloatField1"); + FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + AssertionError.class, + () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); + } + + @Test + public void vectorSimilarity_byteAndFloatVectors_shouldRaiseException() { + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); + ValueSource v2 = new ByteKnnVectorFieldSource("knnByteField1"); + FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + UnsupportedOperationException.class, + () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); + + v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); + v2 = new FloatKnnVectorFieldSource("knnFloatField1"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + UnsupportedOperationException.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + } + + @Test + public void vectorSimilarity_wrongFieldType_shouldRaiseException() { + ValueSource v1 = new ByteKnnVectorFieldSource("knnByteField1"); + ValueSource v2 = new ByteKnnVectorFieldSource("knnFloatField2"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + + assertThrows( + IllegalArgumentException.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + + v1 = new FloatKnnVectorFieldSource("knnByteField1"); + v2 = new FloatKnnVectorFieldSource("knnFloatField2"); + FloatVectorSimilarityFunction floatVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + + assertThrows( + IllegalArgumentException.class, + () -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10)); + } + + private static void assertHits(Query q, float[] scores) throws Exception { + ScoreDoc[] expected = new ScoreDoc[scores.length]; + int[] expectedDocs = new int[scores.length]; + for (int i = 0; i < expected.length; i++) { + expectedDocs[i] = i; + expected[i] = new ScoreDoc(i, scores[i]); + } + TopDocs docs = + searcher.search( + q, documents.size(), new Sort(new SortField("id", SortField.Type.STRING)), true); + CheckHits.checkHitsQuery(q, expected, docs.scoreDocs, expectedDocs); + } +}