diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 7997c5e619be..faba24e3c107 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -23,7 +23,8 @@ API Changes New Features --------------------- -(No changes) +* GITHUB#12548: Added similarityToQueryVector API to compute vector similarity scores + with DoubleValuesSource. (Shubham Chaudhary) Improvements --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java new file mode 100644 index 000000000000..06afe3d3a54f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; + +/** + * A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector + * and the {@link org.apache.lucene.document.KnnByteVectorField} for documents. + */ +class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource { + private final byte[] queryVector; + + public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName) { + super(fieldName); + this.queryVector = vector; + } + + @Override + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName); + VectorSimilarityFunction function = + ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(); + return new DoubleValues() { + @Override + public double doubleValue() throws IOException { + return function.compare(queryVector, vectorValues.vectorValue()); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + return doc >= vectorValues.docID() + && (vectorValues.docID() == doc || vectorValues.advance(doc) == doc); + } + }; + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, Arrays.hashCode(queryVector)); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + ByteVectorSimilarityValuesSource other = (ByteVectorSimilarityValuesSource) obj; + return Objects.equals(fieldName, other.fieldName) + && Arrays.equals(queryVector, other.queryVector); + } + + @Override + public String toString() { + return "ByteVectorSimilarityValuesSource(fieldName=" + + fieldName + + " queryVector=" + + Arrays.toString(queryVector) + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java index f27b791dd95a..6034c97b8c08 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.comparators.DoubleComparator; /** @@ -172,6 +173,52 @@ public LongValuesSource rewrite(IndexSearcher searcher) throws IOException { } } + /** + * Returns a DoubleValues instance for computing the vector similarity score per document against + * the byte query vector + * + * @param ctx the context for which to return the DoubleValues + * @param queryVector byte query vector + * @param vectorField knn byte field name + * @return DoubleValues instance + * @throws IOException if an {@link IOException} occurs + */ + public static DoubleValues similarityToQueryVector( + LeafReaderContext ctx, byte[] queryVector, String vectorField) throws IOException { + if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding() + != VectorEncoding.BYTE) { + throw new IllegalArgumentException( + "Field " + + vectorField + + " does not have the expected vector encoding: " + + VectorEncoding.BYTE); + } + return new ByteVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null); + } + + /** + * Returns a DoubleValues instance for computing the vector similarity score per document against + * the float query vector + * + * @param ctx the context for which to return the DoubleValues + * @param queryVector float query vector + * @param vectorField knn float field name + * @return DoubleValues instance + * @throws IOException if an {@link IOException} occurs + */ + public static DoubleValues similarityToQueryVector( + LeafReaderContext ctx, float[] queryVector, String vectorField) throws IOException { + if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding() + != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "Field " + + vectorField + + " does not have the expected vector encoding: " + + VectorEncoding.FLOAT32); + } + return new FloatVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null); + } + /** * Creates a DoubleValuesSource that wraps a generic NumericDocValues field * diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java new file mode 100644 index 000000000000..b914ce75e20b --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; + +/** + * A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector + * and the {@link org.apache.lucene.document.KnnFloatVectorField} for documents. + */ +class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource { + + private final float[] queryVector; + + public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) { + super(fieldName); + this.queryVector = vector; + } + + @Override + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); + VectorSimilarityFunction function = + ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(); + return new DoubleValues() { + @Override + public double doubleValue() throws IOException { + return function.compare(queryVector, vectorValues.vectorValue()); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + return doc >= vectorValues.docID() + && (vectorValues.docID() == doc || vectorValues.advance(doc) == doc); + } + }; + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, Arrays.hashCode(queryVector)); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + FloatVectorSimilarityValuesSource other = (FloatVectorSimilarityValuesSource) obj; + return Objects.equals(fieldName, other.fieldName) + && Arrays.equals(queryVector, other.queryVector); + } + + @Override + public String toString() { + return "FloatVectorSimilarityValuesSource(fieldName=" + + fieldName + + " queryVector=" + + Arrays.toString(queryVector) + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java new file mode 100644 index 000000000000..639e225d665e --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityValuesSource.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.index.LeafReaderContext; + +/** + * An abstract class that provides the vector similarity scores between the query vector and the + * {@link org.apache.lucene.document.KnnFloatVectorField} or {@link + * org.apache.lucene.document.KnnByteVectorField} for documents. + */ +abstract class VectorSimilarityValuesSource extends DoubleValuesSource { + protected final String fieldName; + + public VectorSimilarityValuesSource(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public abstract DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) + throws IOException; + + @Override + public boolean needsScores() { + return false; + } + + @Override + public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { + return this; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityValuesSource.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityValuesSource.java new file mode 100644 index 000000000000..840d3fc6b8b7 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityValuesSource.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.IOUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +public class TestVectorSimilarityValuesSource extends LuceneTestCase { + private static Directory dir; + private static Analyzer analyzer; + private static IndexReader reader; + private static IndexSearcher searcher; + + @BeforeClass + public static void beforeClass() throws Exception { + dir = newDirectory(); + analyzer = new MockAnalyzer(random()); + IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer); + iwConfig.setMergePolicy(newLogMergePolicy()); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig); + + Document document = new Document(); + document.add(new StringField("id", "1", Field.Store.NO)); + document.add(new SortedDocValuesField("id", new BytesRef("1"))); + document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document.add( + new KnnFloatVectorField( + "knnFloatField2", + new float[] {2.2f, -3.2f, -3.1f}, + VectorSimilarityFunction.DOT_PRODUCT)); + document.add( + new KnnFloatVectorField( + "knnFloatField3", new float[] {4.5f, 10.3f, -7.f}, VectorSimilarityFunction.COSINE)); + document.add( + new KnnFloatVectorField( + "knnFloatField4", + new float[] {-1.3f, 1.0f, 1.0f}, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT)); + document.add(new KnnFloatVectorField("knnFloatField5", new float[] {-6.7f, -1.0f, -0.9f})); + document.add(new KnnByteVectorField("knnByteField1", new byte[] {106, 80, 127})); + document.add( + new KnnByteVectorField( + "knnByteField2", new byte[] {4, 2, 3}, VectorSimilarityFunction.DOT_PRODUCT)); + document.add( + new KnnByteVectorField( + "knnByteField3", new byte[] {-121, -64, -1}, VectorSimilarityFunction.COSINE)); + document.add( + new KnnByteVectorField( + "knnByteField4", + new byte[] {-127, 127, 127}, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT)); + iw.addDocument(document); + + Document document2 = new Document(); + document2.add(new StringField("id", "2", Field.Store.NO)); + document2.add(new SortedDocValuesField("id", new BytesRef("2"))); + document2.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document2.add( + new KnnFloatVectorField( + "knnFloatField2", + new float[] {-5.2f, 8.7f, 3.1f}, + VectorSimilarityFunction.DOT_PRODUCT)); + document2.add( + new KnnFloatVectorField( + "knnFloatField3", new float[] {0.2f, -3.2f, 3.1f}, VectorSimilarityFunction.COSINE)); + document2.add(new KnnFloatVectorField("knnFloatField5", new float[] {2.f, 13.2f, 9.1f})); + document2.add(new KnnByteVectorField("knnByteField1", new byte[] {1, -2, -30})); + document2.add( + new KnnByteVectorField( + "knnByteField2", new byte[] {40, 21, 3}, VectorSimilarityFunction.DOT_PRODUCT)); + document2.add( + new KnnByteVectorField( + "knnByteField3", new byte[] {9, 2, 3}, VectorSimilarityFunction.COSINE)); + document2.add( + new KnnByteVectorField( + "knnByteField4", + new byte[] {14, 29, 31}, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT)); + iw.addDocument(document2); + + Document document3 = new Document(); + document3.add(new StringField("id", "3", Field.Store.NO)); + document3.add(new SortedDocValuesField("id", new BytesRef("3"))); + document3.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document3.add( + new KnnFloatVectorField( + "knnFloatField2", new float[] {-8.f, 7.f, -6.f}, VectorSimilarityFunction.DOT_PRODUCT)); + document3.add(new KnnFloatVectorField("knnFloatField5", new float[] {5.2f, 3.2f, 3.1f})); + document3.add(new KnnByteVectorField("knnByteField1", new byte[] {-128, 0, 127})); + document3.add( + new KnnByteVectorField( + "knnByteField2", new byte[] {-1, -2, -3}, VectorSimilarityFunction.DOT_PRODUCT)); + document3.add( + new KnnByteVectorField( + "knnByteField3", new byte[] {4, 2, 3}, VectorSimilarityFunction.COSINE)); + document3.add( + new KnnByteVectorField( + "knnByteField4", + new byte[] {-4, -2, -128}, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT)); + document3.add(new KnnByteVectorField("knnByteField5", new byte[] {-120, -2, 3})); + iw.addDocument(document3); + + reader = iw.getReader(); + searcher = newSearcher(reader); + iw.close(); + } + + @AfterClass + public static void afterClass() throws Exception { + searcher = null; + IOUtils.close(reader, dir, analyzer); + } + + public void testEuclideanSimilarityValuesSource() throws Exception { + float[] floatQueryVector = new float[] {9.f, 1.f, -10.f}; + + // Checks the computed similarity score between indexed vectors and query vector + // using DVS is correct by passing indexed and query vector in #compare + DoubleValues dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField1"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new float[] {1.f, 2.f, 3.f}, floatQueryVector)); + assertTrue( + dv.advanceExact(1) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new float[] {1.f, 2.f, 3.f}, floatQueryVector)); + assertTrue( + dv.advanceExact(2) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new float[] {1.f, 2.f, 3.f}, floatQueryVector)); + + dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField5"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new float[] {-6.7f, -1.0f, -0.9f}, floatQueryVector)); + assertTrue( + dv.advanceExact(1) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new float[] {2.f, 13.2f, 9.1f}, floatQueryVector)); + assertTrue( + dv.advanceExact(2) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new float[] {5.2f, 3.2f, 3.1f}, floatQueryVector)); + + byte[] byteQueryVector = new byte[] {-128, 2, 127}; + + dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), byteQueryVector, "knnByteField1"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new byte[] {106, 80, 127}, byteQueryVector)); + assertTrue( + dv.advanceExact(1) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new byte[] {1, -2, -30}, byteQueryVector)); + assertTrue( + dv.advanceExact(2) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new byte[] {-128, 0, 127}, byteQueryVector)); + + dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), byteQueryVector, "knnByteField5"); + assertFalse(dv.advanceExact(0)); + assertFalse(dv.advanceExact(1)); + assertTrue( + dv.advanceExact(2) + && dv.doubleValue() + == VectorSimilarityFunction.EUCLIDEAN.compare( + new byte[] {-120, -2, 3}, byteQueryVector)); + } + + public void testDotSimilarityValuesSource() throws Exception { + float[] floatQueryVector = new float[] {10.f, 1.f, -8.5f}; + + // Checks the computed similarity score between indexed vectors and query vector + // using DVS is correct by passing indexed and query vector in #compare + DoubleValues dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField2"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.DOT_PRODUCT.compare( + new float[] {2.2f, -3.2f, -3.1f}, floatQueryVector)); + assertTrue( + dv.advanceExact(1) + && dv.doubleValue() + == VectorSimilarityFunction.DOT_PRODUCT.compare( + new float[] {-5.2f, 8.7f, 3.1f}, floatQueryVector)); + assertTrue( + dv.advanceExact(2) + && dv.doubleValue() + == VectorSimilarityFunction.DOT_PRODUCT.compare( + new float[] {-8.f, 7.f, -6.f}, floatQueryVector)); + + byte[] byteQueryVector = new byte[] {-128, 2, 127}; + + dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), byteQueryVector, "knnByteField2"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.DOT_PRODUCT.compare( + new byte[] {4, 2, 3}, byteQueryVector)); + assertTrue( + dv.advanceExact(1) + && dv.doubleValue() + == VectorSimilarityFunction.DOT_PRODUCT.compare( + new byte[] {40, 21, 3}, byteQueryVector)); + assertTrue( + dv.advanceExact(2) + && dv.doubleValue() + == VectorSimilarityFunction.DOT_PRODUCT.compare( + new byte[] {-1, -2, -3}, byteQueryVector)); + } + + public void testCosineSimilarityValuesSource() throws Exception { + float[] floatQueryVector = new float[] {0.6f, -1.6f, 38.0f}; + + // Checks the computed similarity score between indexed vectors and query vector + // using DVS is correct by passing indexed and query vector in #compare + DoubleValues dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField3"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.COSINE.compare( + new float[] {4.5f, 10.3f, -7.f}, floatQueryVector)); + assertTrue( + dv.advanceExact(1) + && dv.doubleValue() + == VectorSimilarityFunction.COSINE.compare( + new float[] {0.2f, -3.2f, 3.1f}, floatQueryVector)); + assertFalse(dv.advanceExact(2)); + + byte[] byteQueryVector = new byte[] {-10, 8, 0}; + + dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), byteQueryVector, "knnByteField3"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.COSINE.compare( + new byte[] {-121, -64, -1}, byteQueryVector)); + assertTrue( + dv.advanceExact(1) + && dv.doubleValue() + == VectorSimilarityFunction.COSINE.compare(new byte[] {9, 2, 3}, byteQueryVector)); + assertTrue( + dv.advanceExact(2) + && dv.doubleValue() + == VectorSimilarityFunction.COSINE.compare(new byte[] {4, 2, 3}, byteQueryVector)); + } + + public void testMaximumProductSimilarityValuesSource() throws Exception { + float[] floatQueryVector = new float[] {1.f, -6.f, -10.f}; + + // Checks the computed similarity score between indexed vectors and query vector + // using DVS is correct by passing indexed and query vector in #compare + DoubleValues dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField4"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare( + new float[] {-1.3f, 1.0f, 1.0f}, floatQueryVector)); + assertFalse(dv.advanceExact(1)); + assertFalse(dv.advanceExact(2)); + + byte[] byteQueryVector = new byte[] {-127, 127, 127}; + + dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), byteQueryVector, "knnByteField4"); + assertTrue( + dv.advanceExact(0) + && dv.doubleValue() + == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare( + new byte[] {-127, 127, 127}, byteQueryVector)); + assertTrue( + dv.advanceExact(1) + && dv.doubleValue() + == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare( + new byte[] {14, 29, 31}, byteQueryVector)); + assertTrue( + dv.advanceExact(2) + && dv.doubleValue() + == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare( + new byte[] {-4, -2, -128}, byteQueryVector)); + } + + public void testFailuresWithSimilarityValuesSource() throws Exception { + float[] floatQueryVector = new float[] {1.1f, 2.2f, 3.3f}; + byte[] byteQueryVector = new byte[] {-10, 20, 30}; + + expectThrows( + IllegalArgumentException.class, + () -> + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), floatQueryVector, "knnByteField1")); + expectThrows( + IllegalArgumentException.class, + () -> + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), byteQueryVector, "knnFloatField1")); + + DoubleValues dv = + DoubleValuesSource.similarityToQueryVector( + searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField1"); + assertTrue(dv.advanceExact(0)); + assertEquals( + dv.doubleValue(), + VectorSimilarityFunction.EUCLIDEAN.compare(new float[] {1.f, 2.f, 3.f}, floatQueryVector), + 0.0); + assertNotEquals( + dv.doubleValue(), + VectorSimilarityFunction.DOT_PRODUCT.compare( + new float[] {1.f, 2.f, 3.f}, floatQueryVector)); + assertNotEquals( + dv.doubleValue(), + VectorSimilarityFunction.COSINE.compare(new float[] {1.f, 2.f, 3.f}, floatQueryVector)); + assertNotEquals( + dv.doubleValue(), + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare( + new float[] {1.f, 2.f, 3.f}, floatQueryVector)); + } +}