From 7431f8666c776d507263e2cb0cc27429c5a4724d Mon Sep 17 00:00:00 2001 From: Elia Date: Fri, 28 Apr 2023 12:48:37 +0200 Subject: [PATCH 01/35] Implementation of function values for dense vector --- .../queries/function/FunctionValues.java | 8 + .../ByteDenseVectorSimilarityFunction.java | 35 ++++ .../DenseVectorByteConstValueSource.java | 72 ++++++++ .../DenseVectorByteFieldSource.java | 82 +++++++++ .../valuesource/DenseVectorFieldFunction.java | 54 ++++++ .../DenseVectorFloatConstValueSource.java | 72 ++++++++ .../DenseVectorFloatFieldSource.java | 82 +++++++++ .../DenseVectorSimilarityFunction.java | 92 ++++++++++ .../FloatDenseVectorSimilarityFunction.java | 35 ++++ .../TestKnnVectorSimilarityFunctions.java | 165 ++++++++++++++++++ 10 files changed, 697 insertions(+) create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java create mode 100644 lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java index 0d0a99919006..6b2a8a1f546c 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java @@ -70,6 +70,14 @@ public boolean boolVal(int doc) throws IOException { return intVal(doc) != 0; } + public float[] floatVectorVal(int doc) throws IOException { + throw new UnsupportedOperationException(); + } + + public byte[] byteVectorVal(int doc) throws IOException { + throw new UnsupportedOperationException(); + } + /** * returns the bytes representation of the string val - TODO: should this return the indexed raw * bytes not? diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java new file mode 100644 index 000000000000..91adcdb0b7da --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +import java.io.IOException; + +public class ByteDenseVectorSimilarityFunction extends DenseVectorSimilarityFunction{ + public ByteDenseVectorSimilarityFunction(VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + super(similarityFunction, vector1, vector2); + } + + @Override + protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + checkSize(f1.byteVectorVal(doc).length, f2.byteVectorVal(doc).length); + return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc)); + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java new file mode 100644 index 000000000000..6c26da55f1d7 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +public class DenseVectorByteConstValueSource extends ValueSource { + byte[] vector; + public DenseVectorByteConstValueSource(List constVector) { + this.vector = new byte[constVector.size()]; + for (int i = 0; i < constVector.size(); i++) { + vector[i] = constVector.get(i).byteValue(); + } + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { + return new FunctionValues() { + @Override + public byte[] byteVectorVal(int doc) { + return vector; + } + @Override + public String strVal(int doc) { + return Arrays.toString(vector); + } + + @Override + public String toString(int doc) throws IOException { + return description() + '=' + strVal(doc); + } + }; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DenseVectorByteConstValueSource)) return false; + DenseVectorByteConstValueSource other = (DenseVectorByteConstValueSource) o; + return Arrays.equals(vector, other.vector); + } + + @Override + public int hashCode() { + return getClass().hashCode() * 31 + Arrays.hashCode(vector); + } + + @Override + public String description() { + return "denseVectorConst(" + Arrays.toString(vector) + ')'; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java new file mode 100644 index 000000000000..4349f71e5b7b --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; + +public class DenseVectorByteFieldSource extends ValueSource { + private final String fieldName; + public DenseVectorByteFieldSource(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { + + final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); + return new DenseVectorFieldFunction(this){ + byte[] defaultVector = null; + + @Override + public byte[] byteVectorVal(int doc) throws IOException { + if (exists(doc)){ + return vectorValues.vectorValue(); + } else { + return defaultVector(); + } + } + + @Override + protected DocIdSetIterator getVectorIterator() { + return vectorValues; + } + + private byte[] defaultVector(){ + if (defaultVector == null){ + defaultVector = new byte[vectorValues.dimension()]; + Arrays.fill(defaultVector, (byte) 0); + } + return defaultVector; + } + }; + } + + @Override + public boolean equals(Object o) { + if (o.getClass() != DenseVectorByteFieldSource.class) return false; + DenseVectorByteFieldSource other = (DenseVectorByteFieldSource) o; + return fieldName.equals(other.fieldName); + } + + @Override + public int hashCode() { + return getClass().hashCode() * 31 + fieldName.getClass().hashCode(); + } + + @Override + public String description() { + return "denseByteVectorField(" + fieldName + ")"; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java new file mode 100644 index 000000000000..2d056d1d011e --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +import java.io.IOException; + +public abstract class DenseVectorFieldFunction extends FunctionValues { + + protected final ValueSource vs; + int lastDocID; + + protected DenseVectorFieldFunction(ValueSource vs) { + this.vs = vs; + } + + abstract protected DocIdSetIterator getVectorIterator(); + + @Override + public String toString(int doc) throws IOException { + return vs.description() + strVal(doc); + } + + @Override + public boolean exists(int doc) throws IOException { + if (doc < lastDocID) { + throw new IllegalArgumentException( + "docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + doc); + } + lastDocID = doc; + int curDocID = getVectorIterator().docID(); + if (doc > curDocID) { + curDocID = getVectorIterator().advance(doc); + } + return doc == curDocID; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java new file mode 100644 index 000000000000..a04498ec8467 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +public class DenseVectorFloatConstValueSource extends ValueSource { + float[] vector; + public DenseVectorFloatConstValueSource(List constVector) { + this.vector = new float[constVector.size()]; + for (int i = 0; i < constVector.size(); i++) { + vector[i] = constVector.get(i).floatValue(); + } + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { + return new FunctionValues() { + @Override + public float[] floatVectorVal(int doc) { + return vector; + } + @Override + public String strVal(int doc) { + return Arrays.toString(vector); + } + + @Override + public String toString(int doc) throws IOException { + return description() + '=' + strVal(doc); + } + }; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DenseVectorFloatConstValueSource)) return false; + DenseVectorFloatConstValueSource other = (DenseVectorFloatConstValueSource) o; + return Arrays.equals(vector, other.vector); + } + + @Override + public int hashCode() { + return getClass().hashCode() * 31 + Arrays.hashCode(vector); + } + + @Override + public String description() { + return "denseVectorConst(" + Arrays.toString(vector) + ')'; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java new file mode 100644 index 000000000000..2e5016a42159 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; + +public class DenseVectorFloatFieldSource extends ValueSource { + private final String fieldName; + public DenseVectorFloatFieldSource(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { + + final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); + return new DenseVectorFieldFunction(this){ + float[] defaultVector = null; + + @Override + public float[] floatVectorVal(int doc) throws IOException { + if (exists(doc)){ + return vectorValues.vectorValue(); + } else { + return defaultVector(); + } + } + + @Override + protected DocIdSetIterator getVectorIterator() { + return vectorValues; + } + + private float[] defaultVector(){ + if (defaultVector == null){ + defaultVector = new float[vectorValues.dimension()]; + Arrays.fill(defaultVector, 0.f); + } + return defaultVector; + } + }; + } + + @Override + public boolean equals(Object o) { + if (o.getClass() != DenseVectorFloatFieldSource.class) return false; + DenseVectorFloatFieldSource other = (DenseVectorFloatFieldSource) o; + return fieldName.equals(other.fieldName); + } + + @Override + public int hashCode() { + return getClass().hashCode() * 31 + fieldName.getClass().hashCode(); + } + + @Override + public String description() { + return "denseFloatVectorField(" + fieldName + ")"; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java new file mode 100644 index 000000000000..8334882875f0 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +import java.io.IOException; +import java.util.Map; + +public abstract class DenseVectorSimilarityFunction extends ValueSource { + + protected final VectorSimilarityFunction similarityFunction; + protected final ValueSource vector1; + protected final ValueSource vector2; + + public DenseVectorSimilarityFunction(VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + + this.similarityFunction = similarityFunction; + this.vector1 = vector1; + this.vector2 = vector2; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + + final FunctionValues vector1Vals = vector1.getValues(context, readerContext); + final FunctionValues vector2Vals = vector2.getValues(context, readerContext); + return new FunctionValues() { + @Override + public float floatVal(int doc) throws IOException { + return func(doc, vector1Vals, vector2Vals); + } + + @Override + public boolean exists(int doc) throws IOException { + return MultiFunction.allExists(doc, vector1Vals, vector2Vals); + } + + @Override + public String toString(int doc) throws IOException { + return null; + } + }; + } + + protected void checkSize(int sizeVector1, int sizeVector2) throws IOException { + if (sizeVector1 != sizeVector2){ + throw new UnsupportedOperationException("Vectors must have the same size"); + } + } + + protected abstract float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException; + + @Override + public boolean equals(Object o) { + return o instanceof DenseVectorSimilarityFunction && + similarityFunction.equals(((DenseVectorSimilarityFunction) o).similarityFunction) && + vector1.equals(((DenseVectorSimilarityFunction) o).vector1) && + vector2.equals(((DenseVectorSimilarityFunction) o).vector2); + } + + @Override + public int hashCode() { + int h = similarityFunction.hashCode(); + h = 31 * h + vector1.hashCode(); + h = 31 * h + vector2.hashCode(); + return h; + } + + @Override + public String description() { + return similarityFunction.name() + "(" + vector1.description() + ", " + vector2.description() + ")"; + } +} \ No newline at end of file diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java new file mode 100644 index 000000000000..61ae487ba789 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +import java.io.IOException; + +public class FloatDenseVectorSimilarityFunction extends DenseVectorSimilarityFunction { + public FloatDenseVectorSimilarityFunction(VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + super(similarityFunction, vector1, vector2); + } + + @Override + protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + checkSize(f1.floatVectorVal(doc).length, f2.floatVectorVal(doc).length); + return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc)); + } +} diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java new file mode 100644 index 000000000000..a23801278d30 --- /dev/null +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.ByteDenseVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.DenseVectorByteConstValueSource; +import org.apache.lucene.queries.function.valuesource.DenseVectorByteFieldSource; +import org.apache.lucene.queries.function.valuesource.DenseVectorFloatConstValueSource; +import org.apache.lucene.queries.function.valuesource.DenseVectorFloatFieldSource; +import org.apache.lucene.queries.function.valuesource.FloatDenseVectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.CheckHits; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import java.util.List; + +public class TestKnnVectorSimilarityFunctions extends LuceneTestCase { + static Directory dir; + static Analyzer analyzer; + static IndexReader reader; + static IndexSearcher searcher; + static final List documents = List.of("1", "2"); + + @BeforeClass + public static void beforeClass() throws Exception { + dir = newDirectory(); + analyzer = new MockAnalyzer(random()); + IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer); + iwConfig.setMergePolicy(newLogMergePolicy()); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig); + for (String docId : documents) { + Document document = new Document(); + document.add(new StringField("id", docId, Field.Store.NO)); + document.add(new SortedDocValuesField("id", new BytesRef(docId))); + document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document.add(new KnnFloatVectorField("knnFloatField2", new float[]{5.2f, 3.2f, 3.1f})); + document.add(new KnnByteVectorField("knnByteField1", new byte[]{1, 2, 3})); + document.add(new KnnByteVectorField("knnByteField2", new byte[]{4, 2, 3})); + iw.addDocument(document); + } + + reader = iw.getReader(); + searcher = newSearcher(reader); + iw.close(); + } + + @AfterClass + public static void afterClass() throws Exception { + searcher = null; + reader.close(); + reader = null; + dir.close(); + dir = null; + analyzer.close(); + analyzer = null; + } + + public void testFloatVectorSimilarityFunctionConst() throws Exception { + ValueSource v1 = new DenseVectorFloatConstValueSource(List.of(1,2,3)); + ValueSource v2 = new DenseVectorFloatConstValueSource(List.of(5,4,1)); + assertHits(new FunctionQuery(new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.04f, 0.04f}); + } + + public void testByteVectorSimilarityFunctionConst() throws Exception { + ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1,2,3)); + ValueSource v2 = new DenseVectorByteConstValueSource(List.of(2,5,6)); + assertHits(new FunctionQuery(new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.05f, 0.05f}); + } + + public void testFloatVectorSimilarityFunctionField() throws Exception { + ValueSource v1 = new DenseVectorFloatFieldSource("knnFloatField1"); + ValueSource v2 = new DenseVectorFloatFieldSource("knnFloatField2"); + assertHits(new FunctionQuery(new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.049776014f, 0.049776014f}); + } + + public void testByteVectorSimilarityFunctionField() throws Exception { + ValueSource v1 = new DenseVectorByteFieldSource("knnByteField1"); + ValueSource v2 = new DenseVectorByteFieldSource("knnByteField2"); + assertHits(new FunctionQuery(new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.1f, 0.1f}); + } + + public void testFloatVectorSimilarityFunctionMixed() throws Exception { + ValueSource v1 = new DenseVectorFloatConstValueSource(List.of(1,2,4)); + ValueSource v2 = new DenseVectorFloatFieldSource("knnFloatField1"); + assertHits(new FunctionQuery(new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.5f, 0.5f}); + } + + public void testByteVectorSimilarityFunctionMixed() throws Exception { + ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1,2,4)); + ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); + assertHits(new FunctionQuery(new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.5f, 0.5f}); + } + + public void testDismatchDimension() { + ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1,2,3,4)); + ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); + ByteDenseVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows(UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + + v1 = new DenseVectorFloatConstValueSource(List.of(1,2)); + v2 = new DenseVectorFloatFieldSource("knnFloatField1"); + FloatDenseVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows(UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); + } + + public void testMismatchType() { + ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1,2,3)); + ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); + FloatDenseVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows(UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); + + v1 = new DenseVectorByteConstValueSource(List.of(1,2,3)); + v2 = new DenseVectorFloatFieldSource("knnByteField1"); + ByteDenseVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows(UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + } + + + public static void assertHits(Query q, float[] scores) throws Exception { + ScoreDoc[] expected = new ScoreDoc[scores.length]; + int[] expectedDocs = new int[scores.length]; + for (int i = 0; i < expected.length; i++) { + expectedDocs[i] = i; + expected[i] = new ScoreDoc(i, scores[i]); + } + TopDocs docs = + searcher.search( + q, documents.size(), new Sort(new SortField("id", SortField.Type.STRING)), true); + CheckHits.checkHitsQuery(q, expected, docs.scoreDocs, expectedDocs); + } +} From ffe7bab1d6f65117023a4a00c22f648e8af0d8e9 Mon Sep 17 00:00:00 2001 From: Elia Date: Fri, 28 Apr 2023 13:06:52 +0200 Subject: [PATCH 02/35] tidy --- .../ByteDenseVectorSimilarityFunction.java | 22 +- .../DenseVectorByteConstValueSource.java | 84 +++--- .../DenseVectorByteFieldSource.java | 97 +++---- .../valuesource/DenseVectorFieldFunction.java | 47 ++-- .../DenseVectorFloatConstValueSource.java | 84 +++--- .../DenseVectorFloatFieldSource.java | 97 +++---- .../DenseVectorSimilarityFunction.java | 115 ++++---- .../FloatDenseVectorSimilarityFunction.java | 20 +- .../TestKnnVectorSimilarityFunctions.java | 251 ++++++++++-------- 9 files changed, 428 insertions(+), 389 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java index 91adcdb0b7da..a58b17b6e59a 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java @@ -16,20 +16,20 @@ */ package org.apache.lucene.queries.function.valuesource; +import java.io.IOException; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; -import java.io.IOException; - -public class ByteDenseVectorSimilarityFunction extends DenseVectorSimilarityFunction{ - public ByteDenseVectorSimilarityFunction(VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { - super(similarityFunction, vector1, vector2); - } +public class ByteDenseVectorSimilarityFunction extends DenseVectorSimilarityFunction { + public ByteDenseVectorSimilarityFunction( + VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + super(similarityFunction, vector1, vector2); + } - @Override - protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { - checkSize(f1.byteVectorVal(doc).length, f2.byteVectorVal(doc).length); - return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc)); - } + @Override + protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + checkSize(f1.byteVectorVal(doc).length, f2.byteVectorVal(doc).length); + return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc)); + } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java index 6c26da55f1d7..057813882aa1 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java @@ -16,57 +16,59 @@ */ package org.apache.lucene.queries.function.valuesource; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.queries.function.FunctionValues; -import org.apache.lucene.queries.function.ValueSource; - import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Map; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; public class DenseVectorByteConstValueSource extends ValueSource { - byte[] vector; - public DenseVectorByteConstValueSource(List constVector) { - this.vector = new byte[constVector.size()]; - for (int i = 0; i < constVector.size(); i++) { - vector[i] = constVector.get(i).byteValue(); - } + byte[] vector; + + public DenseVectorByteConstValueSource(List constVector) { + this.vector = new byte[constVector.size()]; + for (int i = 0; i < constVector.size(); i++) { + vector[i] = constVector.get(i).byteValue(); } + } - @Override - public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { - return new FunctionValues() { - @Override - public byte[] byteVectorVal(int doc) { - return vector; - } - @Override - public String strVal(int doc) { - return Arrays.toString(vector); - } + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + return new FunctionValues() { + @Override + public byte[] byteVectorVal(int doc) { + return vector; + } - @Override - public String toString(int doc) throws IOException { - return description() + '=' + strVal(doc); - } - }; - } + @Override + public String strVal(int doc) { + return Arrays.toString(vector); + } - @Override - public boolean equals(Object o) { - if (!(o instanceof DenseVectorByteConstValueSource)) return false; - DenseVectorByteConstValueSource other = (DenseVectorByteConstValueSource) o; - return Arrays.equals(vector, other.vector); - } + @Override + public String toString(int doc) throws IOException { + return description() + '=' + strVal(doc); + } + }; + } - @Override - public int hashCode() { - return getClass().hashCode() * 31 + Arrays.hashCode(vector); - } + @Override + public boolean equals(Object o) { + if (!(o instanceof DenseVectorByteConstValueSource)) return false; + DenseVectorByteConstValueSource other = (DenseVectorByteConstValueSource) o; + return Arrays.equals(vector, other.vector); + } - @Override - public String description() { - return "denseVectorConst(" + Arrays.toString(vector) + ')'; - } + @Override + public int hashCode() { + return getClass().hashCode() * 31 + Arrays.hashCode(vector); + } + + @Override + public String description() { + return "denseVectorConst(" + Arrays.toString(vector) + ')'; + } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java index 4349f71e5b7b..4367261a5a9e 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java @@ -16,67 +16,68 @@ */ package org.apache.lucene.queries.function.valuesource; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; -import java.io.IOException; -import java.util.Arrays; -import java.util.Map; - public class DenseVectorByteFieldSource extends ValueSource { - private final String fieldName; - public DenseVectorByteFieldSource(String fieldName) { - this.fieldName = fieldName; - } + private final String fieldName; + + public DenseVectorByteFieldSource(String fieldName) { + this.fieldName = fieldName; + } - @Override - public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { - final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); - return new DenseVectorFieldFunction(this){ - byte[] defaultVector = null; + final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); + return new DenseVectorFieldFunction(this) { + byte[] defaultVector = null; - @Override - public byte[] byteVectorVal(int doc) throws IOException { - if (exists(doc)){ - return vectorValues.vectorValue(); - } else { - return defaultVector(); - } - } + @Override + public byte[] byteVectorVal(int doc) throws IOException { + if (exists(doc)) { + return vectorValues.vectorValue(); + } else { + return defaultVector(); + } + } - @Override - protected DocIdSetIterator getVectorIterator() { - return vectorValues; - } + @Override + protected DocIdSetIterator getVectorIterator() { + return vectorValues; + } - private byte[] defaultVector(){ - if (defaultVector == null){ - defaultVector = new byte[vectorValues.dimension()]; - Arrays.fill(defaultVector, (byte) 0); - } - return defaultVector; - } - }; - } + private byte[] defaultVector() { + if (defaultVector == null) { + defaultVector = new byte[vectorValues.dimension()]; + Arrays.fill(defaultVector, (byte) 0); + } + return defaultVector; + } + }; + } - @Override - public boolean equals(Object o) { - if (o.getClass() != DenseVectorByteFieldSource.class) return false; - DenseVectorByteFieldSource other = (DenseVectorByteFieldSource) o; - return fieldName.equals(other.fieldName); - } + @Override + public boolean equals(Object o) { + if (o.getClass() != DenseVectorByteFieldSource.class) return false; + DenseVectorByteFieldSource other = (DenseVectorByteFieldSource) o; + return fieldName.equals(other.fieldName); + } - @Override - public int hashCode() { - return getClass().hashCode() * 31 + fieldName.getClass().hashCode(); - } + @Override + public int hashCode() { + return getClass().hashCode() * 31 + fieldName.getClass().hashCode(); + } - @Override - public String description() { - return "denseByteVectorField(" + fieldName + ")"; - } + @Override + public String description() { + return "denseByteVectorField(" + fieldName + ")"; + } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java index 2d056d1d011e..8629e1c9f8a6 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java @@ -16,39 +16,38 @@ */ package org.apache.lucene.queries.function.valuesource; +import java.io.IOException; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; -import java.io.IOException; - public abstract class DenseVectorFieldFunction extends FunctionValues { - protected final ValueSource vs; - int lastDocID; + protected final ValueSource vs; + int lastDocID; - protected DenseVectorFieldFunction(ValueSource vs) { - this.vs = vs; - } + protected DenseVectorFieldFunction(ValueSource vs) { + this.vs = vs; + } - abstract protected DocIdSetIterator getVectorIterator(); + protected abstract DocIdSetIterator getVectorIterator(); - @Override - public String toString(int doc) throws IOException { - return vs.description() + strVal(doc); - } + @Override + public String toString(int doc) throws IOException { + return vs.description() + strVal(doc); + } - @Override - public boolean exists(int doc) throws IOException { - if (doc < lastDocID) { - throw new IllegalArgumentException( - "docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + doc); - } - lastDocID = doc; - int curDocID = getVectorIterator().docID(); - if (doc > curDocID) { - curDocID = getVectorIterator().advance(doc); - } - return doc == curDocID; + @Override + public boolean exists(int doc) throws IOException { + if (doc < lastDocID) { + throw new IllegalArgumentException( + "docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + doc); + } + lastDocID = doc; + int curDocID = getVectorIterator().docID(); + if (doc > curDocID) { + curDocID = getVectorIterator().advance(doc); } + return doc == curDocID; + } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java index a04498ec8467..d5b976f2275e 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java @@ -16,57 +16,59 @@ */ package org.apache.lucene.queries.function.valuesource; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.queries.function.FunctionValues; -import org.apache.lucene.queries.function.ValueSource; - import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Map; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; public class DenseVectorFloatConstValueSource extends ValueSource { - float[] vector; - public DenseVectorFloatConstValueSource(List constVector) { - this.vector = new float[constVector.size()]; - for (int i = 0; i < constVector.size(); i++) { - vector[i] = constVector.get(i).floatValue(); - } + float[] vector; + + public DenseVectorFloatConstValueSource(List constVector) { + this.vector = new float[constVector.size()]; + for (int i = 0; i < constVector.size(); i++) { + vector[i] = constVector.get(i).floatValue(); } + } - @Override - public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { - return new FunctionValues() { - @Override - public float[] floatVectorVal(int doc) { - return vector; - } - @Override - public String strVal(int doc) { - return Arrays.toString(vector); - } + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + return new FunctionValues() { + @Override + public float[] floatVectorVal(int doc) { + return vector; + } - @Override - public String toString(int doc) throws IOException { - return description() + '=' + strVal(doc); - } - }; - } + @Override + public String strVal(int doc) { + return Arrays.toString(vector); + } - @Override - public boolean equals(Object o) { - if (!(o instanceof DenseVectorFloatConstValueSource)) return false; - DenseVectorFloatConstValueSource other = (DenseVectorFloatConstValueSource) o; - return Arrays.equals(vector, other.vector); - } + @Override + public String toString(int doc) throws IOException { + return description() + '=' + strVal(doc); + } + }; + } - @Override - public int hashCode() { - return getClass().hashCode() * 31 + Arrays.hashCode(vector); - } + @Override + public boolean equals(Object o) { + if (!(o instanceof DenseVectorFloatConstValueSource)) return false; + DenseVectorFloatConstValueSource other = (DenseVectorFloatConstValueSource) o; + return Arrays.equals(vector, other.vector); + } - @Override - public String description() { - return "denseVectorConst(" + Arrays.toString(vector) + ')'; - } + @Override + public int hashCode() { + return getClass().hashCode() * 31 + Arrays.hashCode(vector); + } + + @Override + public String description() { + return "denseVectorConst(" + Arrays.toString(vector) + ')'; + } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java index 2e5016a42159..6cbe45ad04c3 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java @@ -16,67 +16,68 @@ */ package org.apache.lucene.queries.function.valuesource; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; -import java.io.IOException; -import java.util.Arrays; -import java.util.Map; - public class DenseVectorFloatFieldSource extends ValueSource { - private final String fieldName; - public DenseVectorFloatFieldSource(String fieldName) { - this.fieldName = fieldName; - } + private final String fieldName; + + public DenseVectorFloatFieldSource(String fieldName) { + this.fieldName = fieldName; + } - @Override - public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { - final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); - return new DenseVectorFieldFunction(this){ - float[] defaultVector = null; + final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); + return new DenseVectorFieldFunction(this) { + float[] defaultVector = null; - @Override - public float[] floatVectorVal(int doc) throws IOException { - if (exists(doc)){ - return vectorValues.vectorValue(); - } else { - return defaultVector(); - } - } + @Override + public float[] floatVectorVal(int doc) throws IOException { + if (exists(doc)) { + return vectorValues.vectorValue(); + } else { + return defaultVector(); + } + } - @Override - protected DocIdSetIterator getVectorIterator() { - return vectorValues; - } + @Override + protected DocIdSetIterator getVectorIterator() { + return vectorValues; + } - private float[] defaultVector(){ - if (defaultVector == null){ - defaultVector = new float[vectorValues.dimension()]; - Arrays.fill(defaultVector, 0.f); - } - return defaultVector; - } - }; - } + private float[] defaultVector() { + if (defaultVector == null) { + defaultVector = new float[vectorValues.dimension()]; + Arrays.fill(defaultVector, 0.f); + } + return defaultVector; + } + }; + } - @Override - public boolean equals(Object o) { - if (o.getClass() != DenseVectorFloatFieldSource.class) return false; - DenseVectorFloatFieldSource other = (DenseVectorFloatFieldSource) o; - return fieldName.equals(other.fieldName); - } + @Override + public boolean equals(Object o) { + if (o.getClass() != DenseVectorFloatFieldSource.class) return false; + DenseVectorFloatFieldSource other = (DenseVectorFloatFieldSource) o; + return fieldName.equals(other.fieldName); + } - @Override - public int hashCode() { - return getClass().hashCode() * 31 + fieldName.getClass().hashCode(); - } + @Override + public int hashCode() { + return getClass().hashCode() * 31 + fieldName.getClass().hashCode(); + } - @Override - public String description() { - return "denseFloatVectorField(" + fieldName + ")"; - } + @Override + public String description() { + return "denseFloatVectorField(" + fieldName + ")"; + } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java index 8334882875f0..7a022cc83fe5 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java @@ -16,77 +16,82 @@ */ package org.apache.lucene.queries.function.valuesource; +import java.io.IOException; +import java.util.Map; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; -import java.io.IOException; -import java.util.Map; - public abstract class DenseVectorSimilarityFunction extends ValueSource { - protected final VectorSimilarityFunction similarityFunction; - protected final ValueSource vector1; - protected final ValueSource vector2; + protected final VectorSimilarityFunction similarityFunction; + protected final ValueSource vector1; + protected final ValueSource vector2; - public DenseVectorSimilarityFunction(VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + public DenseVectorSimilarityFunction( + VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { - this.similarityFunction = similarityFunction; - this.vector1 = vector1; - this.vector2 = vector2; - } + this.similarityFunction = similarityFunction; + this.vector1 = vector1; + this.vector2 = vector2; + } - @Override - public FunctionValues getValues(Map context, LeafReaderContext readerContext) - throws IOException { + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { - final FunctionValues vector1Vals = vector1.getValues(context, readerContext); - final FunctionValues vector2Vals = vector2.getValues(context, readerContext); - return new FunctionValues() { - @Override - public float floatVal(int doc) throws IOException { - return func(doc, vector1Vals, vector2Vals); - } + final FunctionValues vector1Vals = vector1.getValues(context, readerContext); + final FunctionValues vector2Vals = vector2.getValues(context, readerContext); + return new FunctionValues() { + @Override + public float floatVal(int doc) throws IOException { + return func(doc, vector1Vals, vector2Vals); + } - @Override - public boolean exists(int doc) throws IOException { - return MultiFunction.allExists(doc, vector1Vals, vector2Vals); - } + @Override + public boolean exists(int doc) throws IOException { + return MultiFunction.allExists(doc, vector1Vals, vector2Vals); + } - @Override - public String toString(int doc) throws IOException { - return null; - } - }; - } + @Override + public String toString(int doc) throws IOException { + return null; + } + }; + } - protected void checkSize(int sizeVector1, int sizeVector2) throws IOException { - if (sizeVector1 != sizeVector2){ - throw new UnsupportedOperationException("Vectors must have the same size"); - } + protected void checkSize(int sizeVector1, int sizeVector2) throws IOException { + if (sizeVector1 != sizeVector2) { + throw new UnsupportedOperationException("Vectors must have the same size"); } + } - protected abstract float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException; + protected abstract float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException; - @Override - public boolean equals(Object o) { - return o instanceof DenseVectorSimilarityFunction && - similarityFunction.equals(((DenseVectorSimilarityFunction) o).similarityFunction) && - vector1.equals(((DenseVectorSimilarityFunction) o).vector1) && - vector2.equals(((DenseVectorSimilarityFunction) o).vector2); - } + @Override + public boolean equals(Object o) { + return o instanceof DenseVectorSimilarityFunction + && similarityFunction.equals(((DenseVectorSimilarityFunction) o).similarityFunction) + && vector1.equals(((DenseVectorSimilarityFunction) o).vector1) + && vector2.equals(((DenseVectorSimilarityFunction) o).vector2); + } - @Override - public int hashCode() { - int h = similarityFunction.hashCode(); - h = 31 * h + vector1.hashCode(); - h = 31 * h + vector2.hashCode(); - return h; - } + @Override + public int hashCode() { + int h = similarityFunction.hashCode(); + h = 31 * h + vector1.hashCode(); + h = 31 * h + vector2.hashCode(); + return h; + } - @Override - public String description() { - return similarityFunction.name() + "(" + vector1.description() + ", " + vector2.description() + ")"; - } -} \ No newline at end of file + @Override + public String description() { + return similarityFunction.name() + + "(" + + vector1.description() + + ", " + + vector2.description() + + ")"; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java index 61ae487ba789..336438cc360b 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java @@ -16,20 +16,20 @@ */ package org.apache.lucene.queries.function.valuesource; +import java.io.IOException; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; -import java.io.IOException; - public class FloatDenseVectorSimilarityFunction extends DenseVectorSimilarityFunction { - public FloatDenseVectorSimilarityFunction(VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { - super(similarityFunction, vector1, vector2); - } + public FloatDenseVectorSimilarityFunction( + VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + super(similarityFunction, vector1, vector2); + } - @Override - protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { - checkSize(f1.floatVectorVal(doc).length, f2.floatVectorVal(doc).length); - return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc)); - } + @Override + protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + checkSize(f1.floatVectorVal(doc).length, f2.floatVectorVal(doc).length); + return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc)); + } } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index a23801278d30..44b8317e8bc6 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -16,6 +16,7 @@ */ package org.apache.lucene.queries.function; +import java.util.List; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -46,120 +47,148 @@ import org.apache.lucene.util.BytesRef; import org.junit.AfterClass; import org.junit.BeforeClass; -import java.util.List; public class TestKnnVectorSimilarityFunctions extends LuceneTestCase { - static Directory dir; - static Analyzer analyzer; - static IndexReader reader; - static IndexSearcher searcher; - static final List documents = List.of("1", "2"); - - @BeforeClass - public static void beforeClass() throws Exception { - dir = newDirectory(); - analyzer = new MockAnalyzer(random()); - IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer); - iwConfig.setMergePolicy(newLogMergePolicy()); - RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig); - for (String docId : documents) { - Document document = new Document(); - document.add(new StringField("id", docId, Field.Store.NO)); - document.add(new SortedDocValuesField("id", new BytesRef(docId))); - document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); - document.add(new KnnFloatVectorField("knnFloatField2", new float[]{5.2f, 3.2f, 3.1f})); - document.add(new KnnByteVectorField("knnByteField1", new byte[]{1, 2, 3})); - document.add(new KnnByteVectorField("knnByteField2", new byte[]{4, 2, 3})); - iw.addDocument(document); - } - - reader = iw.getReader(); - searcher = newSearcher(reader); - iw.close(); - } - - @AfterClass - public static void afterClass() throws Exception { - searcher = null; - reader.close(); - reader = null; - dir.close(); - dir = null; - analyzer.close(); - analyzer = null; - } - - public void testFloatVectorSimilarityFunctionConst() throws Exception { - ValueSource v1 = new DenseVectorFloatConstValueSource(List.of(1,2,3)); - ValueSource v2 = new DenseVectorFloatConstValueSource(List.of(5,4,1)); - assertHits(new FunctionQuery(new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.04f, 0.04f}); - } - - public void testByteVectorSimilarityFunctionConst() throws Exception { - ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1,2,3)); - ValueSource v2 = new DenseVectorByteConstValueSource(List.of(2,5,6)); - assertHits(new FunctionQuery(new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.05f, 0.05f}); - } - - public void testFloatVectorSimilarityFunctionField() throws Exception { - ValueSource v1 = new DenseVectorFloatFieldSource("knnFloatField1"); - ValueSource v2 = new DenseVectorFloatFieldSource("knnFloatField2"); - assertHits(new FunctionQuery(new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.049776014f, 0.049776014f}); - } - - public void testByteVectorSimilarityFunctionField() throws Exception { - ValueSource v1 = new DenseVectorByteFieldSource("knnByteField1"); - ValueSource v2 = new DenseVectorByteFieldSource("knnByteField2"); - assertHits(new FunctionQuery(new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.1f, 0.1f}); + static Directory dir; + static Analyzer analyzer; + static IndexReader reader; + static IndexSearcher searcher; + static final List documents = List.of("1", "2"); + + @BeforeClass + public static void beforeClass() throws Exception { + dir = newDirectory(); + analyzer = new MockAnalyzer(random()); + IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer); + iwConfig.setMergePolicy(newLogMergePolicy()); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig); + for (String docId : documents) { + Document document = new Document(); + document.add(new StringField("id", docId, Field.Store.NO)); + document.add(new SortedDocValuesField("id", new BytesRef(docId))); + document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document.add(new KnnFloatVectorField("knnFloatField2", new float[] {5.2f, 3.2f, 3.1f})); + document.add(new KnnByteVectorField("knnByteField1", new byte[] {1, 2, 3})); + document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); + iw.addDocument(document); } - public void testFloatVectorSimilarityFunctionMixed() throws Exception { - ValueSource v1 = new DenseVectorFloatConstValueSource(List.of(1,2,4)); - ValueSource v2 = new DenseVectorFloatFieldSource("knnFloatField1"); - assertHits(new FunctionQuery(new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.5f, 0.5f}); - } - - public void testByteVectorSimilarityFunctionMixed() throws Exception { - ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1,2,4)); - ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); - assertHits(new FunctionQuery(new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.5f, 0.5f}); - } - - public void testDismatchDimension() { - ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1,2,3,4)); - ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); - ByteDenseVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); - assertThrows(UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); - - v1 = new DenseVectorFloatConstValueSource(List.of(1,2)); - v2 = new DenseVectorFloatFieldSource("knnFloatField1"); - FloatDenseVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); - assertThrows(UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); - } - - public void testMismatchType() { - ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1,2,3)); - ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); - FloatDenseVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); - assertThrows(UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); - - v1 = new DenseVectorByteConstValueSource(List.of(1,2,3)); - v2 = new DenseVectorFloatFieldSource("knnByteField1"); - ByteDenseVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); - assertThrows(UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); - } - - - public static void assertHits(Query q, float[] scores) throws Exception { - ScoreDoc[] expected = new ScoreDoc[scores.length]; - int[] expectedDocs = new int[scores.length]; - for (int i = 0; i < expected.length; i++) { - expectedDocs[i] = i; - expected[i] = new ScoreDoc(i, scores[i]); - } - TopDocs docs = - searcher.search( - q, documents.size(), new Sort(new SortField("id", SortField.Type.STRING)), true); - CheckHits.checkHitsQuery(q, expected, docs.scoreDocs, expectedDocs); + reader = iw.getReader(); + searcher = newSearcher(reader); + iw.close(); + } + + @AfterClass + public static void afterClass() throws Exception { + searcher = null; + reader.close(); + reader = null; + dir.close(); + dir = null; + analyzer.close(); + analyzer = null; + } + + public void testFloatVectorSimilarityFunctionConst() throws Exception { + ValueSource v1 = new DenseVectorFloatConstValueSource(List.of(1, 2, 3)); + ValueSource v2 = new DenseVectorFloatConstValueSource(List.of(5, 4, 1)); + assertHits( + new FunctionQuery( + new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.04f, 0.04f}); + } + + public void testByteVectorSimilarityFunctionConst() throws Exception { + ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 3)); + ValueSource v2 = new DenseVectorByteConstValueSource(List.of(2, 5, 6)); + assertHits( + new FunctionQuery( + new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.05f, 0.05f}); + } + + public void testFloatVectorSimilarityFunctionField() throws Exception { + ValueSource v1 = new DenseVectorFloatFieldSource("knnFloatField1"); + ValueSource v2 = new DenseVectorFloatFieldSource("knnFloatField2"); + assertHits( + new FunctionQuery( + new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.049776014f, 0.049776014f}); + } + + public void testByteVectorSimilarityFunctionField() throws Exception { + ValueSource v1 = new DenseVectorByteFieldSource("knnByteField1"); + ValueSource v2 = new DenseVectorByteFieldSource("knnByteField2"); + assertHits( + new FunctionQuery( + new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.1f, 0.1f}); + } + + public void testFloatVectorSimilarityFunctionMixed() throws Exception { + ValueSource v1 = new DenseVectorFloatConstValueSource(List.of(1, 2, 4)); + ValueSource v2 = new DenseVectorFloatFieldSource("knnFloatField1"); + assertHits( + new FunctionQuery( + new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.5f}); + } + + public void testByteVectorSimilarityFunctionMixed() throws Exception { + ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 4)); + ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); + assertHits( + new FunctionQuery( + new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.5f}); + } + + public void testDismatchDimension() { + ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 3, 4)); + ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); + ByteDenseVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + UnsupportedOperationException.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + + v1 = new DenseVectorFloatConstValueSource(List.of(1, 2)); + v2 = new DenseVectorFloatFieldSource("knnFloatField1"); + FloatDenseVectorSimilarityFunction floatDenseVectorSimilarityFunction = + new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + UnsupportedOperationException.class, + () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); + } + + public void testMismatchType() { + ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 3)); + ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); + FloatDenseVectorSimilarityFunction floatDenseVectorSimilarityFunction = + new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + UnsupportedOperationException.class, + () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); + + v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 3)); + v2 = new DenseVectorFloatFieldSource("knnByteField1"); + ByteDenseVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + UnsupportedOperationException.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + } + + public static void assertHits(Query q, float[] scores) throws Exception { + ScoreDoc[] expected = new ScoreDoc[scores.length]; + int[] expectedDocs = new int[scores.length]; + for (int i = 0; i < expected.length; i++) { + expectedDocs[i] = i; + expected[i] = new ScoreDoc(i, scores[i]); } + TopDocs docs = + searcher.search( + q, documents.size(), new Sort(new SortField("id", SortField.Type.STRING)), true); + CheckHits.checkHitsQuery(q, expected, docs.scoreDocs, expectedDocs); + } } From 92b96511ad11f3e2dbd8f9a6dd384f97c5c3e153 Mon Sep 17 00:00:00 2001 From: Elia Date: Fri, 28 Apr 2023 13:21:51 +0200 Subject: [PATCH 03/35] Add toString for DenseVectorSimilarityFunction --- .../valuesource/DenseVectorSimilarityFunction.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java index 7a022cc83fe5..7894961adc09 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java @@ -49,6 +49,11 @@ public float floatVal(int doc) throws IOException { return func(doc, vector1Vals, vector2Vals); } + @Override + String strVal(int doc) throws IOException { + return Float.toString(floatVal(doc)); + } + @Override public boolean exists(int doc) throws IOException { return MultiFunction.allExists(doc, vector1Vals, vector2Vals); @@ -56,7 +61,7 @@ public boolean exists(int doc) throws IOException { @Override public String toString(int doc) throws IOException { - return null; + return description() + " = " + strVal(doc); } }; } From 36d682417e114665b7f0721778f6ec024311239a Mon Sep 17 00:00:00 2001 From: Elia Date: Fri, 28 Apr 2023 14:33:17 +0200 Subject: [PATCH 04/35] minor fix --- .../function/valuesource/DenseVectorSimilarityFunction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java index 7894961adc09..e24653cd3491 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java @@ -50,7 +50,7 @@ public float floatVal(int doc) throws IOException { } @Override - String strVal(int doc) throws IOException { + public String strVal(int doc) throws IOException { return Float.toString(floatVal(doc)); } From 533c6fc23bf45d8eb7764a9f5b20d628a5d0fd99 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 27 Apr 2023 10:50:22 -0500 Subject: [PATCH 05/35] Use HashMap (was TreeMap) for OnHeapHnswGraph neighbors --- .../lucene91/Lucene91HnswVectorsWriter.java | 14 +-- .../lucene92/Lucene92HnswVectorsWriter.java | 14 +-- .../lucene94/Lucene94HnswVectorsWriter.java | 18 ++- .../lucene95/Lucene95HnswVectorsWriter.java | 22 ++-- .../apache/lucene/util/hnsw/HnswGraph.java | 6 +- .../lucene/util/hnsw/OnHeapHnswGraph.java | 22 +++- .../lucene/util/hnsw/HnswGraphTestCase.java | 114 ++++++++++++++---- 7 files changed, 146 insertions(+), 64 deletions(-) diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index 68c8967b9b28..9281b6374411 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -25,6 +25,7 @@ import java.util.Arrays; import org.apache.lucene.codecs.BufferingKnnVectorsWriter; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -36,7 +37,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** @@ -227,11 +227,10 @@ private void writeMeta( } else { meta.writeInt(graph.numLevels()); for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - meta.writeInt(nodesOnLevel.size()); // number of nodes on a level + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + meta.writeInt(sortedNodes.length); // number of nodes on a level if (level > 0) { - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + for (int node : sortedNodes) { meta.writeInt(node); // list of nodes on a level } } @@ -257,9 +256,8 @@ private Lucene91OnHeapHnswGraph writeGraph( // write vectors' neighbours on each level into the vectorIndex file int countOnLevel0 = graph.size(); for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + for (int node : sortedNodes) { Lucene91NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); vectorIndex.writeInt(size); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index 1480d1aea2e4..b2e7629aed1c 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -27,6 +27,7 @@ import org.apache.lucene.codecs.BufferingKnnVectorsWriter; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -39,7 +40,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; @@ -261,11 +261,10 @@ private void writeMeta( } else { meta.writeInt(graph.numLevels()); for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - meta.writeInt(nodesOnLevel.size()); // number of nodes on a level + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + meta.writeInt(sortedNodes.length); // number of nodes on a level if (level > 0) { - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + for (int node : sortedNodes) { meta.writeInt(node); // list of nodes on a level } } @@ -293,9 +292,8 @@ private OnHeapHnswGraph writeGraph( int countOnLevel0 = graph.size(); for (int level = 0; level < graph.numLevels(); level++) { int maxConnOnLevel = level == 0 ? (M * 2) : M; - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + for (int node : sortedNodes) { NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); vectorIndex.writeInt(size); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index f6f378027603..9a2a156f98ac 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -30,6 +30,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -303,9 +304,8 @@ private HnswGraph reconstructAndWriteGraph( for (int level = 1; level < graph.numLevels(); level++) { NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); int[] newNodes = new int[nodesOnLevel.size()]; - int n = 0; - while (nodesOnLevel.hasNext()) { - newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()]; + for (int n = 0; nodesOnLevel.hasNext(); n++) { + newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()]; } Arrays.sort(newNodes); nodesByLevel.add(newNodes); @@ -481,9 +481,8 @@ private void writeGraph(OnHeapHnswGraph graph) throws IOException { int countOnLevel0 = graph.size(); for (int level = 0; level < graph.numLevels(); level++) { int maxConnOnLevel = level == 0 ? (M * 2) : M; - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + for (int node : sortedNodes) { NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); vectorIndex.writeInt(size); @@ -570,11 +569,10 @@ private void writeMeta( } else { meta.writeInt(graph.numLevels()); for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - meta.writeInt(nodesOnLevel.size()); // number of nodes on a level + int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level)); + meta.writeInt(sortedNodes.length); // number of nodes on a level if (level > 0) { - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + for (int node : sortedNodes) { meta.writeInt(node); // list of nodes on a level } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java index bf0b79807f06..5358d66f16e2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -315,9 +315,8 @@ private HnswGraph reconstructAndWriteGraph( for (int level = 1; level < graph.numLevels(); level++) { NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); int[] newNodes = new int[nodesOnLevel.size()]; - int n = 0; - while (nodesOnLevel.hasNext()) { - newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()]; + for (int n = 0; nodesOnLevel.hasNext(); n++) { + newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()]; } Arrays.sort(newNodes); nodesByLevel.add(newNodes); @@ -677,11 +676,10 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException { int countOnLevel0 = graph.size(); int[][] offsets = new int[graph.numLevels()][]; for (int level = 0; level < graph.numLevels(); level++) { - NodesIterator nodesOnLevel = graph.getNodesOnLevel(level); - offsets[level] = new int[nodesOnLevel.size()]; + int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level)); + offsets[level] = new int[sortedNodes.length]; int nodeOffsetId = 0; - while (nodesOnLevel.hasNext()) { - int node = nodesOnLevel.nextInt(); + for (int node : sortedNodes) { NeighborArray neighbors = graph.getNeighbors(level, node); int size = neighbors.size(); // Write size in VInt as the neighbors list is typically small @@ -706,6 +704,15 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException { return offsets; } + public static int[] getSortedNodes(NodesIterator nodesOnLevel) { + int[] sortedNodes = new int[nodesOnLevel.size()]; + for (int n = 0; nodesOnLevel.hasNext(); n++) { + sortedNodes[n] = nodesOnLevel.nextInt(); + } + Arrays.sort(sortedNodes); + return sortedNodes; + } + private void writeMeta( FieldInfo field, int maxDoc, @@ -779,6 +786,7 @@ private void writeMeta( if (level > 0) { int[] nol = new int[nodesOnLevel.size()]; int numberConsumed = nodesOnLevel.consume(nol); + Arrays.sort(nol); assert numberConsumed == nodesOnLevel.size(); meta.writeVInt(nol.length); // number of nodes on a level for (int i = nodesOnLevel.size() - 1; i > 0; --i) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index 9086ab55d2eb..9b3d0d62c905 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -81,7 +81,8 @@ protected HnswGraph() {} public abstract int entryNode() throws IOException; /** - * Get all nodes on a given level as node 0th ordinals + * Get all nodes on a given level as node 0th ordinals. The nodes are NOT guaranteed to be + * presented in any particular order. * * @param level level for which to get all nodes * @return an iterator over nodes where {@code nextInt} returns a next node on the level @@ -123,7 +124,8 @@ public NodesIterator getNodesOnLevel(int level) { /** * Iterator over the graph nodes on a certain level, Iterator also provides the size – the total - * number of nodes to be iterated over. + * number of nodes to be iterated over. The nodes are NOT guaranteed to be presented in any + * particular order. */ public abstract static class NodesIterator implements PrimitiveIterator.OfInt { protected final int size; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 9862536de08c..ae39614f160e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -20,8 +20,9 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; -import java.util.TreeMap; +import java.util.Map; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; @@ -40,12 +41,12 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { // added to HnswBuilder, and the node values are the ordinals of those vectors. // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. private final List graphLevel0; - // Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0 + // Represents levels 1-N. Each level is represented with a Map that maps a levels level 0 // ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain // it in this list. However, to avoid changing list indexing, we always will make the first // element // null. - private final List> graphUpperLevels; + private final List> graphUpperLevels; private final int nsize; private final int nsize0; @@ -76,7 +77,7 @@ public NeighborArray getNeighbors(int level, int node) { if (level == 0) { return graphLevel0.get(node); } - TreeMap levelMap = graphUpperLevels.get(level); + Map levelMap = graphUpperLevels.get(level); assert levelMap.containsKey(node); return levelMap.get(node); } @@ -103,7 +104,7 @@ public void addNode(int level, int node) { // and make this node the graph's new entry point if (level >= numLevels) { for (int i = numLevels; i <= level; i++) { - graphUpperLevels.add(new TreeMap<>()); + graphUpperLevels.add(new HashMap<>()); } numLevels = level + 1; entryNode = node; @@ -204,4 +205,15 @@ public long ramBytesUsed() { } return total; } + + @Override + public String toString() { + return "OnHeapHnswGraph(size=" + + size() + + ", numLevels=" + + numLevels + + ", entryNode=" + + entryNode + + ")"; + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 80c9c7a93cf4..9825d4a5f419 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -29,6 +29,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Random; import java.util.Set; @@ -265,19 +266,50 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } } + List sortedNodesOnLevel(HnswGraph h, int level) throws IOException { + NodesIterator nodesOnLevel = h.getNodesOnLevel(level); + List nodes = new ArrayList<>(); + while (nodesOnLevel.hasNext()) { + nodes.add(nodesOnLevel.next()); + } + Collections.sort(nodes); + return nodes; + } + void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { - assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); - assertEquals("the number of nodes in the graphs are different!", g.size(), h.size()); + // construct these up front since they call seek which will mess up our test loop + String prettyG = prettyPrint(g); + String prettyH = prettyPrint(h); + assertEquals( + String.format( + Locale.ROOT, + "the number of levels in the graphs are different:%n%s%n%s", + prettyG, + prettyH), + g.numLevels(), + h.numLevels()); + assertEquals( + String.format( + Locale.ROOT, + "the number of nodes in the graphs are different:%n%s%n%s", + prettyG, + prettyH), + g.size(), + h.size()); // assert equal nodes on each level for (int level = 0; level < g.numLevels(); level++) { - NodesIterator nodesOnLevel = g.getNodesOnLevel(level); - NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level); - while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) { - int node = nodesOnLevel.nextInt(); - int node2 = nodesOnLevel2.nextInt(); - assertEquals("nodes in the graphs are different", node, node2); - } + List hNodes = sortedNodesOnLevel(h, level); + List gNodes = sortedNodesOnLevel(g, level); + assertEquals( + String.format( + Locale.ROOT, + "nodes in the graphs are different on level %d:%n%s%n%s", + level, + prettyG, + prettyH), + gNodes, + hNodes); } // assert equal nodes' neighbours on each level @@ -287,7 +319,16 @@ void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { int node = nodesOnLevel.nextInt(); g.seek(level, node); h.seek(level, node); - assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h)); + assertEquals( + String.format( + Locale.ROOT, + "arcs differ for node %d on level %d:%n%s%n%s", + node, + level, + prettyG, + prettyH), + getNeighborNodes(g), + getNeighborNodes(h)); } } } @@ -495,14 +536,12 @@ public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { } for (int currLevel = 1; currLevel < numLevels; currLevel++) { - NodesIterator nodesIterator = bottomUpExpectedHnsw.getNodesOnLevel(currLevel); List expectedNodesOnLevel = nodesPerLevel.get(currLevel); - assertEquals(expectedNodesOnLevel.size(), nodesIterator.size()); - for (Integer expectedNode : expectedNodesOnLevel) { - int currentNode = nodesIterator.nextInt(); - assertEquals(expectedNode.intValue(), currentNode); - assertEquals(0, bottomUpExpectedHnsw.getNeighbors(currLevel, currentNode).size()); - } + List sortedNodes = sortedNodesOnLevel(bottomUpExpectedHnsw, currLevel); + assertEquals( + String.format(Locale.ROOT, "Nodes on level %d do not match", currLevel), + expectedNodesOnLevel, + sortedNodes); } assertGraphEqual(bottomUpExpectedHnsw, topDownOrderReversedHnsw); @@ -607,13 +646,10 @@ private void assertGraphInitializedFromGraph( // assert the nodes from the previous graph are successfully to levels > 0 in the new graph for (int level = 1; level < g.numLevels(); level++) { - NodesIterator nodesOnLevel = g.getNodesOnLevel(level); - NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level); - while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) { - int node = nodesOnLevel.nextInt(); - int node2 = oldToNewOrdMap.get(nodesOnLevel2.nextInt()); - assertEquals("nodes in the graphs are different", node, node2); - } + List nodesOnLevel = sortedNodesOnLevel(g, level); + List nodesOnLevel2 = + sortedNodesOnLevel(h, level).stream().map(oldToNewOrdMap::get).toList(); + assertEquals(nodesOnLevel, nodesOnLevel2); } // assert that the neighbors from the old graph are successfully transferred to the new graph @@ -1196,4 +1232,34 @@ static byte[] randomVector8(Random random, int dim) { } return bvec; } + + static String prettyPrint(HnswGraph hnsw) { + StringBuilder sb = new StringBuilder(); + sb.append(hnsw); + sb.append("\n"); + + try { + for (int level = 0; level < hnsw.numLevels(); level++) { + sb.append("# Level ").append(level).append("\n"); + NodesIterator it = hnsw.getNodesOnLevel(level); + while (it.hasNext()) { + int node = it.nextInt(); + sb.append(" ").append(node).append(" -> "); + hnsw.seek(level, node); + while (true) { + int neighbor = hnsw.nextNeighbor(); + if (neighbor == NO_MORE_DOCS) { + break; + } + sb.append(" ").append(neighbor); + } + sb.append("\n"); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return sb.toString(); + } } From 9cb17b0b7bb6bd820627a08c58ac3a4e64d38379 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Wed, 3 May 2023 11:27:33 +0200 Subject: [PATCH 06/35] Fix SynonymQuery equals implementation (#12260) The term member of TermAndBoost used to be a Term instance and became a BytesRef with #11941, which means its equals impl won't take the field name into account. The SynonymQuery equals impl needs to be updated accordingly to take the field into account as well, otherwise synonym queries with same term and boost across different fields are equal which is a bug. --- .../org/apache/lucene/search/SynonymQuery.java | 6 ++++-- .../apache/lucene/search/TestSynonymQuery.java | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java index ca32a7d7bb8e..7e314870f9c7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java @@ -113,7 +113,7 @@ public SynonymQuery build() { */ private SynonymQuery(TermAndBoost[] terms, String field) { this.terms = Objects.requireNonNull(terms); - this.field = field; + this.field = Objects.requireNonNull(field); } public List getTerms() { @@ -146,7 +146,9 @@ public int hashCode() { @Override public boolean equals(Object other) { - return sameClassAs(other) && Arrays.equals(terms, ((SynonymQuery) other).terms); + return sameClassAs(other) + && field.equals(((SynonymQuery) other).field) + && Arrays.equals(terms, ((SynonymQuery) other).terms); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java index 123307e75f8e..ae026a87367f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java @@ -73,6 +73,18 @@ public void testEquals() { .addTerm(new Term("field", "c"), 0.2f) .addTerm(new Term("field", "d")) .build()); + + QueryUtils.checkUnequal( + new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.4f).build(), + new SynonymQuery.Builder("field").addTerm(new Term("field", "b"), 0.4f).build()); + + QueryUtils.checkUnequal( + new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.2f).build(), + new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.4f).build()); + + QueryUtils.checkUnequal( + new SynonymQuery.Builder("field1").addTerm(new Term("field1", "b"), 0.4f).build(), + new SynonymQuery.Builder("field2").addTerm(new Term("field2", "b"), 0.4f).build()); } public void testBogusParams() { @@ -127,6 +139,12 @@ public void testBogusParams() { () -> { new SynonymQuery.Builder("field1").addTerm(new Term("field1", "a"), -0f); }); + + expectThrows( + NullPointerException.class, + () -> new SynonymQuery.Builder(null).addTerm(new Term("field1", "a"), -0f)); + + expectThrows(NullPointerException.class, () -> new SynonymQuery.Builder(null).build()); } public void testToString() { From ca5b83175281f9c7c9ee3594e26534520fe69076 Mon Sep 17 00:00:00 2001 From: Uwe Schindler Date: Fri, 5 May 2023 12:04:38 +0200 Subject: [PATCH 07/35] Fix MMapDirectory documentation for Java 20 (#12265) --- .../org/apache/lucene/store/MMapDirectory.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java b/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java index 58be07ab4372..9ca636b4cd71 100644 --- a/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java +++ b/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java @@ -76,9 +76,9 @@ *
  • {@code permission java.lang.RuntimePermission "accessClassInPackage.sun.misc";} * * - *

    On exactly Java 19 this class will use the modern {@code MemorySegment} API which - * allows to safely unmap (if you discover any problems with this preview API, you can disable it by - * using system property {@link #ENABLE_MEMORY_SEGMENTS_SYSPROP}). + *

    On exactly Java 19 and Java 20 this class will use the modern {@code + * MemorySegment} API which allows to safely unmap (if you discover any problems with this preview + * API, you can disable it by using system property {@link #ENABLE_MEMORY_SEGMENTS_SYSPROP}). * *

    NOTE: Accessing this class either directly or indirectly from a thread while it's * interrupted can close the underlying channel immediately if at the same time the thread is @@ -123,7 +123,7 @@ public class MMapDirectory extends FSDirectory { * Default max chunk size: * *

      - *
    • 16 GiBytes for 64 bit Java 19 JVMs + *
    • 16 GiBytes for 64 bit Java 19 and Java 20 JVMs *
    • 1 GiBytes for other 64 bit JVMs *
    • 256 MiBytes for 32 bit JVMs *
    @@ -198,9 +198,9 @@ public MMapDirectory(Path path, long maxChunkSize) throws IOException { * files cannot be mapped. Using a lower chunk size makes the directory implementation a little * bit slower (as the correct chunk may be resolved on lots of seeks) but the chance is higher * that mmap does not fail. On 64 bit Java platforms, this parameter should always be large (like - * 1 GiBytes, or even larger with Java 19), as the address space is big enough. If it is larger, - * fragmentation of address space increases, but number of file handles and mappings is lower for - * huge installations with many open indexes. + * 1 GiBytes, or even larger with recent Java versions), as the address space is big enough. If it + * is larger, fragmentation of address space increases, but number of file handles and mappings is + * lower for huge installations with many open indexes. * *

    Please note: The chunk size is always rounded down to a power of 2. * From 6f8edc2551efd87cf7501b3fcdd7cf5c4f34cb3b Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Mon, 8 May 2023 10:12:36 -0400 Subject: [PATCH 08/35] GITHUB-12224: remove KnnGraphTester (moved to luceneutil) (#12238) --- .../lucene/util/hnsw/KnnGraphTester.java | 794 ------------------ 1 file changed, 794 deletions(-) delete mode 100644 lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java deleted file mode 100644 index 8b625a29a163..000000000000 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java +++ /dev/null @@ -1,794 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.util.hnsw; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import java.io.OutputStream; -import java.lang.management.ManagementFactory; -import java.lang.management.ThreadMXBean; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.IntBuffer; -import java.nio.channels.FileChannel; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.attribute.FileTime; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Locale; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import org.apache.lucene.codecs.KnnVectorsFormat; -import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.lucene95.Lucene95Codec; -import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; -import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; -import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; -import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.document.StoredField; -import org.apache.lucene.index.CodecReader; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.StoredFields; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.ConstantScoreScorer; -import org.apache.lucene.search.ConstantScoreWeight; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.KnnFloatVectorQuery; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.Weight; -import org.apache.lucene.store.Directory; -import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.util.BitSetIterator; -import org.apache.lucene.util.FixedBitSet; -import org.apache.lucene.util.PrintStreamInfoStream; -import org.apache.lucene.util.SuppressForbidden; - -/** - * For testing indexing and search performance of a knn-graph - * - *

    java -cp .../lib/*.jar org.apache.lucene.util.hnsw.KnnGraphTester -ndoc 1000000 -search - * .../vectors.bin - */ -public class KnnGraphTester { - - private static final String KNN_FIELD = "knn"; - private static final String ID_FIELD = "id"; - - private int numDocs; - private int dim; - private int topK; - private int numIters; - private int fanout; - private Path indexPath; - private boolean quiet; - private boolean reindex; - private boolean forceMerge; - private int reindexTimeMsec; - private int beamWidth; - private int maxConn; - private VectorSimilarityFunction similarityFunction; - private VectorEncoding vectorEncoding; - private FixedBitSet matchDocs; - private float selectivity; - private boolean prefilter; - - private KnnGraphTester() { - // set defaults - numDocs = 1000; - numIters = 1000; - dim = 256; - topK = 100; - fanout = topK; - similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - vectorEncoding = VectorEncoding.FLOAT32; - selectivity = 1f; - prefilter = false; - } - - public static void main(String... args) throws Exception { - new KnnGraphTester().run(args); - } - - private void run(String... args) throws Exception { - String operation = null; - Path docVectorsPath = null, queryPath = null, outputPath = null; - for (int iarg = 0; iarg < args.length; iarg++) { - String arg = args[iarg]; - switch (arg) { - case "-search": - case "-check": - case "-stats": - case "-dump": - if (operation != null) { - throw new IllegalArgumentException( - "Specify only one operation, not both " + arg + " and " + operation); - } - operation = arg; - if (operation.equals("-search")) { - if (iarg == args.length - 1) { - throw new IllegalArgumentException( - "Operation " + arg + " requires a following pathname"); - } - queryPath = Paths.get(args[++iarg]); - } - break; - case "-fanout": - if (iarg == args.length - 1) { - throw new IllegalArgumentException("-fanout requires a following number"); - } - fanout = Integer.parseInt(args[++iarg]); - break; - case "-beamWidthIndex": - if (iarg == args.length - 1) { - throw new IllegalArgumentException("-beamWidthIndex requires a following number"); - } - beamWidth = Integer.parseInt(args[++iarg]); - break; - case "-maxConn": - if (iarg == args.length - 1) { - throw new IllegalArgumentException("-maxConn requires a following number"); - } - maxConn = Integer.parseInt(args[++iarg]); - break; - case "-dim": - if (iarg == args.length - 1) { - throw new IllegalArgumentException("-dim requires a following number"); - } - dim = Integer.parseInt(args[++iarg]); - break; - case "-ndoc": - if (iarg == args.length - 1) { - throw new IllegalArgumentException("-ndoc requires a following number"); - } - numDocs = Integer.parseInt(args[++iarg]); - break; - case "-niter": - if (iarg == args.length - 1) { - throw new IllegalArgumentException("-niter requires a following number"); - } - numIters = Integer.parseInt(args[++iarg]); - break; - case "-reindex": - reindex = true; - break; - case "-topK": - if (iarg == args.length - 1) { - throw new IllegalArgumentException("-topK requires a following number"); - } - topK = Integer.parseInt(args[++iarg]); - break; - case "-out": - outputPath = Paths.get(args[++iarg]); - break; - case "-docs": - docVectorsPath = Paths.get(args[++iarg]); - break; - case "-encoding": - String encoding = args[++iarg]; - switch (encoding) { - case "byte": - vectorEncoding = VectorEncoding.BYTE; - break; - case "float32": - vectorEncoding = VectorEncoding.FLOAT32; - break; - default: - throw new IllegalArgumentException("-encoding can be 'byte' or 'float32' only"); - } - break; - case "-metric": - String metric = args[++iarg]; - switch (metric) { - case "euclidean": - similarityFunction = VectorSimilarityFunction.EUCLIDEAN; - break; - case "angular": - similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - break; - default: - throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only"); - } - break; - case "-forceMerge": - forceMerge = true; - break; - case "-prefilter": - prefilter = true; - break; - case "-filterSelectivity": - if (iarg == args.length - 1) { - throw new IllegalArgumentException("-filterSelectivity requires a following float"); - } - selectivity = Float.parseFloat(args[++iarg]); - if (selectivity <= 0 || selectivity >= 1) { - throw new IllegalArgumentException("-filterSelectivity must be between 0 and 1"); - } - break; - case "-quiet": - quiet = true; - break; - default: - throw new IllegalArgumentException("unknown argument " + arg); - // usage(); - } - } - if (operation == null && reindex == false) { - usage(); - } - if (prefilter && selectivity == 1f) { - throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1"); - } - indexPath = Paths.get(formatIndexPath(docVectorsPath)); - if (reindex) { - if (docVectorsPath == null) { - throw new IllegalArgumentException("-docs argument is required when indexing"); - } - reindexTimeMsec = createIndex(docVectorsPath, indexPath); - if (forceMerge) { - forceMerge(); - } - } - if (operation != null) { - switch (operation) { - case "-search": - if (docVectorsPath == null) { - throw new IllegalArgumentException("missing -docs arg"); - } - if (selectivity < 1) { - matchDocs = generateRandomBitSet(numDocs, selectivity); - } - if (outputPath != null) { - testSearch(indexPath, queryPath, outputPath, null); - } else { - testSearch(indexPath, queryPath, null, getNN(docVectorsPath, queryPath)); - } - break; - case "-stats": - printFanoutHist(indexPath); - break; - } - } - } - - private String formatIndexPath(Path docsPath) { - return docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + ".index"; - } - - @SuppressForbidden(reason = "Prints stuff") - private void printFanoutHist(Path indexPath) throws IOException { - try (Directory dir = FSDirectory.open(indexPath); - DirectoryReader reader = DirectoryReader.open(dir)) { - for (LeafReaderContext context : reader.leaves()) { - LeafReader leafReader = context.reader(); - KnnVectorsReader vectorsReader = - ((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader()) - .getFieldReader(KNN_FIELD); - HnswGraph knnValues = ((Lucene95HnswVectorsReader) vectorsReader).getGraph(KNN_FIELD); - System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc()); - printGraphFanout(knnValues, leafReader.maxDoc()); - } - } - } - - @SuppressForbidden(reason = "Prints stuff") - private void forceMerge() throws IOException { - IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND); - iwc.setInfoStream(new PrintStreamInfoStream(System.out)); - System.out.println("Force merge index in " + indexPath); - try (IndexWriter iw = new IndexWriter(FSDirectory.open(indexPath), iwc)) { - iw.forceMerge(1); - } - } - - @SuppressForbidden(reason = "Prints stuff") - private void printGraphFanout(HnswGraph knnValues, int numDocs) throws IOException { - int min = Integer.MAX_VALUE, max = 0, total = 0; - int count = 0; - int[] leafHist = new int[numDocs]; - for (int node = 0; node < numDocs; node++) { - knnValues.seek(0, node); - int n = 0; - while (knnValues.nextNeighbor() != NO_MORE_DOCS) { - ++n; - } - ++leafHist[n]; - max = Math.max(max, n); - min = Math.min(min, n); - if (n > 0) { - ++count; - total += n; - } - } - System.out.printf( - "Graph size=%d, Fanout min=%d, mean=%.2f, max=%d\n", - count, min, total / (float) count, max); - printHist(leafHist, max, count, 10); - } - - @SuppressForbidden(reason = "Prints stuff") - private void printHist(int[] hist, int max, int count, int nbuckets) { - System.out.print("%"); - for (int i = 0; i <= nbuckets; i++) { - System.out.printf("%4d", i * 100 / nbuckets); - } - System.out.printf("\n %4d", hist[0]); - int total = 0, ibucket = 1; - for (int i = 1; i <= max && ibucket <= nbuckets; i++) { - total += hist[i]; - while (total >= count * ibucket / nbuckets) { - System.out.printf("%4d", i); - ++ibucket; - } - } - System.out.println(); - } - - @SuppressForbidden(reason = "Prints stuff") - private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] nn) - throws IOException { - TopDocs[] results = new TopDocs[numIters]; - long elapsed, totalCpuTime, totalVisited = 0; - try (FileChannel input = FileChannel.open(queryPath)) { - VectorReader targetReader = VectorReader.create(input, dim, vectorEncoding); - if (quiet == false) { - System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout); - } - long start; - ThreadMXBean bean = ManagementFactory.getThreadMXBean(); - long cpuTimeStartNs; - try (Directory dir = FSDirectory.open(indexPath); - DirectoryReader reader = DirectoryReader.open(dir)) { - IndexSearcher searcher = new IndexSearcher(reader); - numDocs = reader.maxDoc(); - Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null; - for (int i = 0; i < numIters; i++) { - // warm up - float[] target = targetReader.next(); - if (prefilter) { - doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery); - } else { - doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null); - } - } - targetReader.reset(); - start = System.nanoTime(); - cpuTimeStartNs = bean.getCurrentThreadCpuTime(); - for (int i = 0; i < numIters; i++) { - float[] target = targetReader.next(); - if (prefilter) { - results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery); - } else { - results[i] = - doKnnVectorQuery( - searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null); - - if (matchDocs != null) { - results[i].scoreDocs = - Arrays.stream(results[i].scoreDocs) - .filter(scoreDoc -> matchDocs.get(scoreDoc.doc)) - .toArray(ScoreDoc[]::new); - } - } - } - totalCpuTime = - TimeUnit.NANOSECONDS.toMillis(bean.getCurrentThreadCpuTime() - cpuTimeStartNs); - elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); // ns -> ms - StoredFields storedFields = reader.storedFields(); - for (int i = 0; i < numIters; i++) { - totalVisited += results[i].totalHits.value; - for (ScoreDoc doc : results[i].scoreDocs) { - if (doc.doc != NO_MORE_DOCS) { - // there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens - // in some degenerate case (like input query has NaN in it?) that causes no results to - // be returned from HNSW search? - doc.doc = Integer.parseInt(storedFields.document(doc.doc).get("id")); - } else { - System.out.println("NO_MORE_DOCS!"); - } - } - } - } - if (quiet == false) { - System.out.println( - "completed " - + numIters - + " searches in " - + elapsed - + " ms: " - + ((1000 * numIters) / elapsed) - + " QPS " - + "CPU time=" - + totalCpuTime - + "ms"); - } - } - if (outputPath != null) { - ByteBuffer buf = ByteBuffer.allocate(4); - IntBuffer ibuf = buf.order(ByteOrder.LITTLE_ENDIAN).asIntBuffer(); - try (OutputStream out = Files.newOutputStream(outputPath)) { - for (int i = 0; i < numIters; i++) { - for (ScoreDoc doc : results[i].scoreDocs) { - ibuf.position(0); - ibuf.put(doc.doc); - out.write(buf.array()); - } - } - } - } else { - if (quiet == false) { - System.out.println("checking results"); - } - float recall = checkResults(results, nn); - totalVisited /= numIters; - System.out.printf( - Locale.ROOT, - "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\t%.2f\t%s\n", - recall, - totalCpuTime / (float) numIters, - numDocs, - fanout, - maxConn, - beamWidth, - totalVisited, - reindexTimeMsec, - selectivity, - prefilter ? "pre-filter" : "post-filter"); - } - } - - private abstract static class VectorReader { - final float[] target; - final ByteBuffer bytes; - final FileChannel input; - - static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) { - int bufferSize = dim * vectorEncoding.byteSize; - return switch (vectorEncoding) { - case BYTE -> new VectorReaderByte(input, dim, bufferSize); - case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize); - }; - } - - VectorReader(FileChannel input, int dim, int bufferSize) { - this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN); - this.input = input; - target = new float[dim]; - } - - void reset() throws IOException { - input.position(0); - } - - protected final void readNext() throws IOException { - this.input.read(bytes); - bytes.position(0); - } - - abstract float[] next() throws IOException; - } - - private static class VectorReaderFloat32 extends VectorReader { - VectorReaderFloat32(FileChannel input, int dim, int bufferSize) { - super(input, dim, bufferSize); - } - - @Override - float[] next() throws IOException { - readNext(); - bytes.asFloatBuffer().get(target); - return target; - } - } - - private static class VectorReaderByte extends VectorReader { - private final byte[] scratch; - - VectorReaderByte(FileChannel input, int dim, int bufferSize) { - super(input, dim, bufferSize); - scratch = new byte[dim]; - } - - @Override - float[] next() throws IOException { - readNext(); - bytes.get(scratch); - for (int i = 0; i < scratch.length; i++) { - target[i] = scratch[i]; - } - return target; - } - - byte[] nextBytes() throws IOException { - readNext(); - bytes.get(scratch); - return scratch; - } - } - - private static TopDocs doKnnVectorQuery( - IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter) - throws IOException { - return searcher.search(new KnnFloatVectorQuery(field, vector, k + fanout, filter), k); - } - - private float checkResults(TopDocs[] results, int[][] nn) { - int totalMatches = 0; - int totalResults = results.length * topK; - for (int i = 0; i < results.length; i++) { - // System.out.println(Arrays.toString(nn[i])); - // System.out.println(Arrays.toString(results[i].scoreDocs)); - totalMatches += compareNN(nn[i], results[i]); - } - return totalMatches / (float) totalResults; - } - - private int compareNN(int[] expected, TopDocs results) { - int matched = 0; - /* - System.out.print("expected="); - for (int j = 0; j < expected.length; j++) { - System.out.print(expected[j]); - System.out.print(", "); - } - System.out.print('\n'); - System.out.println("results="); - for (int j = 0; j < results.scoreDocs.length; j++) { - System.out.print("" + results.scoreDocs[j].doc + ":" + results.scoreDocs[j].score + ", "); - } - System.out.print('\n'); - */ - Set expectedSet = new HashSet<>(); - for (int i = 0; i < topK; i++) { - expectedSet.add(expected[i]); - } - for (ScoreDoc scoreDoc : results.scoreDocs) { - if (expectedSet.contains(scoreDoc.doc)) { - ++matched; - } - } - return matched; - } - - private int[][] getNN(Path docPath, Path queryPath) throws IOException { - // look in working directory for cached nn file - String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36); - String nnFileName = "nn-" + hash + ".bin"; - Path nnPath = Paths.get(nnFileName); - if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) { - return readNN(nnPath); - } else { - // TODO: enable computing NN from high precision vectors when - // checking low-precision recall - int[][] nn = computeNN(docPath, queryPath, vectorEncoding); - if (selectivity == 1f) { - writeNN(nn, nnPath); - } - return nn; - } - } - - private boolean isNewer(Path path, Path... others) throws IOException { - FileTime modified = Files.getLastModifiedTime(path); - for (Path other : others) { - if (Files.getLastModifiedTime(other).compareTo(modified) >= 0) { - return false; - } - } - return true; - } - - private int[][] readNN(Path nnPath) throws IOException { - int[][] result = new int[numIters][]; - try (FileChannel in = FileChannel.open(nnPath)) { - IntBuffer intBuffer = - in.map(FileChannel.MapMode.READ_ONLY, 0, numIters * topK * Integer.BYTES) - .order(ByteOrder.LITTLE_ENDIAN) - .asIntBuffer(); - for (int i = 0; i < numIters; i++) { - result[i] = new int[topK]; - intBuffer.get(result[i]); - } - } - return result; - } - - private void writeNN(int[][] nn, Path nnPath) throws IOException { - if (quiet == false) { - System.out.println("writing true nearest neighbors to " + nnPath); - } - ByteBuffer tmp = - ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); - try (OutputStream out = Files.newOutputStream(nnPath)) { - for (int i = 0; i < numIters; i++) { - tmp.asIntBuffer().put(nn[i]); - out.write(tmp.array()); - } - } - } - - @SuppressForbidden(reason = "Uses random()") - private static FixedBitSet generateRandomBitSet(int size, float selectivity) { - FixedBitSet bitSet = new FixedBitSet(size); - for (int i = 0; i < size; i++) { - if (Math.random() < selectivity) { - bitSet.set(i); - } else { - bitSet.clear(i); - } - } - return bitSet; - } - - private int[][] computeNN(Path docPath, Path queryPath, VectorEncoding encoding) - throws IOException { - int[][] result = new int[numIters][]; - if (quiet == false) { - System.out.println("computing true nearest neighbors of " + numIters + " target vectors"); - } - try (FileChannel in = FileChannel.open(docPath); - FileChannel qIn = FileChannel.open(queryPath)) { - VectorReader docReader = VectorReader.create(in, dim, encoding); - VectorReader queryReader = VectorReader.create(qIn, dim, encoding); - for (int i = 0; i < numIters; i++) { - float[] query = queryReader.next(); - NeighborQueue queue = new NeighborQueue(topK, false); - for (int j = 0; j < numDocs; j++) { - float[] doc = docReader.next(); - float d = similarityFunction.compare(query, doc); - if (matchDocs == null || matchDocs.get(j)) { - queue.insertWithOverflow(j, d); - } - } - docReader.reset(); - result[i] = new int[topK]; - for (int k = topK - 1; k >= 0; k--) { - result[i][k] = queue.topNode(); - queue.pop(); - // System.out.print(" " + n); - } - if (quiet == false && (i + 1) % 10 == 0) { - System.out.print(" " + (i + 1)); - System.out.flush(); - } - } - } - return result; - } - - private int createIndex(Path docsPath, Path indexPath) throws IOException { - IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE); - iwc.setCodec( - new Lucene95Codec() { - @Override - public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(maxConn, beamWidth); - } - }); - // iwc.setMergePolicy(NoMergePolicy.INSTANCE); - iwc.setRAMBufferSizeMB(1994d); - iwc.setUseCompoundFile(false); - // iwc.setMaxBufferedDocs(10000); - - FieldType fieldType = - switch (vectorEncoding) { - case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction); - case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction); - }; - if (quiet == false) { - iwc.setInfoStream(new PrintStreamInfoStream(System.out)); - System.out.println("creating index in " + indexPath); - } - long start = System.nanoTime(); - try (FSDirectory dir = FSDirectory.open(indexPath); - IndexWriter iw = new IndexWriter(dir, iwc)) { - try (FileChannel in = FileChannel.open(docsPath)) { - VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding); - for (int i = 0; i < numDocs; i++) { - Document doc = new Document(); - switch (vectorEncoding) { - case BYTE -> doc.add( - new KnnByteVectorField( - KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType)); - case FLOAT32 -> doc.add( - new KnnFloatVectorField(KNN_FIELD, vectorReader.next(), fieldType)); - } - doc.add(new StoredField(ID_FIELD, i)); - iw.addDocument(doc); - } - if (quiet == false) { - System.out.println("Done indexing " + numDocs + " documents; now flush"); - } - } - } - long elapsed = System.nanoTime() - start; - if (quiet == false) { - System.out.println( - "Indexed " + numDocs + " documents in " + TimeUnit.NANOSECONDS.toSeconds(elapsed) + "s"); - } - return (int) TimeUnit.NANOSECONDS.toMillis(elapsed); - } - - private static void usage() { - String error = - "Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N] [-filterSelectivity N] [-prefilter]"; - System.err.println(error); - System.exit(1); - } - - private static class BitSetQuery extends Query { - private final FixedBitSet docs; - private final int cardinality; - - BitSetQuery(FixedBitSet docs) { - this.docs = docs; - this.cardinality = docs.cardinality(); - } - - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) - throws IOException { - return new ConstantScoreWeight(this, boost) { - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - return new ConstantScoreScorer( - this, score(), scoreMode, new BitSetIterator(docs, cardinality)); - } - - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return false; - } - }; - } - - @Override - public void visit(QueryVisitor visitor) {} - - @Override - public String toString(String field) { - return "BitSetQuery"; - } - - @Override - public boolean equals(Object other) { - return sameClassAs(other) && docs.equals(((BitSetQuery) other).docs); - } - - @Override - public int hashCode() { - return 31 * classHash() + docs.hashCode(); - } - } -} From ab643b7dbc4b9667a1af41659ea6120dd5aeef6e Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Mon, 8 May 2023 16:22:58 -0500 Subject: [PATCH 09/35] allocate one NeighborQueue per search for results (#12255) --- .../lucene/util/hnsw/HnswGraphSearcher.java | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 4857d5b9d577..84d40c841bf8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -100,7 +100,7 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - NeighborQueue results; + NeighborQueue results = new NeighborQueue(topK, false); int initialEp = graph.entryNode(); if (initialEp == -1) { @@ -109,7 +109,8 @@ public static NeighborQueue search( int[] eps = new int[] {initialEp}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { - results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); + results.clear(); + graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit); numVisited += results.visitedCount(); visitedLimit -= results.visitedCount(); if (results.incomplete()) { @@ -118,8 +119,9 @@ public static NeighborQueue search( } eps[0] = results.pop(); } - results = - graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); + results.clear(); + graphSearcher.searchLevel( + results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); results.setVisitedCount(results.visitedCount() + numVisited); return results; } @@ -161,11 +163,12 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - NeighborQueue results; + NeighborQueue results = new NeighborQueue(topK, false); int[] eps = new int[] {graph.entryNode()}; int numVisited = 0; for (int level = graph.numLevels() - 1; level >= 1; level--) { - results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); + results.clear(); + graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit); numVisited += results.visitedCount(); visitedLimit -= results.visitedCount(); @@ -176,8 +179,9 @@ public static NeighborQueue search( } eps[0] = results.pop(); } - results = - graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); + results.clear(); + graphSearcher.searchLevel( + results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); results.setVisitedCount(results.visitedCount() + numVisited); return results; } @@ -205,10 +209,19 @@ public NeighborQueue searchLevel( RandomAccessVectorValues vectors, HnswGraph graph) throws IOException { - return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); + NeighborQueue results = new NeighborQueue(topK, false); + searchLevel(results, query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); + return results; } - private NeighborQueue searchLevel( + /** + * Add the closest neighbors found to a priority queue (heap). These are returned in REVERSE + * proximity order -- the most distant neighbor of the topK found, i.e. the one with the lowest + * score/comparison value, will be at the top of the heap, while the closest neighbor will be the + * last to be popped. + */ + private void searchLevel( + NeighborQueue results, T query, int topK, int level, @@ -219,7 +232,6 @@ private NeighborQueue searchLevel( int visitedLimit) throws IOException { int size = graph.size(); - NeighborQueue results = new NeighborQueue(topK, false); prepareScratchState(vectors.size()); int numVisited = 0; @@ -280,7 +292,6 @@ private NeighborQueue searchLevel( results.pop(); } results.setVisitedCount(numVisited); - return results; } private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException { From 815e840c5a39e7a6e722d56b98fe419d8ae7dcfb Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Tue, 9 May 2023 10:18:52 +0200 Subject: [PATCH 10/35] Don't generate stacktrace in CollectionTerminatedException (#12270) CollectionTerminatedException is always caught and never exposed to users so there's no point in filling in a stack-trace for it. --- lucene/CHANGES.txt | 2 ++ .../apache/lucene/search/CollectionTerminatedException.java | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index db018fb21446..7179685d0306 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -72,6 +72,8 @@ Optimizations * GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh) +* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java b/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java index 2a7e04447481..89f14fff20bc 100644 --- a/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java +++ b/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java @@ -31,4 +31,10 @@ public final class CollectionTerminatedException extends RuntimeException { public CollectionTerminatedException() { super(); } + + @Override + public Throwable fillInStackTrace() { + // never re-thrown so we can save the expensive stacktrace + return this; + } } From 2937c8da97e3dc50bbfb1416d1644d8bed2d5eca Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 9 May 2023 10:28:22 +0200 Subject: [PATCH 11/35] Move changes entry for #12270 to 9.7.0 section --- lucene/CHANGES.txt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 7179685d0306..198cd5be330a 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -72,8 +72,6 @@ Optimizations * GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh) -* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun) - Bug Fixes --------------------- @@ -124,7 +122,8 @@ Improvements Optimizations --------------------- -(No changes) + +* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun) Bug Fixes --------------------- From 8befea59b0a5d4ad2eaad729ce424b0028ded4f7 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 9 May 2023 10:52:03 +0200 Subject: [PATCH 12/35] add missing changelog entry for #12260 --- lucene/CHANGES.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 198cd5be330a..79281e9ccf89 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -220,6 +220,8 @@ Bug Fixes * GITHUB#12212: Bug fix for a DrillSideways issue where matching hits could occasionally be missed. (Frederic Thevenet) +* GITHUB#12260: Fix SynonymQuery equals implementation to take the targeted field name into account (Luca Cavanna) + Build --------------------- From 3f93aa6f85a624a64c66e56c5bf261d6a92ccf7a Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 9 May 2023 10:57:28 +0200 Subject: [PATCH 13/35] add missing changelog entry for #12220 --- lucene/CHANGES.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 79281e9ccf89..0c8c56c28afb 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -220,6 +220,8 @@ Bug Fixes * GITHUB#12212: Bug fix for a DrillSideways issue where matching hits could occasionally be missed. (Frederic Thevenet) +* GITHUB#12220: Hunspell: disallow hidden title-case entries from compound middle/end (Peter Gromov) + * GITHUB#12260: Fix SynonymQuery equals implementation to take the targeted field name into account (Luca Cavanna) Build From 3a6fa03f018bd2a4dbd9f56e510b7d13985ddfa7 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 9 May 2023 11:27:06 +0200 Subject: [PATCH 14/35] Make query timeout members final in ExitableDirectoryReader (#12274) There's a couple of places in the Exitable wrapper classes where queryTimeout is set within the constructor and never modified. This commit makes such members final. --- .../org/apache/lucene/index/ExitableDirectoryReader.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index b180d3bdda3d..a572b2258af5 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -36,7 +36,7 @@ */ public class ExitableDirectoryReader extends FilterDirectoryReader { - private QueryTimeout queryTimeout; + private final QueryTimeout queryTimeout; /** Exception that is thrown to prematurely terminate a term enumeration. */ @SuppressWarnings("serial") @@ -50,7 +50,7 @@ public ExitingReaderException(String msg) { /** Wrapper class for a SubReaderWrapper that is used by the ExitableDirectoryReader. */ public static class ExitableSubReaderWrapper extends SubReaderWrapper { - private QueryTimeout queryTimeout; + private final QueryTimeout queryTimeout; /** Constructor * */ public ExitableSubReaderWrapper(QueryTimeout queryTimeout) { @@ -810,7 +810,7 @@ public static class ExitableTermsEnum extends FilterTermsEnum { // Create bit mask in the form of 0000 1111 for efficient checking private static final int NUM_CALLS_PER_TIMEOUT_CHECK = (1 << 4) - 1; // 15 private int calls; - private QueryTimeout queryTimeout; + private final QueryTimeout queryTimeout; /** Constructor * */ public ExitableTermsEnum(TermsEnum termsEnum, QueryTimeout queryTimeout) { From 295113836ed6d9c1bec50dfa2a4e5b07c1629507 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 9 May 2023 11:27:47 +0200 Subject: [PATCH 15/35] Update javadocs for QueryTimeout (#12272) QueryTimeout was introduced together with ExitableDirectoryReader but is now also optionally set to the IndexSearcher to wrap the bulk scorer with a TimeLimitingBulkScorer. Its javadocs needs updating. --- .../java/org/apache/lucene/index/QueryTimeout.java | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java b/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java index 0c64f4c2c9ac..f1e543423670 100644 --- a/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java +++ b/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java @@ -17,14 +17,17 @@ package org.apache.lucene.index; /** - * Base for query timeout implementations, which will provide a {@code shouldExit()} method, used - * with {@link ExitableDirectoryReader}. + * Query timeout abstraction that controls whether a query should continue or be stopped. Can be set + * to the searcher through {@link org.apache.lucene.search.IndexSearcher#setTimeout(QueryTimeout)}, + * in which case bulk scoring will be time-bound. Can also be used in combination with {@link + * ExitableDirectoryReader}. */ public interface QueryTimeout { /** - * Called from {@link ExitableDirectoryReader.ExitableTermsEnum#next()} to determine whether to - * stop processing a query. + * Called to determine whether to stop processing a query + * + * @return true if the query should stop, false otherwise */ - public abstract boolean shouldExit(); + boolean shouldExit(); } From 5f9dfd49041cb3f631548273ed22d3ed78a567f0 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 9 May 2023 11:28:23 +0200 Subject: [PATCH 16/35] Make TimeExceededException members final (#12271) TimeExceededException has three members that are set within its constructor and never modified. They can be made final. --- .../org/apache/lucene/search/TimeLimitingCollector.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java index 4a208a3f0c5f..c50f3a372e97 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java @@ -33,9 +33,9 @@ public class TimeLimitingCollector implements Collector { /** Thrown when elapsed search time exceeds allowed search time. */ @SuppressWarnings("serial") public static class TimeExceededException extends RuntimeException { - private long timeAllowed; - private long timeElapsed; - private int lastDocCollected; + private final long timeAllowed; + private final long timeElapsed; + private final int lastDocCollected; private TimeExceededException(long timeAllowed, long timeElapsed, int lastDocCollected) { super( From 88856f77fe820b6194f91c63ef08b7d7ae31e112 Mon Sep 17 00:00:00 2001 From: Alan Woodward Date: Wed, 10 May 2023 08:36:53 +0100 Subject: [PATCH 17/35] DOAP changes for release 9.6.0 --- dev-tools/doap/lucene.rdf | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dev-tools/doap/lucene.rdf b/dev-tools/doap/lucene.rdf index 48c5743bce40..1b8447ab8615 100644 --- a/dev-tools/doap/lucene.rdf +++ b/dev-tools/doap/lucene.rdf @@ -67,6 +67,13 @@ + + + lucene-9.6.0 + 2023-05-09 + 9.6.0 + + lucene-9.5.0 @@ -74,7 +81,6 @@ 9.5.0 - lucene-9.4.2 From ec5f3d4276a6d656a892c1a4f99e71b36060e4bc Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Fri, 12 May 2023 13:09:00 +0100 Subject: [PATCH 18/35] reasoning about thread safety --- lucene/queries/src/java/module-info.java | 1 + .../ByteVectorFieldSource.java} | 12 +-- .../ByteVectorSimilarityFunction.java} | 12 +-- .../ByteVectorValueSource.java} | 12 +-- .../FloatVectorFieldSource.java} | 12 +-- .../FloatVectorSimilarityFunction.java} | 12 +-- .../FloatVectorValueSource.java} | 10 +-- .../VectorFieldFunction.java} | 6 +- .../VectorSimilarityFunction.java} | 22 ++--- .../TestKnnVectorSimilarityFunctions.java | 80 +++++++++---------- 10 files changed, 90 insertions(+), 89 deletions(-) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{DenseVectorByteFieldSource.java => densevectors/ByteVectorFieldSource.java} (86%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{ByteDenseVectorSimilarityFunction.java => densevectors/ByteVectorSimilarityFunction.java} (74%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{DenseVectorByteConstValueSource.java => densevectors/ByteVectorValueSource.java} (84%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{DenseVectorFloatFieldSource.java => densevectors/FloatVectorFieldSource.java} (86%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{FloatDenseVectorSimilarityFunction.java => densevectors/FloatVectorSimilarityFunction.java} (74%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{DenseVectorFloatConstValueSource.java => densevectors/FloatVectorValueSource.java} (85%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{DenseVectorFieldFunction.java => densevectors/VectorFieldFunction.java} (89%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{DenseVectorSimilarityFunction.java => densevectors/VectorSimilarityFunction.java} (77%) diff --git a/lucene/queries/src/java/module-info.java b/lucene/queries/src/java/module-info.java index ae3ff4d0712f..b1923557bcba 100644 --- a/lucene/queries/src/java/module-info.java +++ b/lucene/queries/src/java/module-info.java @@ -27,4 +27,5 @@ exports org.apache.lucene.queries.mlt; exports org.apache.lucene.queries.payloads; exports org.apache.lucene.queries.spans; + exports org.apache.lucene.queries.function.valuesource.densevectors; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorFieldSource.java similarity index 86% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorFieldSource.java index 4367261a5a9e..535baa76d4fe 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorFieldSource.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource; +package org.apache.lucene.queries.function.valuesource.densevectors; import java.io.IOException; import java.util.Arrays; @@ -25,10 +25,10 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; -public class DenseVectorByteFieldSource extends ValueSource { +public class ByteVectorFieldSource extends ValueSource { private final String fieldName; - public DenseVectorByteFieldSource(String fieldName) { + public ByteVectorFieldSource(String fieldName) { this.fieldName = fieldName; } @@ -37,7 +37,7 @@ public FunctionValues getValues(Map context, LeafReaderContext r throws IOException { final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); - return new DenseVectorFieldFunction(this) { + return new VectorFieldFunction(this) { byte[] defaultVector = null; @Override @@ -66,8 +66,8 @@ private byte[] defaultVector() { @Override public boolean equals(Object o) { - if (o.getClass() != DenseVectorByteFieldSource.class) return false; - DenseVectorByteFieldSource other = (DenseVectorByteFieldSource) o; + if (o.getClass() != ByteVectorFieldSource.class) return false; + ByteVectorFieldSource other = (ByteVectorFieldSource) o; return fieldName.equals(other.fieldName); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorSimilarityFunction.java similarity index 74% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorSimilarityFunction.java index a58b17b6e59a..0414e35ec622 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteDenseVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorSimilarityFunction.java @@ -14,22 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource; +package org.apache.lucene.queries.function.valuesource.densevectors; import java.io.IOException; -import org.apache.lucene.index.VectorSimilarityFunction; + import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; -public class ByteDenseVectorSimilarityFunction extends DenseVectorSimilarityFunction { - public ByteDenseVectorSimilarityFunction( - VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { +public class ByteVectorSimilarityFunction extends VectorSimilarityFunction { + public ByteVectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { super(similarityFunction, vector1, vector2); } @Override protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { - checkSize(f1.byteVectorVal(doc).length, f2.byteVectorVal(doc).length); + assertSameSize(f1.byteVectorVal(doc).length, f2.byteVectorVal(doc).length); return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc)); } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorValueSource.java similarity index 84% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorValueSource.java index 057813882aa1..19e532a1647a 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorByteConstValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorValueSource.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource; +package org.apache.lucene.queries.function.valuesource.densevectors; import java.io.IOException; import java.util.Arrays; @@ -23,11 +23,11 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; - -public class DenseVectorByteConstValueSource extends ValueSource { +/** Function that returns a constant byte vector value for every document. */ +public class ByteVectorValueSource extends ValueSource { byte[] vector; - public DenseVectorByteConstValueSource(List constVector) { + public ByteVectorValueSource(List constVector) { this.vector = new byte[constVector.size()]; for (int i = 0; i < constVector.size(); i++) { vector[i] = constVector.get(i).byteValue(); @@ -57,8 +57,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof DenseVectorByteConstValueSource)) return false; - DenseVectorByteConstValueSource other = (DenseVectorByteConstValueSource) o; + if (!(o instanceof ByteVectorValueSource)) return false; + ByteVectorValueSource other = (ByteVectorValueSource) o; return Arrays.equals(vector, other.vector); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorFieldSource.java similarity index 86% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorFieldSource.java index 6cbe45ad04c3..12570904d1b4 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorFieldSource.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource; +package org.apache.lucene.queries.function.valuesource.densevectors; import java.io.IOException; import java.util.Arrays; @@ -25,10 +25,10 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; -public class DenseVectorFloatFieldSource extends ValueSource { +public class FloatVectorFieldSource extends ValueSource { private final String fieldName; - public DenseVectorFloatFieldSource(String fieldName) { + public FloatVectorFieldSource(String fieldName) { this.fieldName = fieldName; } @@ -37,7 +37,7 @@ public FunctionValues getValues(Map context, LeafReaderContext r throws IOException { final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); - return new DenseVectorFieldFunction(this) { + return new VectorFieldFunction(this) { float[] defaultVector = null; @Override @@ -66,8 +66,8 @@ private float[] defaultVector() { @Override public boolean equals(Object o) { - if (o.getClass() != DenseVectorFloatFieldSource.class) return false; - DenseVectorFloatFieldSource other = (DenseVectorFloatFieldSource) o; + if (o.getClass() != FloatVectorFieldSource.class) return false; + FloatVectorFieldSource other = (FloatVectorFieldSource) o; return fieldName.equals(other.fieldName); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorSimilarityFunction.java similarity index 74% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorSimilarityFunction.java index 336438cc360b..9a59eca61e6f 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatDenseVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorSimilarityFunction.java @@ -14,22 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource; +package org.apache.lucene.queries.function.valuesource.densevectors; import java.io.IOException; -import org.apache.lucene.index.VectorSimilarityFunction; + import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; -public class FloatDenseVectorSimilarityFunction extends DenseVectorSimilarityFunction { - public FloatDenseVectorSimilarityFunction( - VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { +public class FloatVectorSimilarityFunction extends VectorSimilarityFunction { + public FloatVectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { super(similarityFunction, vector1, vector2); } @Override protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { - checkSize(f1.floatVectorVal(doc).length, f2.floatVectorVal(doc).length); + assertSameSize(f1.floatVectorVal(doc).length, f2.floatVectorVal(doc).length); return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc)); } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorValueSource.java similarity index 85% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorValueSource.java index d5b976f2275e..5fa6bc9bf63d 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFloatConstValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorValueSource.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource; +package org.apache.lucene.queries.function.valuesource.densevectors; import java.io.IOException; import java.util.Arrays; @@ -24,10 +24,10 @@ import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; -public class DenseVectorFloatConstValueSource extends ValueSource { +public class FloatVectorValueSource extends ValueSource { float[] vector; - public DenseVectorFloatConstValueSource(List constVector) { + public FloatVectorValueSource(List constVector) { this.vector = new float[constVector.size()]; for (int i = 0; i < constVector.size(); i++) { vector[i] = constVector.get(i).floatValue(); @@ -57,8 +57,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof DenseVectorFloatConstValueSource)) return false; - DenseVectorFloatConstValueSource other = (DenseVectorFloatConstValueSource) o; + if (!(o instanceof FloatVectorValueSource)) return false; + FloatVectorValueSource other = (FloatVectorValueSource) o; return Arrays.equals(vector, other.vector); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorFieldFunction.java similarity index 89% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorFieldFunction.java index 8629e1c9f8a6..59cacc245b07 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorFieldFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorFieldFunction.java @@ -14,19 +14,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource; +package org.apache.lucene.queries.function.valuesource.densevectors; import java.io.IOException; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; -public abstract class DenseVectorFieldFunction extends FunctionValues { +public abstract class VectorFieldFunction extends FunctionValues { protected final ValueSource vs; int lastDocID; - protected DenseVectorFieldFunction(ValueSource vs) { + protected VectorFieldFunction(ValueSource vs) { this.vs = vs; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorSimilarityFunction.java similarity index 77% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorSimilarityFunction.java index e24653cd3491..70bfc864b57d 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/DenseVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorSimilarityFunction.java @@ -14,23 +14,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource; +package org.apache.lucene.queries.function.valuesource.densevectors; import java.io.IOException; import java.util.Map; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.queries.function.valuesource.MultiFunction; -public abstract class DenseVectorSimilarityFunction extends ValueSource { +public abstract class VectorSimilarityFunction extends ValueSource { - protected final VectorSimilarityFunction similarityFunction; + protected final org.apache.lucene.index.VectorSimilarityFunction similarityFunction; protected final ValueSource vector1; protected final ValueSource vector2; - public DenseVectorSimilarityFunction( - VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + public VectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { this.similarityFunction = similarityFunction; this.vector1 = vector1; @@ -66,7 +66,7 @@ public String toString(int doc) throws IOException { }; } - protected void checkSize(int sizeVector1, int sizeVector2) throws IOException { + protected void assertSameSize(int sizeVector1, int sizeVector2){ if (sizeVector1 != sizeVector2) { throw new UnsupportedOperationException("Vectors must have the same size"); } @@ -76,10 +76,10 @@ protected void checkSize(int sizeVector1, int sizeVector2) throws IOException { @Override public boolean equals(Object o) { - return o instanceof DenseVectorSimilarityFunction - && similarityFunction.equals(((DenseVectorSimilarityFunction) o).similarityFunction) - && vector1.equals(((DenseVectorSimilarityFunction) o).vector1) - && vector2.equals(((DenseVectorSimilarityFunction) o).vector2); + return o instanceof VectorSimilarityFunction + && similarityFunction.equals(((VectorSimilarityFunction) o).similarityFunction) + && vector1.equals(((VectorSimilarityFunction) o).vector1) + && vector2.equals(((VectorSimilarityFunction) o).vector2); } @Override diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 44b8317e8bc6..222f5b1a5e9b 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -27,12 +27,12 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.ByteDenseVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.DenseVectorByteConstValueSource; -import org.apache.lucene.queries.function.valuesource.DenseVectorByteFieldSource; -import org.apache.lucene.queries.function.valuesource.DenseVectorFloatConstValueSource; -import org.apache.lucene.queries.function.valuesource.DenseVectorFloatFieldSource; -import org.apache.lucene.queries.function.valuesource.FloatDenseVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.densevectors.ByteVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.densevectors.ByteVectorValueSource; +import org.apache.lucene.queries.function.valuesource.densevectors.ByteVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.densevectors.FloatVectorValueSource; +import org.apache.lucene.queries.function.valuesource.densevectors.FloatVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.densevectors.FloatVectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -90,90 +90,90 @@ public static void afterClass() throws Exception { } public void testFloatVectorSimilarityFunctionConst() throws Exception { - ValueSource v1 = new DenseVectorFloatConstValueSource(List.of(1, 2, 3)); - ValueSource v2 = new DenseVectorFloatConstValueSource(List.of(5, 4, 1)); + ValueSource v1 = new FloatVectorValueSource(List.of(1, 2, 3)); + ValueSource v2 = new FloatVectorValueSource(List.of(5, 4, 1)); assertHits( new FunctionQuery( - new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.04f, 0.04f}); } public void testByteVectorSimilarityFunctionConst() throws Exception { - ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 3)); - ValueSource v2 = new DenseVectorByteConstValueSource(List.of(2, 5, 6)); + ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 3)); + ValueSource v2 = new ByteVectorValueSource(List.of(2, 5, 6)); assertHits( new FunctionQuery( - new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.05f, 0.05f}); } public void testFloatVectorSimilarityFunctionField() throws Exception { - ValueSource v1 = new DenseVectorFloatFieldSource("knnFloatField1"); - ValueSource v2 = new DenseVectorFloatFieldSource("knnFloatField2"); + ValueSource v1 = new FloatVectorFieldSource("knnFloatField1"); + ValueSource v2 = new FloatVectorFieldSource("knnFloatField2"); assertHits( new FunctionQuery( - new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.049776014f, 0.049776014f}); } public void testByteVectorSimilarityFunctionField() throws Exception { - ValueSource v1 = new DenseVectorByteFieldSource("knnByteField1"); - ValueSource v2 = new DenseVectorByteFieldSource("knnByteField2"); + ValueSource v1 = new ByteVectorFieldSource("knnByteField1"); + ValueSource v2 = new ByteVectorFieldSource("knnByteField2"); assertHits( new FunctionQuery( - new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.1f, 0.1f}); } public void testFloatVectorSimilarityFunctionMixed() throws Exception { - ValueSource v1 = new DenseVectorFloatConstValueSource(List.of(1, 2, 4)); - ValueSource v2 = new DenseVectorFloatFieldSource("knnFloatField1"); + ValueSource v1 = new FloatVectorValueSource(List.of(1, 2, 4)); + ValueSource v2 = new FloatVectorFieldSource("knnFloatField1"); assertHits( new FunctionQuery( - new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.5f, 0.5f}); } public void testByteVectorSimilarityFunctionMixed() throws Exception { - ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 4)); - ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); + ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 4)); + ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); assertHits( new FunctionQuery( - new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.5f, 0.5f}); } public void testDismatchDimension() { - ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 3, 4)); - ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); - ByteDenseVectorSimilarityFunction byteDenseVectorSimilarityFunction = - new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 3, 4)); + ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); - v1 = new DenseVectorFloatConstValueSource(List.of(1, 2)); - v2 = new DenseVectorFloatFieldSource("knnFloatField1"); - FloatDenseVectorSimilarityFunction floatDenseVectorSimilarityFunction = - new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + v1 = new FloatVectorValueSource(List.of(1, 2)); + v2 = new FloatVectorFieldSource("knnFloatField1"); + FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); } public void testMismatchType() { - ValueSource v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 3)); - ValueSource v2 = new DenseVectorByteFieldSource("knnByteField1"); - FloatDenseVectorSimilarityFunction floatDenseVectorSimilarityFunction = - new FloatDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 3)); + ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); + FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); - v1 = new DenseVectorByteConstValueSource(List.of(1, 2, 3)); - v2 = new DenseVectorFloatFieldSource("knnByteField1"); - ByteDenseVectorSimilarityFunction byteDenseVectorSimilarityFunction = - new ByteDenseVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + v1 = new ByteVectorValueSource(List.of(1, 2, 3)); + v2 = new FloatVectorFieldSource("knnByteField1"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); From b916d84bb82830cb529b30f51c1cf4e68ea7b88f Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 16 May 2023 11:01:10 +0100 Subject: [PATCH 19/35] minor rename --- .../{densevectors => }/ByteVectorFieldSource.java | 2 +- .../ByteVectorSimilarityFunction.java | 2 +- .../{densevectors => }/ByteVectorValueSource.java | 2 +- .../{densevectors => }/FloatVectorFieldSource.java | 2 +- .../FloatVectorSimilarityFunction.java | 2 +- .../{densevectors => }/FloatVectorValueSource.java | 2 +- .../{densevectors => }/VectorFieldFunction.java | 2 +- .../{densevectors => }/VectorSimilarityFunction.java | 3 +-- .../function/TestKnnVectorSimilarityFunctions.java | 12 ++++++------ 9 files changed, 14 insertions(+), 15 deletions(-) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{densevectors => }/ByteVectorFieldSource.java (97%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{densevectors => }/ByteVectorSimilarityFunction.java (95%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{densevectors => }/ByteVectorValueSource.java (97%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{densevectors => }/FloatVectorFieldSource.java (97%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{densevectors => }/FloatVectorSimilarityFunction.java (95%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{densevectors => }/FloatVectorValueSource.java (97%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{densevectors => }/VectorFieldFunction.java (96%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{densevectors => }/VectorSimilarityFunction.java (96%) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java similarity index 97% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorFieldSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java index 535baa76d4fe..1ba0a8b9ddae 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource.densevectors; +package org.apache.lucene.queries.function.valuesource; import java.io.IOException; import java.util.Arrays; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java similarity index 95% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorSimilarityFunction.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java index 0414e35ec622..a3bde3759886 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource.densevectors; +package org.apache.lucene.queries.function.valuesource; import java.io.IOException; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java similarity index 97% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java index 19e532a1647a..a53bbfc3e71e 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/ByteVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource.densevectors; +package org.apache.lucene.queries.function.valuesource; import java.io.IOException; import java.util.Arrays; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java similarity index 97% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorFieldSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java index 12570904d1b4..474c0af5dac3 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource.densevectors; +package org.apache.lucene.queries.function.valuesource; import java.io.IOException; import java.util.Arrays; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java similarity index 95% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorSimilarityFunction.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java index 9a59eca61e6f..b826c24bf3c2 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource.densevectors; +package org.apache.lucene.queries.function.valuesource; import java.io.IOException; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java similarity index 97% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java index 5fa6bc9bf63d..db8e74c352a6 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/FloatVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource.densevectors; +package org.apache.lucene.queries.function.valuesource; import java.io.IOException; import java.util.Arrays; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java similarity index 96% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorFieldFunction.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java index 59cacc245b07..6c53d100d85f 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorFieldFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource.densevectors; +package org.apache.lucene.queries.function.valuesource; import java.io.IOException; import org.apache.lucene.queries.function.FunctionValues; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java similarity index 96% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorSimilarityFunction.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java index 70bfc864b57d..c5d17d695495 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/densevectors/VectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java @@ -14,14 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.queries.function.valuesource.densevectors; +package org.apache.lucene.queries.function.valuesource; import java.io.IOException; import java.util.Map; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; -import org.apache.lucene.queries.function.valuesource.MultiFunction; public abstract class VectorSimilarityFunction extends ValueSource { diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 222f5b1a5e9b..debf6e10f1f8 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -27,12 +27,12 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.densevectors.ByteVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.densevectors.ByteVectorValueSource; -import org.apache.lucene.queries.function.valuesource.densevectors.ByteVectorFieldSource; -import org.apache.lucene.queries.function.valuesource.densevectors.FloatVectorValueSource; -import org.apache.lucene.queries.function.valuesource.densevectors.FloatVectorFieldSource; -import org.apache.lucene.queries.function.valuesource.densevectors.FloatVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.ByteVectorValueSource; +import org.apache.lucene.queries.function.valuesource.ByteVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.FloatVectorValueSource; +import org.apache.lucene.queries.function.valuesource.FloatVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; From a48dc2e4e743906a57ac58ffbbaab381b61a39f9 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 16 May 2023 11:03:43 +0100 Subject: [PATCH 20/35] minor rename --- lucene/queries/src/java/module-info.java | 1 - 1 file changed, 1 deletion(-) diff --git a/lucene/queries/src/java/module-info.java b/lucene/queries/src/java/module-info.java index b1923557bcba..ae3ff4d0712f 100644 --- a/lucene/queries/src/java/module-info.java +++ b/lucene/queries/src/java/module-info.java @@ -27,5 +27,4 @@ exports org.apache.lucene.queries.mlt; exports org.apache.lucene.queries.payloads; exports org.apache.lucene.queries.spans; - exports org.apache.lucene.queries.function.valuesource.densevectors; } From 01497a6a60037a918698cf1802e79fbf92e9da7f Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 16 May 2023 11:25:44 +0100 Subject: [PATCH 21/35] moved to assertion --- .../function/valuesource/ByteVectorSimilarityFunction.java | 2 +- .../function/valuesource/FloatVectorSimilarityFunction.java | 2 +- .../function/valuesource/VectorSimilarityFunction.java | 6 ------ .../queries/function/TestKnnVectorSimilarityFunctions.java | 4 ++-- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java index a3bde3759886..ea7c844eadd2 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -29,7 +29,7 @@ public ByteVectorSimilarityFunction( @Override protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { - assertSameSize(f1.byteVectorVal(doc).length, f2.byteVectorVal(doc).length); + assert f1.byteVectorVal(doc).length == f2.byteVectorVal(doc).length: "Vectors must have the same length"; return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc)); } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java index b826c24bf3c2..af8a11995a77 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java @@ -29,7 +29,7 @@ public FloatVectorSimilarityFunction( @Override protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { - assertSameSize(f1.floatVectorVal(doc).length, f2.floatVectorVal(doc).length); + assert f1.floatVectorVal(doc).length == f2.floatVectorVal(doc).length: "Vectors must have the same length"; return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc)); } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java index c5d17d695495..b14eca1413d1 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java @@ -65,12 +65,6 @@ public String toString(int doc) throws IOException { }; } - protected void assertSameSize(int sizeVector1, int sizeVector2){ - if (sizeVector1 != sizeVector2) { - throw new UnsupportedOperationException("Vectors must have the same size"); - } - } - protected abstract float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException; @Override diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index debf6e10f1f8..93e376e98a0c 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -149,7 +149,7 @@ public void testDismatchDimension() { ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - UnsupportedOperationException.class, + AssertionError.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); v1 = new FloatVectorValueSource(List.of(1, 2)); @@ -157,7 +157,7 @@ public void testDismatchDimension() { FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - UnsupportedOperationException.class, + AssertionError.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); } From 5a2eabfaca6f9ffd4d889899c518b019f396fb69 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 16 May 2023 12:15:25 +0100 Subject: [PATCH 22/35] tydy and checks --- .../function/valuesource/ByteVectorFieldSource.java | 3 +++ .../valuesource/ByteVectorSimilarityFunction.java | 12 +++++++++--- .../function/valuesource/ByteVectorValueSource.java | 1 + .../function/valuesource/FloatVectorFieldSource.java | 3 +++ .../valuesource/FloatVectorSimilarityFunction.java | 12 +++++++++--- .../function/valuesource/FloatVectorValueSource.java | 1 + .../function/valuesource/VectorFieldFunction.java | 1 + .../valuesource/VectorSimilarityFunction.java | 5 ++++- .../function/TestKnnVectorSimilarityFunctions.java | 8 ++++---- 9 files changed, 35 insertions(+), 11 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java index 1ba0a8b9ddae..863221767cf3 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java @@ -25,6 +25,9 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; +/** + * An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields. + */ public class ByteVectorFieldSource extends ValueSource { private final String fieldName; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java index ea7c844eadd2..f4bd9d8c8bfa 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -17,19 +17,25 @@ package org.apache.lucene.queries.function.valuesource; import java.io.IOException; - import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; +/** + * ByteVectorSimilarityFunction returns a similarity function between two knn vectors + * with byte elements. + */ public class ByteVectorSimilarityFunction extends VectorSimilarityFunction { public ByteVectorSimilarityFunction( - org.apache.lucene.index.VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { super(similarityFunction, vector1, vector2); } @Override protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { - assert f1.byteVectorVal(doc).length == f2.byteVectorVal(doc).length: "Vectors must have the same length"; + assert f1.byteVectorVal(doc).length == f2.byteVectorVal(doc).length + : "Vectors must have the same length"; return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc)); } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java index a53bbfc3e71e..88f3184f3e88 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java @@ -23,6 +23,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; + /** Function that returns a constant byte vector value for every document. */ public class ByteVectorValueSource extends ValueSource { byte[] vector; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java index 474c0af5dac3..feef8d0033fd 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java @@ -25,6 +25,9 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; +/** + * An implementation for retrieving {@link FunctionValues} instances for float knn vectors fields. + */ public class FloatVectorFieldSource extends ValueSource { private final String fieldName; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java index af8a11995a77..a671a5c46762 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java @@ -17,19 +17,25 @@ package org.apache.lucene.queries.function.valuesource; import java.io.IOException; - import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; +/** + * FloatVectorSimilarityFunction returns a similarity function between two knn vectors + * with float elements. + */ public class FloatVectorSimilarityFunction extends VectorSimilarityFunction { public FloatVectorSimilarityFunction( - org.apache.lucene.index.VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { super(similarityFunction, vector1, vector2); } @Override protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { - assert f1.floatVectorVal(doc).length == f2.floatVectorVal(doc).length: "Vectors must have the same length"; + assert f1.floatVectorVal(doc).length == f2.floatVectorVal(doc).length + : "Vectors must have the same length"; return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc)); } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java index db8e74c352a6..825157d05c9b 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java @@ -24,6 +24,7 @@ import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; +/** Function that returns a constant float vector value for every document. */ public class FloatVectorValueSource extends ValueSource { float[] vector; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java index 6c53d100d85f..fda853aaeec4 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java @@ -21,6 +21,7 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; +/** An implementation for retrieving {@link FunctionValues} instances for knn vectors fields. */ public abstract class VectorFieldFunction extends FunctionValues { protected final ValueSource vs; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java index b14eca1413d1..ebb1400529f5 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java @@ -22,6 +22,7 @@ import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; +/** VectorSimilarityFunction returns a similarity function between two knn vectors. */ public abstract class VectorSimilarityFunction extends ValueSource { protected final org.apache.lucene.index.VectorSimilarityFunction similarityFunction; @@ -29,7 +30,9 @@ public abstract class VectorSimilarityFunction extends ValueSource { protected final ValueSource vector2; public VectorSimilarityFunction( - org.apache.lucene.index.VectorSimilarityFunction similarityFunction, ValueSource vector1, ValueSource vector2) { + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { this.similarityFunction = similarityFunction; this.vector1 = vector1; diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 93e376e98a0c..a7444ceecab8 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -27,12 +27,12 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.ByteVectorFieldSource; import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; import org.apache.lucene.queries.function.valuesource.ByteVectorValueSource; -import org.apache.lucene.queries.function.valuesource.ByteVectorFieldSource; -import org.apache.lucene.queries.function.valuesource.FloatVectorValueSource; import org.apache.lucene.queries.function.valuesource.FloatVectorFieldSource; import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.FloatVectorValueSource; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -149,7 +149,7 @@ public void testDismatchDimension() { ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - AssertionError.class, + AssertionError.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); v1 = new FloatVectorValueSource(List.of(1, 2)); @@ -157,7 +157,7 @@ public void testDismatchDimension() { FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - AssertionError.class, + AssertionError.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); } From 7a6393328a6e94257b8ab4fab17c84814daf728c Mon Sep 17 00:00:00 2001 From: Elia Date: Thu, 25 May 2023 14:18:30 +0200 Subject: [PATCH 23/35] Managed case of field not indexed for vector valuesource Addressing review --- .../valuesource/ByteVectorFieldSource.java | 17 +-- .../ByteVectorSimilarityFunction.java | 9 ++ .../valuesource/FloatVectorFieldSource.java | 16 +- .../FloatVectorSimilarityFunction.java | 8 + .../valuesource/VectorFieldFunction.java | 3 + .../TestKnnVectorSimilarityFunctions.java | 142 +++++++++++++----- 6 files changed, 137 insertions(+), 58 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java index 863221767cf3..a15d281fc097 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java @@ -25,6 +25,8 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; +import static java.util.Optional.ofNullable; + /** * An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields. */ @@ -40,15 +42,19 @@ public FunctionValues getValues(Map context, LeafReaderContext r throws IOException { final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); + + if (vectorValues == null) { + throw new IllegalArgumentException("no byte vector value is indexed for field '" + fieldName + "'"); + } + return new VectorFieldFunction(this) { - byte[] defaultVector = null; @Override public byte[] byteVectorVal(int doc) throws IOException { if (exists(doc)) { return vectorValues.vectorValue(); } else { - return defaultVector(); + return null; } } @@ -57,13 +63,6 @@ protected DocIdSetIterator getVectorIterator() { return vectorValues; } - private byte[] defaultVector() { - if (defaultVector == null) { - defaultVector = new byte[vectorValues.dimension()]; - Arrays.fill(defaultVector, (byte) 0); - } - return defaultVector; - } }; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java index f4bd9d8c8bfa..2b9208ea199c 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -34,8 +34,17 @@ public ByteVectorSimilarityFunction( @Override protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + + var v1 = f1.byteVectorVal(doc); + var v2 = f2.byteVectorVal(doc); + + if (v1 == null || v2 == null) { + return Float.NaN; + } + assert f1.byteVectorVal(doc).length == f2.byteVectorVal(doc).length : "Vectors must have the same length"; + return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc)); } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java index feef8d0033fd..ad2eb49a7b40 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java @@ -17,7 +17,6 @@ package org.apache.lucene.queries.function.valuesource; import java.io.IOException; -import java.util.Arrays; import java.util.Map; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; @@ -40,15 +39,18 @@ public FunctionValues getValues(Map context, LeafReaderContext r throws IOException { final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); + + if (vectorValues == null) { + throw new IllegalArgumentException("no float vector value is indexed for field '" + fieldName + "'"); + } return new VectorFieldFunction(this) { - float[] defaultVector = null; @Override public float[] floatVectorVal(int doc) throws IOException { if (exists(doc)) { return vectorValues.vectorValue(); } else { - return defaultVector(); + return null; } } @@ -56,14 +58,6 @@ public float[] floatVectorVal(int doc) throws IOException { protected DocIdSetIterator getVectorIterator() { return vectorValues; } - - private float[] defaultVector() { - if (defaultVector == null) { - defaultVector = new float[vectorValues.dimension()]; - Arrays.fill(defaultVector, 0.f); - } - return defaultVector; - } }; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java index a671a5c46762..fd7206ce9990 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java @@ -34,6 +34,14 @@ public FloatVectorSimilarityFunction( @Override protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + + var v1 = f1.floatVectorVal(doc); + var v2 = f2.floatVectorVal(doc); + + if (v1 == null || v2 == null) { + return Float.NaN; + } + assert f1.floatVectorVal(doc).length == f2.floatVectorVal(doc).length : "Vectors must have the same length"; return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc)); diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java index fda853aaeec4..e1bc54f3e712 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java @@ -38,13 +38,16 @@ public String toString(int doc) throws IOException { return vs.description() + strVal(doc); } + @Override public boolean exists(int doc) throws IOException { if (doc < lastDocID) { throw new IllegalArgumentException( "docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + doc); } + lastDocID = doc; + int curDocID = getVectorIterator().docID(); if (doc > curDocID) { curDocID = getVectorIterator().advance(doc); diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index a7444ceecab8..a3294004ceee 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -16,7 +16,7 @@ */ package org.apache.lucene.queries.function; -import java.util.List; +import com.sun.jdi.Value; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -47,6 +47,9 @@ import org.apache.lucene.util.BytesRef; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; public class TestKnnVectorSimilarityFunctions extends LuceneTestCase { static Directory dir; @@ -62,16 +65,30 @@ public static void beforeClass() throws Exception { IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer); iwConfig.setMergePolicy(newLogMergePolicy()); RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig); - for (String docId : documents) { - Document document = new Document(); - document.add(new StringField("id", docId, Field.Store.NO)); - document.add(new SortedDocValuesField("id", new BytesRef(docId))); - document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); - document.add(new KnnFloatVectorField("knnFloatField2", new float[] {5.2f, 3.2f, 3.1f})); - document.add(new KnnByteVectorField("knnByteField1", new byte[] {1, 2, 3})); - document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); - iw.addDocument(document); - } + + Document document = new Document(); + document.add(new StringField("id", "1", Field.Store.NO)); + document.add(new SortedDocValuesField("id", new BytesRef("1"))); + document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document.add(new KnnFloatVectorField("knnFloatField2", new float[] {5.2f, 3.2f, 3.1f})); + + // add only to the first document + document.add(new KnnFloatVectorField("knnFloatField3", new float[] {1.0f, 1.0f, 1.0f})); + document.add(new KnnByteVectorField("knnByteField3", new byte[] {1, 1, 1})); + + document.add(new KnnByteVectorField("knnByteField1", new byte[] {1, 2, 3})); + document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); + iw.addDocument(document); + + Document document2 = new Document(); + document2.add(new StringField("id", "2", Field.Store.NO)); + document2.add(new SortedDocValuesField("id", new BytesRef("2"))); + document2.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document2.add(new KnnFloatVectorField("knnFloatField2", new float[] {5.2f, 3.2f, 3.1f})); + + document2.add(new KnnByteVectorField("knnByteField1", new byte[] {1, 2, 3})); + document2.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); + iw.addDocument(document2); reader = iw.getReader(); searcher = newSearcher(reader); @@ -89,80 +106,108 @@ public static void afterClass() throws Exception { analyzer = null; } - public void testFloatVectorSimilarityFunctionConst() throws Exception { - ValueSource v1 = new FloatVectorValueSource(List.of(1, 2, 3)); - ValueSource v2 = new FloatVectorValueSource(List.of(5, 4, 1)); + @Test + public void floatVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() throws Exception { + var v1 = new FloatVectorValueSource(List.of(1, 2, 3)); + var v2 = new FloatVectorValueSource(List.of(5, 4, 1)); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.04f, 0.04f}); } - public void testByteVectorSimilarityFunctionConst() throws Exception { - ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 3)); - ValueSource v2 = new ByteVectorValueSource(List.of(2, 5, 6)); + @Test + public void byteVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() throws Exception { + var v1 = new ByteVectorValueSource(List.of(1, 2, 3)); + var v2 = new ByteVectorValueSource(List.of(2, 5, 6)); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.05f, 0.05f}); } - public void testFloatVectorSimilarityFunctionField() throws Exception { - ValueSource v1 = new FloatVectorFieldSource("knnFloatField1"); - ValueSource v2 = new FloatVectorFieldSource("knnFloatField2"); + @Test + public void floatVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() throws Exception { + var v1 = new FloatVectorFieldSource("knnFloatField1"); + var v2 = new FloatVectorFieldSource("knnFloatField2"); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.049776014f, 0.049776014f}); } - public void testByteVectorSimilarityFunctionField() throws Exception { - ValueSource v1 = new ByteVectorFieldSource("knnByteField1"); - ValueSource v2 = new ByteVectorFieldSource("knnByteField2"); + @Test + public void byteVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() throws Exception { + var v1 = new ByteVectorFieldSource("knnByteField1"); + var v2 = new ByteVectorFieldSource("knnByteField2"); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.1f, 0.1f}); } - public void testFloatVectorSimilarityFunctionMixed() throws Exception { - ValueSource v1 = new FloatVectorValueSource(List.of(1, 2, 4)); - ValueSource v2 = new FloatVectorFieldSource("knnFloatField1"); + @Test + public void floatVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() throws Exception { + var v1 = new FloatVectorValueSource(List.of(1, 2, 4)); + var v2 = new FloatVectorFieldSource("knnFloatField1"); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.5f, 0.5f}); } - public void testByteVectorSimilarityFunctionMixed() throws Exception { - ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 4)); - ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); + @Test + public void byteVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() throws Exception { + var v1 = new ByteVectorValueSource(List.of(1, 2, 4)); + var v2 = new ByteVectorFieldSource("knnByteField1"); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), new float[] {0.5f, 0.5f}); } - public void testDismatchDimension() { + @Test + public void floatVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() throws Exception { + var v1 = new FloatVectorValueSource(List.of(2.0, 1.0, 1.0)); + var v2 = new FloatVectorFieldSource("knnFloatField3"); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, Float.NaN}); + } + + @Test + public void byteVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() throws Exception { + var v1 = new ByteVectorValueSource(List.of(2.0, 1.0, 1.0)); + var v2 = new ByteVectorFieldSource("knnByteField3"); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, Float.NaN}); + } + + @Test + public void vectorSimilarityBetweenTwoVectorsWithDifferentDimensions_shouldRaiseException() { ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 3, 4)); ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = - new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - AssertionError.class, - () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + AssertionError.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); v1 = new FloatVectorValueSource(List.of(1, 2)); v2 = new FloatVectorFieldSource("knnFloatField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = - new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( AssertionError.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); } - public void testMismatchType() { - ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 3)); + @Test + public void vectorSimilarityBetweenByteAndFloatVectors_shouldRaiseException() { + var v1 = new ByteVectorValueSource(List.of(1, 2, 3)); ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); @@ -171,7 +216,7 @@ public void testMismatchType() { () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); v1 = new ByteVectorValueSource(List.of(1, 2, 3)); - v2 = new FloatVectorFieldSource("knnByteField1"); + v2 = new FloatVectorFieldSource("knnFloatField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( @@ -179,7 +224,28 @@ public void testMismatchType() { () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); } - public static void assertHits(Query q, float[] scores) throws Exception { + @Test + public void vectorFielValueSourceWithIncorrectFields_shouldRaiseException() { + ValueSource v1 = new ByteVectorFieldSource("knnByteField1"); + ValueSource v2 = new ByteVectorFieldSource("knnFloatField2"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + + assertThrows( + IllegalArgumentException.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + + v1 = new FloatVectorFieldSource("knnByteField1"); + v2 = new FloatVectorFieldSource("knnFloatField2"); + FloatVectorSimilarityFunction floatVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + + assertThrows( + IllegalArgumentException.class, + () -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10)); + } + + private static void assertHits(Query q, float[] scores) throws Exception { ScoreDoc[] expected = new ScoreDoc[scores.length]; int[] expectedDocs = new int[scores.length]; for (int i = 0; i < expected.length; i++) { From 363aea0d8c14233367abb3b696109e8991ac6a0e Mon Sep 17 00:00:00 2001 From: Elia Date: Thu, 25 May 2023 14:23:05 +0200 Subject: [PATCH 24/35] spotless apply --- .../valuesource/ByteVectorFieldSource.java | 7 +-- .../valuesource/FloatVectorFieldSource.java | 3 +- .../valuesource/VectorFieldFunction.java | 1 - .../TestKnnVectorSimilarityFunctions.java | 51 ++++++++++--------- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java index a15d281fc097..84bd2dd053a5 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java @@ -17,7 +17,6 @@ package org.apache.lucene.queries.function.valuesource; import java.io.IOException; -import java.util.Arrays; import java.util.Map; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; @@ -25,8 +24,6 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; -import static java.util.Optional.ofNullable; - /** * An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields. */ @@ -44,7 +41,8 @@ public FunctionValues getValues(Map context, LeafReaderContext r final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); if (vectorValues == null) { - throw new IllegalArgumentException("no byte vector value is indexed for field '" + fieldName + "'"); + throw new IllegalArgumentException( + "no byte vector value is indexed for field '" + fieldName + "'"); } return new VectorFieldFunction(this) { @@ -62,7 +60,6 @@ public byte[] byteVectorVal(int doc) throws IOException { protected DocIdSetIterator getVectorIterator() { return vectorValues; } - }; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java index ad2eb49a7b40..7d0dafdb6a8c 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java @@ -41,7 +41,8 @@ public FunctionValues getValues(Map context, LeafReaderContext r final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); if (vectorValues == null) { - throw new IllegalArgumentException("no float vector value is indexed for field '" + fieldName + "'"); + throw new IllegalArgumentException( + "no float vector value is indexed for field '" + fieldName + "'"); } return new VectorFieldFunction(this) { diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java index e1bc54f3e712..4488a4e40258 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java @@ -38,7 +38,6 @@ public String toString(int doc) throws IOException { return vs.description() + strVal(doc); } - @Override public boolean exists(int doc) throws IOException { if (doc < lastDocID) { diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index a3294004ceee..496867e84549 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -16,7 +16,7 @@ */ package org.apache.lucene.queries.function; -import com.sun.jdi.Value; +import java.util.List; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -49,8 +49,6 @@ import org.junit.BeforeClass; import org.junit.Test; -import java.util.List; - public class TestKnnVectorSimilarityFunctions extends LuceneTestCase { static Directory dir; static Analyzer analyzer; @@ -127,7 +125,8 @@ public void byteVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() t } @Test - public void floatVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() throws Exception { + public void floatVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() + throws Exception { var v1 = new FloatVectorFieldSource("knnFloatField1"); var v2 = new FloatVectorFieldSource("knnFloatField2"); assertHits( @@ -147,7 +146,8 @@ public void byteVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() } @Test - public void floatVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() throws Exception { + public void floatVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() + throws Exception { var v1 = new FloatVectorValueSource(List.of(1, 2, 4)); var v2 = new FloatVectorFieldSource("knnFloatField1"); assertHits( @@ -157,7 +157,8 @@ public void floatVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectR } @Test - public void byteVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() throws Exception { + public void byteVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() + throws Exception { var v1 = new ByteVectorValueSource(List.of(1, 2, 4)); var v2 = new ByteVectorFieldSource("knnByteField1"); assertHits( @@ -167,23 +168,25 @@ public void byteVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectRe } @Test - public void floatVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() throws Exception { + public void floatVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() + throws Exception { var v1 = new FloatVectorValueSource(List.of(2.0, 1.0, 1.0)); var v2 = new FloatVectorFieldSource("knnFloatField3"); assertHits( - new FunctionQuery( - new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), - new float[] {0.5f, Float.NaN}); + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, Float.NaN}); } @Test - public void byteVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() throws Exception { + public void byteVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() + throws Exception { var v1 = new ByteVectorValueSource(List.of(2.0, 1.0, 1.0)); var v2 = new ByteVectorFieldSource("knnByteField3"); assertHits( - new FunctionQuery( - new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), - new float[] {0.5f, Float.NaN}); + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, Float.NaN}); } @Test @@ -191,15 +194,15 @@ public void vectorSimilarityBetweenTwoVectorsWithDifferentDimensions_shouldRaise ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 3, 4)); ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = - new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - AssertionError.class, - () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + AssertionError.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); v1 = new FloatVectorValueSource(List.of(1, 2)); v2 = new FloatVectorFieldSource("knnFloatField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = - new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( AssertionError.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); @@ -229,20 +232,20 @@ public void vectorFielValueSourceWithIncorrectFields_shouldRaiseException() { ValueSource v1 = new ByteVectorFieldSource("knnByteField1"); ValueSource v2 = new ByteVectorFieldSource("knnFloatField2"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = - new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - IllegalArgumentException.class, - () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + IllegalArgumentException.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); v1 = new FloatVectorFieldSource("knnByteField1"); v2 = new FloatVectorFieldSource("knnFloatField2"); FloatVectorSimilarityFunction floatVectorSimilarityFunction = - new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - IllegalArgumentException.class, - () -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10)); + IllegalArgumentException.class, + () -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10)); } private static void assertHits(Query q, float[] scores) throws Exception { From 108c4a8f0b40a2ea36ab5e5a54fd807c07fb882e Mon Sep 17 00:00:00 2001 From: Elia Date: Thu, 25 May 2023 19:34:07 +0200 Subject: [PATCH 25/35] fix typo --- .../queries/function/TestKnnVectorSimilarityFunctions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 496867e84549..20f13a9f9d98 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -228,7 +228,7 @@ public void vectorSimilarityBetweenByteAndFloatVectors_shouldRaiseException() { } @Test - public void vectorFielValueSourceWithIncorrectFields_shouldRaiseException() { + public void vectorFieldValueSourceWithDifferentFieldType_shouldRaiseException() { ValueSource v1 = new ByteVectorFieldSource("knnByteField1"); ValueSource v2 = new ByteVectorFieldSource("knnFloatField2"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = From d95c6c16efb04bfef608c90dec4c6a1441bd4593 Mon Sep 17 00:00:00 2001 From: Elia Date: Tue, 13 Jun 2023 11:13:57 +0200 Subject: [PATCH 26/35] Renaming classes --- ...e.java => ConstByteVectorValueSource.java} | 8 +++--- ....java => ConstFloatVectorValueSource.java} | 8 +++--- .../TestKnnVectorSimilarityFunctions.java | 28 +++++++++---------- 3 files changed, 22 insertions(+), 22 deletions(-) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{ByteVectorValueSource.java => ConstByteVectorValueSource.java} (89%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{FloatVectorValueSource.java => ConstFloatVectorValueSource.java} (89%) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstByteVectorValueSource.java similarity index 89% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstByteVectorValueSource.java index 88f3184f3e88..9807a17e409a 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstByteVectorValueSource.java @@ -25,10 +25,10 @@ import org.apache.lucene.queries.function.ValueSource; /** Function that returns a constant byte vector value for every document. */ -public class ByteVectorValueSource extends ValueSource { +public class ConstByteVectorValueSource extends ValueSource { byte[] vector; - public ByteVectorValueSource(List constVector) { + public ConstByteVectorValueSource(List constVector) { this.vector = new byte[constVector.size()]; for (int i = 0; i < constVector.size(); i++) { vector[i] = constVector.get(i).byteValue(); @@ -58,8 +58,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof ByteVectorValueSource)) return false; - ByteVectorValueSource other = (ByteVectorValueSource) o; + if (!(o instanceof ConstByteVectorValueSource)) return false; + ConstByteVectorValueSource other = (ConstByteVectorValueSource) o; return Arrays.equals(vector, other.vector); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstFloatVectorValueSource.java similarity index 89% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstFloatVectorValueSource.java index 825157d05c9b..b3caf83b2664 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstFloatVectorValueSource.java @@ -25,10 +25,10 @@ import org.apache.lucene.queries.function.ValueSource; /** Function that returns a constant float vector value for every document. */ -public class FloatVectorValueSource extends ValueSource { +public class ConstFloatVectorValueSource extends ValueSource { float[] vector; - public FloatVectorValueSource(List constVector) { + public ConstFloatVectorValueSource(List constVector) { this.vector = new float[constVector.size()]; for (int i = 0; i < constVector.size(); i++) { vector[i] = constVector.get(i).floatValue(); @@ -58,8 +58,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof FloatVectorValueSource)) return false; - FloatVectorValueSource other = (FloatVectorValueSource) o; + if (!(o instanceof ConstFloatVectorValueSource)) return false; + ConstFloatVectorValueSource other = (ConstFloatVectorValueSource) o; return Arrays.equals(vector, other.vector); } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 20f13a9f9d98..60e684c3881e 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -29,10 +29,10 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.valuesource.ByteVectorFieldSource; import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.ByteVectorValueSource; +import org.apache.lucene.queries.function.valuesource.ConstByteVectorValueSource; +import org.apache.lucene.queries.function.valuesource.ConstFloatVectorValueSource; import org.apache.lucene.queries.function.valuesource.FloatVectorFieldSource; import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.FloatVectorValueSource; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -106,8 +106,8 @@ public static void afterClass() throws Exception { @Test public void floatVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() throws Exception { - var v1 = new FloatVectorValueSource(List.of(1, 2, 3)); - var v2 = new FloatVectorValueSource(List.of(5, 4, 1)); + var v1 = new ConstFloatVectorValueSource(List.of(1, 2, 3)); + var v2 = new ConstFloatVectorValueSource(List.of(5, 4, 1)); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -116,8 +116,8 @@ public void floatVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() @Test public void byteVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() throws Exception { - var v1 = new ByteVectorValueSource(List.of(1, 2, 3)); - var v2 = new ByteVectorValueSource(List.of(2, 5, 6)); + var v1 = new ConstByteVectorValueSource(List.of(1, 2, 3)); + var v2 = new ConstByteVectorValueSource(List.of(2, 5, 6)); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -148,7 +148,7 @@ public void byteVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() @Test public void floatVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() throws Exception { - var v1 = new FloatVectorValueSource(List.of(1, 2, 4)); + var v1 = new ConstFloatVectorValueSource(List.of(1, 2, 4)); var v2 = new FloatVectorFieldSource("knnFloatField1"); assertHits( new FunctionQuery( @@ -159,7 +159,7 @@ public void floatVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectR @Test public void byteVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() throws Exception { - var v1 = new ByteVectorValueSource(List.of(1, 2, 4)); + var v1 = new ConstByteVectorValueSource(List.of(1, 2, 4)); var v2 = new ByteVectorFieldSource("knnByteField1"); assertHits( new FunctionQuery( @@ -170,7 +170,7 @@ public void byteVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectRe @Test public void floatVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() throws Exception { - var v1 = new FloatVectorValueSource(List.of(2.0, 1.0, 1.0)); + var v1 = new ConstFloatVectorValueSource(List.of(2.0, 1.0, 1.0)); var v2 = new FloatVectorFieldSource("knnFloatField3"); assertHits( new FunctionQuery( @@ -181,7 +181,7 @@ public void floatVectorSimilarityComputedOnDocumentWithMissingFieldVector_should @Test public void byteVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() throws Exception { - var v1 = new ByteVectorValueSource(List.of(2.0, 1.0, 1.0)); + var v1 = new ConstByteVectorValueSource(List.of(2.0, 1.0, 1.0)); var v2 = new ByteVectorFieldSource("knnByteField3"); assertHits( new FunctionQuery( @@ -191,7 +191,7 @@ public void byteVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldR @Test public void vectorSimilarityBetweenTwoVectorsWithDifferentDimensions_shouldRaiseException() { - ValueSource v1 = new ByteVectorValueSource(List.of(1, 2, 3, 4)); + ValueSource v1 = new ConstByteVectorValueSource(List.of(1, 2, 3, 4)); ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); @@ -199,7 +199,7 @@ public void vectorSimilarityBetweenTwoVectorsWithDifferentDimensions_shouldRaise AssertionError.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); - v1 = new FloatVectorValueSource(List.of(1, 2)); + v1 = new ConstFloatVectorValueSource(List.of(1, 2)); v2 = new FloatVectorFieldSource("knnFloatField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); @@ -210,7 +210,7 @@ public void vectorSimilarityBetweenTwoVectorsWithDifferentDimensions_shouldRaise @Test public void vectorSimilarityBetweenByteAndFloatVectors_shouldRaiseException() { - var v1 = new ByteVectorValueSource(List.of(1, 2, 3)); + var v1 = new ConstByteVectorValueSource(List.of(1, 2, 3)); ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); @@ -218,7 +218,7 @@ public void vectorSimilarityBetweenByteAndFloatVectors_shouldRaiseException() { UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); - v1 = new ByteVectorValueSource(List.of(1, 2, 3)); + v1 = new ConstByteVectorValueSource(List.of(1, 2, 3)); v2 = new FloatVectorFieldSource("knnFloatField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); From 69ce94de196fdac379a9c3e87df06d8373f44671 Mon Sep 17 00:00:00 2001 From: Elia Date: Tue, 13 Jun 2023 14:03:41 +0200 Subject: [PATCH 27/35] code refactoring --- ...rce.java => ByteKnnVectorFieldSource.java} | 10 +-- ...urce.java => ConsKnnFloatValueSource.java} | 10 +-- ...ava => ConstKnnByteVectorValueSource.java} | 10 +-- ...ce.java => FloatKnnVectorFieldSource.java} | 10 +-- .../TestKnnVectorSimilarityFunctions.java | 89 +++++++++---------- 5 files changed, 63 insertions(+), 66 deletions(-) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{ByteVectorFieldSource.java => ByteKnnVectorFieldSource.java} (88%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{ConstFloatVectorValueSource.java => ConsKnnFloatValueSource.java} (86%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{ConstByteVectorValueSource.java => ConstKnnByteVectorValueSource.java} (85%) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{FloatVectorFieldSource.java => FloatKnnVectorFieldSource.java} (88%) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java similarity index 88% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java index 84bd2dd053a5..83cc49c772eb 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -27,10 +27,10 @@ /** * An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields. */ -public class ByteVectorFieldSource extends ValueSource { +public class ByteKnnVectorFieldSource extends ValueSource { private final String fieldName; - public ByteVectorFieldSource(String fieldName) { + public ByteKnnVectorFieldSource(String fieldName) { this.fieldName = fieldName; } @@ -65,8 +65,8 @@ protected DocIdSetIterator getVectorIterator() { @Override public boolean equals(Object o) { - if (o.getClass() != ByteVectorFieldSource.class) return false; - ByteVectorFieldSource other = (ByteVectorFieldSource) o; + if (o.getClass() != ByteKnnVectorFieldSource.class) return false; + ByteKnnVectorFieldSource other = (ByteKnnVectorFieldSource) o; return fieldName.equals(other.fieldName); } @@ -77,6 +77,6 @@ public int hashCode() { @Override public String description() { - return "denseByteVectorField(" + fieldName + ")"; + return "ByteKnnVectorFieldSource(" + fieldName + ")"; } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstFloatVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConsKnnFloatValueSource.java similarity index 86% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstFloatVectorValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConsKnnFloatValueSource.java index b3caf83b2664..45871d2836be 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstFloatVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConsKnnFloatValueSource.java @@ -25,10 +25,10 @@ import org.apache.lucene.queries.function.ValueSource; /** Function that returns a constant float vector value for every document. */ -public class ConstFloatVectorValueSource extends ValueSource { +public class ConsKnnFloatValueSource extends ValueSource { float[] vector; - public ConstFloatVectorValueSource(List constVector) { + public ConsKnnFloatValueSource(List constVector) { this.vector = new float[constVector.size()]; for (int i = 0; i < constVector.size(); i++) { vector[i] = constVector.get(i).floatValue(); @@ -58,8 +58,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof ConstFloatVectorValueSource)) return false; - ConstFloatVectorValueSource other = (ConstFloatVectorValueSource) o; + if (!(o instanceof ConsKnnFloatValueSource)) return false; + ConsKnnFloatValueSource other = (ConsKnnFloatValueSource) o; return Arrays.equals(vector, other.vector); } @@ -70,6 +70,6 @@ public int hashCode() { @Override public String description() { - return "denseVectorConst(" + Arrays.toString(vector) + ')'; + return "ConsKnnFloatValueSource(" + Arrays.toString(vector) + ')'; } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java similarity index 85% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstByteVectorValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java index 9807a17e409a..79115eecaa6f 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstByteVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java @@ -25,10 +25,10 @@ import org.apache.lucene.queries.function.ValueSource; /** Function that returns a constant byte vector value for every document. */ -public class ConstByteVectorValueSource extends ValueSource { +public class ConstKnnByteVectorValueSource extends ValueSource { byte[] vector; - public ConstByteVectorValueSource(List constVector) { + public ConstKnnByteVectorValueSource(List constVector) { this.vector = new byte[constVector.size()]; for (int i = 0; i < constVector.size(); i++) { vector[i] = constVector.get(i).byteValue(); @@ -58,8 +58,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof ConstByteVectorValueSource)) return false; - ConstByteVectorValueSource other = (ConstByteVectorValueSource) o; + if (!(o instanceof ConstKnnByteVectorValueSource)) return false; + ConstKnnByteVectorValueSource other = (ConstKnnByteVectorValueSource) o; return Arrays.equals(vector, other.vector); } @@ -70,6 +70,6 @@ public int hashCode() { @Override public String description() { - return "denseVectorConst(" + Arrays.toString(vector) + ')'; + return "ConstKnnByteVectorValueSource(" + Arrays.toString(vector) + ')'; } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java similarity index 88% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java index 7d0dafdb6a8c..64def5ed67fb 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -27,10 +27,10 @@ /** * An implementation for retrieving {@link FunctionValues} instances for float knn vectors fields. */ -public class FloatVectorFieldSource extends ValueSource { +public class FloatKnnVectorFieldSource extends ValueSource { private final String fieldName; - public FloatVectorFieldSource(String fieldName) { + public FloatKnnVectorFieldSource(String fieldName) { this.fieldName = fieldName; } @@ -64,8 +64,8 @@ protected DocIdSetIterator getVectorIterator() { @Override public boolean equals(Object o) { - if (o.getClass() != FloatVectorFieldSource.class) return false; - FloatVectorFieldSource other = (FloatVectorFieldSource) o; + if (o.getClass() != FloatKnnVectorFieldSource.class) return false; + FloatKnnVectorFieldSource other = (FloatKnnVectorFieldSource) o; return fieldName.equals(other.fieldName); } @@ -76,6 +76,6 @@ public int hashCode() { @Override public String description() { - return "denseFloatVectorField(" + fieldName + ")"; + return "FloatKnnVectorFieldSource(" + fieldName + ")"; } } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 60e684c3881e..228d6d597c51 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -27,11 +27,11 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.ByteVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.ConstByteVectorValueSource; -import org.apache.lucene.queries.function.valuesource.ConstFloatVectorValueSource; -import org.apache.lucene.queries.function.valuesource.FloatVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.ConsKnnFloatValueSource; +import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource; +import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; @@ -105,9 +105,9 @@ public static void afterClass() throws Exception { } @Test - public void floatVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() throws Exception { - var v1 = new ConstFloatVectorValueSource(List.of(1, 2, 3)); - var v2 = new ConstFloatVectorValueSource(List.of(5, 4, 1)); + public void vectorSimilarity_floatConstantVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ConsKnnFloatValueSource(List.of(1, 2, 3)); + var v2 = new ConsKnnFloatValueSource(List.of(5, 4, 1)); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -115,9 +115,9 @@ public void floatVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() } @Test - public void byteVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() throws Exception { - var v1 = new ConstByteVectorValueSource(List.of(1, 2, 3)); - var v2 = new ConstByteVectorValueSource(List.of(2, 5, 6)); + public void vectorSimilarity_byteConstantVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 3)); + var v2 = new ConstKnnByteVectorValueSource(List.of(2, 5, 6)); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -125,10 +125,9 @@ public void byteVectorSimilarityBetweenConstVector_shouldReturnCorrectResult() t } @Test - public void floatVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() - throws Exception { - var v1 = new FloatVectorFieldSource("knnFloatField1"); - var v2 = new FloatVectorFieldSource("knnFloatField2"); + public void vectorSimilarity_floatFieldVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new FloatKnnVectorFieldSource("knnFloatField1"); + var v2 = new FloatKnnVectorFieldSource("knnFloatField2"); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -136,9 +135,9 @@ public void floatVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() } @Test - public void byteVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() throws Exception { - var v1 = new ByteVectorFieldSource("knnByteField1"); - var v2 = new ByteVectorFieldSource("knnByteField2"); + public void vectorSimilarity_byteFieldVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ByteKnnVectorFieldSource("knnByteField1"); + var v2 = new ByteKnnVectorFieldSource("knnByteField2"); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -146,10 +145,10 @@ public void byteVectorSimilarityBetweenVectorFields_shouldReturnCorrectResult() } @Test - public void floatVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() + public void vectorSimilarity_FloatConstAndFloatFieldVectors_shouldReturnFloatSimilarity() throws Exception { - var v1 = new ConstFloatVectorValueSource(List.of(1, 2, 4)); - var v2 = new FloatVectorFieldSource("knnFloatField1"); + var v1 = new ConsKnnFloatValueSource(List.of(1, 2, 4)); + var v2 = new FloatKnnVectorFieldSource("knnFloatField1"); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -157,10 +156,10 @@ public void floatVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectR } @Test - public void byteVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectResult() + public void vectorSimilarity_ByteConstAndByteFieldVectors_shouldReturnFloatSimilarity() throws Exception { - var v1 = new ConstByteVectorValueSource(List.of(1, 2, 4)); - var v2 = new ByteVectorFieldSource("knnByteField1"); + var v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 4)); + var v2 = new ByteKnnVectorFieldSource("knnByteField1"); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -168,10 +167,9 @@ public void byteVectorSimilarityBetweenConstAndVectorField_shouldReturnCorrectRe } @Test - public void floatVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() - throws Exception { - var v1 = new ConstFloatVectorValueSource(List.of(2.0, 1.0, 1.0)); - var v2 = new FloatVectorFieldSource("knnFloatField3"); + public void vectorSimilarity_missingFloatVectorField_shouldReturnNaN() throws Exception { + var v1 = new ConsKnnFloatValueSource(List.of(2.0, 1.0, 1.0)); + var v2 = new FloatKnnVectorFieldSource("knnFloatField3"); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -179,10 +177,9 @@ public void floatVectorSimilarityComputedOnDocumentWithMissingFieldVector_should } @Test - public void byteVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldReturnNaN() - throws Exception { - var v1 = new ConstByteVectorValueSource(List.of(2.0, 1.0, 1.0)); - var v2 = new ByteVectorFieldSource("knnByteField3"); + public void vectorSimilarity_missingByteVectorField_shouldReturnNaN() throws Exception { + var v1 = new ConstKnnByteVectorValueSource(List.of(2.0, 1.0, 1.0)); + var v2 = new ByteKnnVectorFieldSource("knnByteField3"); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -190,17 +187,17 @@ public void byteVectorSimilarityComputedOnDocumentWithMissingFieldVector_shouldR } @Test - public void vectorSimilarityBetweenTwoVectorsWithDifferentDimensions_shouldRaiseException() { - ValueSource v1 = new ConstByteVectorValueSource(List.of(1, 2, 3, 4)); - ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); + public void vectorSimilarity_twoVectorsWithDifferentDimensions_shouldRaiseException() { + ValueSource v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 3, 4)); + ValueSource v2 = new ByteKnnVectorFieldSource("knnByteField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( AssertionError.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); - v1 = new ConstFloatVectorValueSource(List.of(1, 2)); - v2 = new FloatVectorFieldSource("knnFloatField1"); + v1 = new ConsKnnFloatValueSource(List.of(1, 2)); + v2 = new FloatKnnVectorFieldSource("knnFloatField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( @@ -209,17 +206,17 @@ public void vectorSimilarityBetweenTwoVectorsWithDifferentDimensions_shouldRaise } @Test - public void vectorSimilarityBetweenByteAndFloatVectors_shouldRaiseException() { - var v1 = new ConstByteVectorValueSource(List.of(1, 2, 3)); - ValueSource v2 = new ByteVectorFieldSource("knnByteField1"); + public void vectorSimilarity_byteAndFloatVectors_shouldRaiseException() { + var v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 3)); + ValueSource v2 = new ByteKnnVectorFieldSource("knnByteField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); - v1 = new ConstByteVectorValueSource(List.of(1, 2, 3)); - v2 = new FloatVectorFieldSource("knnFloatField1"); + v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 3)); + v2 = new FloatKnnVectorFieldSource("knnFloatField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( @@ -228,9 +225,9 @@ public void vectorSimilarityBetweenByteAndFloatVectors_shouldRaiseException() { } @Test - public void vectorFieldValueSourceWithDifferentFieldType_shouldRaiseException() { - ValueSource v1 = new ByteVectorFieldSource("knnByteField1"); - ValueSource v2 = new ByteVectorFieldSource("knnFloatField2"); + public void vectorSimilarity_wrongFieldType_shouldRaiseException() { + ValueSource v1 = new ByteKnnVectorFieldSource("knnByteField1"); + ValueSource v2 = new ByteKnnVectorFieldSource("knnFloatField2"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); @@ -238,8 +235,8 @@ public void vectorFieldValueSourceWithDifferentFieldType_shouldRaiseException() IllegalArgumentException.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); - v1 = new FloatVectorFieldSource("knnByteField1"); - v2 = new FloatVectorFieldSource("knnFloatField2"); + v1 = new FloatKnnVectorFieldSource("knnByteField1"); + v2 = new FloatKnnVectorFieldSource("knnFloatField2"); FloatVectorSimilarityFunction floatVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); From 72dd3fc7635ce0d3eba6e32659c10f09554e5b7e Mon Sep 17 00:00:00 2001 From: Elia Date: Tue, 13 Jun 2023 14:16:29 +0200 Subject: [PATCH 28/35] rename variable --- .../queries/function/valuesource/VectorFieldFunction.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java index 4488a4e40258..de64984249fe 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java @@ -24,18 +24,18 @@ /** An implementation for retrieving {@link FunctionValues} instances for knn vectors fields. */ public abstract class VectorFieldFunction extends FunctionValues { - protected final ValueSource vs; + protected final ValueSource valueSource; int lastDocID; - protected VectorFieldFunction(ValueSource vs) { - this.vs = vs; + protected VectorFieldFunction(ValueSource valueSource) { + this.valueSource = valueSource; } protected abstract DocIdSetIterator getVectorIterator(); @Override public String toString(int doc) throws IOException { - return vs.description() + strVal(doc); + return valueSource.description() + strVal(doc); } @Override From 481f2a0c6d6b1f5261b0472c430706862facab58 Mon Sep 17 00:00:00 2001 From: Elia Date: Tue, 13 Jun 2023 14:26:10 +0200 Subject: [PATCH 29/35] Addressing review --- .../valuesource/ByteVectorSimilarityFunction.java | 4 ++-- .../valuesource/ConstKnnByteVectorValueSource.java | 2 +- ...alueSource.java => ConstKnnFloatValueSource.java} | 12 ++++++------ .../valuesource/FloatVectorSimilarityFunction.java | 5 ++--- .../function/TestKnnVectorSimilarityFunctions.java | 12 ++++++------ 5 files changed, 17 insertions(+), 18 deletions(-) rename lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/{ConsKnnFloatValueSource.java => ConstKnnFloatValueSource.java} (85%) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java index 2b9208ea199c..a17b60b402a4 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -42,9 +42,9 @@ protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOExc return Float.NaN; } - assert f1.byteVectorVal(doc).length == f2.byteVectorVal(doc).length + assert v1.length == v2.length : "Vectors must have the same length"; - return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc)); + return similarityFunction.compare(v1, v2); } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java index 79115eecaa6f..eaa257d01832 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java @@ -26,7 +26,7 @@ /** Function that returns a constant byte vector value for every document. */ public class ConstKnnByteVectorValueSource extends ValueSource { - byte[] vector; + private final byte[] vector; public ConstKnnByteVectorValueSource(List constVector) { this.vector = new byte[constVector.size()]; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConsKnnFloatValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java similarity index 85% rename from lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConsKnnFloatValueSource.java rename to lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java index 45871d2836be..b45826622f42 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConsKnnFloatValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java @@ -25,10 +25,10 @@ import org.apache.lucene.queries.function.ValueSource; /** Function that returns a constant float vector value for every document. */ -public class ConsKnnFloatValueSource extends ValueSource { - float[] vector; +public class ConstKnnFloatValueSource extends ValueSource { + private final float[] vector; - public ConsKnnFloatValueSource(List constVector) { + public ConstKnnFloatValueSource(List constVector) { this.vector = new float[constVector.size()]; for (int i = 0; i < constVector.size(); i++) { vector[i] = constVector.get(i).floatValue(); @@ -58,8 +58,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof ConsKnnFloatValueSource)) return false; - ConsKnnFloatValueSource other = (ConsKnnFloatValueSource) o; + if (!(o instanceof ConstKnnFloatValueSource)) return false; + ConstKnnFloatValueSource other = (ConstKnnFloatValueSource) o; return Arrays.equals(vector, other.vector); } @@ -70,6 +70,6 @@ public int hashCode() { @Override public String description() { - return "ConsKnnFloatValueSource(" + Arrays.toString(vector) + ')'; + return "ConstKnnFloatValueSource(" + Arrays.toString(vector) + ')'; } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java index fd7206ce9990..d3c2527a3f70 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java @@ -42,8 +42,7 @@ protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOExc return Float.NaN; } - assert f1.floatVectorVal(doc).length == f2.floatVectorVal(doc).length - : "Vectors must have the same length"; - return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc)); + assert v1.length == v2.length : "Vectors must have the same length"; + return similarityFunction.compare(v1, v2); } } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 228d6d597c51..f932d31b2d7e 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -29,7 +29,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.ConsKnnFloatValueSource; +import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource; import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource; import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; @@ -106,8 +106,8 @@ public static void afterClass() throws Exception { @Test public void vectorSimilarity_floatConstantVectors_shouldReturnFloatSimilarity() throws Exception { - var v1 = new ConsKnnFloatValueSource(List.of(1, 2, 3)); - var v2 = new ConsKnnFloatValueSource(List.of(5, 4, 1)); + var v1 = new ConstKnnFloatValueSource(List.of(1, 2, 3)); + var v2 = new ConstKnnFloatValueSource(List.of(5, 4, 1)); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -147,7 +147,7 @@ public void vectorSimilarity_byteFieldVectors_shouldReturnFloatSimilarity() thro @Test public void vectorSimilarity_FloatConstAndFloatFieldVectors_shouldReturnFloatSimilarity() throws Exception { - var v1 = new ConsKnnFloatValueSource(List.of(1, 2, 4)); + var v1 = new ConstKnnFloatValueSource(List.of(1, 2, 4)); var v2 = new FloatKnnVectorFieldSource("knnFloatField1"); assertHits( new FunctionQuery( @@ -168,7 +168,7 @@ public void vectorSimilarity_ByteConstAndByteFieldVectors_shouldReturnFloatSimil @Test public void vectorSimilarity_missingFloatVectorField_shouldReturnNaN() throws Exception { - var v1 = new ConsKnnFloatValueSource(List.of(2.0, 1.0, 1.0)); + var v1 = new ConstKnnFloatValueSource(List.of(2.0, 1.0, 1.0)); var v2 = new FloatKnnVectorFieldSource("knnFloatField3"); assertHits( new FunctionQuery( @@ -196,7 +196,7 @@ public void vectorSimilarity_twoVectorsWithDifferentDimensions_shouldRaiseExcept AssertionError.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); - v1 = new ConsKnnFloatValueSource(List.of(1, 2)); + v1 = new ConstKnnFloatValueSource(List.of(1, 2)); v2 = new FloatKnnVectorFieldSource("knnFloatField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); From 901febe400bff2e8bc719e35007e874913a11e79 Mon Sep 17 00:00:00 2001 From: Elia Date: Tue, 13 Jun 2023 14:43:42 +0200 Subject: [PATCH 30/35] tidy --- .../function/valuesource/ByteVectorSimilarityFunction.java | 3 +-- .../queries/function/TestKnnVectorSimilarityFunctions.java | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java index a17b60b402a4..81e57ab94970 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -42,8 +42,7 @@ protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOExc return Float.NaN; } - assert v1.length == v2.length - : "Vectors must have the same length"; + assert v1.length == v2.length : "Vectors must have the same length"; return similarityFunction.compare(v1, v2); } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index f932d31b2d7e..332914a09e3a 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -29,8 +29,8 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource; import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource; +import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource; import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; From 7219987d52922ab571ceca631ee325fdc8846492 Mon Sep 17 00:00:00 2001 From: Elia Date: Tue, 13 Jun 2023 15:46:23 +0200 Subject: [PATCH 31/35] Addressing review --- .../valuesource/ByteKnnVectorFieldSource.java | 8 +++++--- .../valuesource/ConstKnnByteVectorValueSource.java | 3 ++- .../valuesource/ConstKnnFloatValueSource.java | 3 ++- .../valuesource/FloatKnnVectorFieldSource.java | 8 +++++--- .../valuesource/VectorSimilarityFunction.java | 13 +++++++------ 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java index 83cc49c772eb..7f862c67ee24 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Map; +import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; @@ -65,14 +66,15 @@ protected DocIdSetIterator getVectorIterator() { @Override public boolean equals(Object o) { - if (o.getClass() != ByteKnnVectorFieldSource.class) return false; + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; ByteKnnVectorFieldSource other = (ByteKnnVectorFieldSource) o; - return fieldName.equals(other.fieldName); + return Objects.equals(fieldName, other.fieldName); } @Override public int hashCode() { - return getClass().hashCode() * 31 + fieldName.getClass().hashCode(); + return getClass().hashCode() * 31 + Objects.hashCode(fieldName); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java index eaa257d01832..71b025b0490d 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java @@ -58,7 +58,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof ConstKnnByteVectorValueSource)) return false; + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; ConstKnnByteVectorValueSource other = (ConstKnnByteVectorValueSource) o; return Arrays.equals(vector, other.vector); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java index b45826622f42..c9a22a5beb18 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java @@ -58,7 +58,8 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - if (!(o instanceof ConstKnnFloatValueSource)) return false; + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; ConstKnnFloatValueSource other = (ConstKnnFloatValueSource) o; return Arrays.equals(vector, other.vector); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java index 64def5ed67fb..b526e9030aee 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Map; +import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; @@ -64,14 +65,15 @@ protected DocIdSetIterator getVectorIterator() { @Override public boolean equals(Object o) { - if (o.getClass() != FloatKnnVectorFieldSource.class) return false; + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; FloatKnnVectorFieldSource other = (FloatKnnVectorFieldSource) o; - return fieldName.equals(other.fieldName); + return Objects.equals(fieldName, other.fieldName); } @Override public int hashCode() { - return getClass().hashCode() * 31 + fieldName.getClass().hashCode(); + return getClass().hashCode() * 31 + Objects.hashCode(fieldName); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java index ebb1400529f5..3fcb43b07f87 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Map; +import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; @@ -72,17 +73,17 @@ public String toString(int doc) throws IOException { @Override public boolean equals(Object o) { - return o instanceof VectorSimilarityFunction - && similarityFunction.equals(((VectorSimilarityFunction) o).similarityFunction) - && vector1.equals(((VectorSimilarityFunction) o).vector1) - && vector2.equals(((VectorSimilarityFunction) o).vector2); + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return Objects.equals(vector1, ((VectorSimilarityFunction) o).vector1) + && Objects.equals(vector2, ((VectorSimilarityFunction) o).vector2); } @Override public int hashCode() { int h = similarityFunction.hashCode(); - h = 31 * h + vector1.hashCode(); - h = 31 * h + vector2.hashCode(); + h = 31 * h + Objects.hashCode(vector1); + h = 31 * h + Objects.hashCode(vector2); return h; } From bc504291ccc7050c7dea8c52e3403b7282051c89 Mon Sep 17 00:00:00 2001 From: Elia Date: Tue, 13 Jun 2023 16:08:57 +0200 Subject: [PATCH 32/35] Addressing review: fixed hash computations --- .../function/valuesource/ByteKnnVectorFieldSource.java | 2 +- .../function/valuesource/ConstKnnByteVectorValueSource.java | 3 ++- .../function/valuesource/ConstKnnFloatValueSource.java | 3 ++- .../function/valuesource/FloatKnnVectorFieldSource.java | 2 +- .../function/valuesource/VectorSimilarityFunction.java | 5 +---- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java index 7f862c67ee24..c8a4a93a2dfc 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -74,7 +74,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return getClass().hashCode() * 31 + Objects.hashCode(fieldName); + return Objects.hash(getClass().hashCode(), fieldName); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java index 71b025b0490d..446920c64e82 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; @@ -66,7 +67,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return getClass().hashCode() * 31 + Arrays.hashCode(vector); + return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector)); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java index c9a22a5beb18..9a6c45f14136 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; @@ -66,7 +67,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return getClass().hashCode() * 31 + Arrays.hashCode(vector); + return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector)); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java index b526e9030aee..9a1f27a7c79d 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -73,7 +73,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return getClass().hashCode() * 31 + Objects.hashCode(fieldName); + return Objects.hash(getClass().hashCode(), fieldName); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java index 3fcb43b07f87..9ba2d359a568 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java @@ -81,10 +81,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - int h = similarityFunction.hashCode(); - h = 31 * h + Objects.hashCode(vector1); - h = 31 * h + Objects.hashCode(vector2); - return h; + return Objects.hash(similarityFunction, vector1, vector2); } @Override From 1b695514741d26bef22e3a36b588f462d5314fcf Mon Sep 17 00:00:00 2001 From: Elia Date: Tue, 13 Jun 2023 18:17:51 +0200 Subject: [PATCH 33/35] Changed default similarity from NaN to 0.f --- .../valuesource/ByteVectorSimilarityFunction.java | 2 +- .../valuesource/FloatVectorSimilarityFunction.java | 2 +- .../function/TestKnnVectorSimilarityFunctions.java | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java index 81e57ab94970..fb6ec68ee9e5 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -39,7 +39,7 @@ protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOExc var v2 = f2.byteVectorVal(doc); if (v1 == null || v2 == null) { - return Float.NaN; + return 0.f; } assert v1.length == v2.length : "Vectors must have the same length"; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java index d3c2527a3f70..296775388856 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java @@ -39,7 +39,7 @@ protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOExc var v2 = f2.floatVectorVal(doc); if (v1 == null || v2 == null) { - return Float.NaN; + return 0.f; } assert v1.length == v2.length : "Vectors must have the same length"; diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 332914a09e3a..41f2d4d23df0 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -167,23 +167,23 @@ public void vectorSimilarity_ByteConstAndByteFieldVectors_shouldReturnFloatSimil } @Test - public void vectorSimilarity_missingFloatVectorField_shouldReturnNaN() throws Exception { + public void vectorSimilarity_missingFloatVectorField_shouldReturnZero() throws Exception { var v1 = new ConstKnnFloatValueSource(List.of(2.0, 1.0, 1.0)); var v2 = new FloatKnnVectorFieldSource("knnFloatField3"); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), - new float[] {0.5f, Float.NaN}); + new float[] {0.5f, 0.f}); } @Test - public void vectorSimilarity_missingByteVectorField_shouldReturnNaN() throws Exception { + public void vectorSimilarity_missingByteVectorField_shouldReturnZero() throws Exception { var v1 = new ConstKnnByteVectorValueSource(List.of(2.0, 1.0, 1.0)); var v2 = new ByteKnnVectorFieldSource("knnByteField3"); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), - new float[] {0.5f, Float.NaN}); + new float[] {0.5f, 0.f}); } @Test From 8ece485522cfc279cf087e2b18e4a9206a998e09 Mon Sep 17 00:00:00 2001 From: Elia Date: Wed, 14 Jun 2023 10:06:40 +0200 Subject: [PATCH 34/35] Changed constructors using arrays instead of lists --- .../ConstKnnByteVectorValueSource.java | 8 ++----- .../valuesource/ConstKnnFloatValueSource.java | 9 +++---- .../TestKnnVectorSimilarityFunctions.java | 24 +++++++++---------- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java index 446920c64e82..4996e026abee 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.List; import java.util.Map; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; @@ -29,11 +28,8 @@ public class ConstKnnByteVectorValueSource extends ValueSource { private final byte[] vector; - public ConstKnnByteVectorValueSource(List constVector) { - this.vector = new byte[constVector.size()]; - for (int i = 0; i < constVector.size(); i++) { - vector[i] = constVector.get(i).byteValue(); - } + public ConstKnnByteVectorValueSource(byte[] constVector) { + this.vector = Objects.requireNonNull(constVector, "constVector"); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java index 9a6c45f14136..57c016eb793e 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java @@ -18,22 +18,19 @@ import java.io.IOException; import java.util.Arrays; -import java.util.List; import java.util.Map; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.util.VectorUtil; /** Function that returns a constant float vector value for every document. */ public class ConstKnnFloatValueSource extends ValueSource { private final float[] vector; - public ConstKnnFloatValueSource(List constVector) { - this.vector = new float[constVector.size()]; - for (int i = 0; i < constVector.size(); i++) { - vector[i] = constVector.get(i).floatValue(); - } + public ConstKnnFloatValueSource(float[] constVector) { + this.vector = VectorUtil.checkFinite(Objects.requireNonNull(constVector, "constVector")); } @Override diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 41f2d4d23df0..12144b252ba0 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -106,8 +106,8 @@ public static void afterClass() throws Exception { @Test public void vectorSimilarity_floatConstantVectors_shouldReturnFloatSimilarity() throws Exception { - var v1 = new ConstKnnFloatValueSource(List.of(1, 2, 3)); - var v2 = new ConstKnnFloatValueSource(List.of(5, 4, 1)); + var v1 = new ConstKnnFloatValueSource(new float[] {1, 2, 3}); + var v2 = new ConstKnnFloatValueSource(new float[] {5, 4, 1}); assertHits( new FunctionQuery( new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -116,8 +116,8 @@ public void vectorSimilarity_floatConstantVectors_shouldReturnFloatSimilarity() @Test public void vectorSimilarity_byteConstantVectors_shouldReturnFloatSimilarity() throws Exception { - var v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 3)); - var v2 = new ConstKnnByteVectorValueSource(List.of(2, 5, 6)); + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); + var v2 = new ConstKnnByteVectorValueSource(new byte[] {2, 5, 6}); assertHits( new FunctionQuery( new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), @@ -147,7 +147,7 @@ public void vectorSimilarity_byteFieldVectors_shouldReturnFloatSimilarity() thro @Test public void vectorSimilarity_FloatConstAndFloatFieldVectors_shouldReturnFloatSimilarity() throws Exception { - var v1 = new ConstKnnFloatValueSource(List.of(1, 2, 4)); + var v1 = new ConstKnnFloatValueSource(new float[] {1, 2, 4}); var v2 = new FloatKnnVectorFieldSource("knnFloatField1"); assertHits( new FunctionQuery( @@ -158,7 +158,7 @@ public void vectorSimilarity_FloatConstAndFloatFieldVectors_shouldReturnFloatSim @Test public void vectorSimilarity_ByteConstAndByteFieldVectors_shouldReturnFloatSimilarity() throws Exception { - var v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 4)); + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 4}); var v2 = new ByteKnnVectorFieldSource("knnByteField1"); assertHits( new FunctionQuery( @@ -168,7 +168,7 @@ public void vectorSimilarity_ByteConstAndByteFieldVectors_shouldReturnFloatSimil @Test public void vectorSimilarity_missingFloatVectorField_shouldReturnZero() throws Exception { - var v1 = new ConstKnnFloatValueSource(List.of(2.0, 1.0, 1.0)); + var v1 = new ConstKnnFloatValueSource(new float[] {2.f, 1.f, 1.f}); var v2 = new FloatKnnVectorFieldSource("knnFloatField3"); assertHits( new FunctionQuery( @@ -178,7 +178,7 @@ public void vectorSimilarity_missingFloatVectorField_shouldReturnZero() throws E @Test public void vectorSimilarity_missingByteVectorField_shouldReturnZero() throws Exception { - var v1 = new ConstKnnByteVectorValueSource(List.of(2.0, 1.0, 1.0)); + var v1 = new ConstKnnByteVectorValueSource(new byte[] {2, 1, 1}); var v2 = new ByteKnnVectorFieldSource("knnByteField3"); assertHits( new FunctionQuery( @@ -188,7 +188,7 @@ public void vectorSimilarity_missingByteVectorField_shouldReturnZero() throws Ex @Test public void vectorSimilarity_twoVectorsWithDifferentDimensions_shouldRaiseException() { - ValueSource v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 3, 4)); + ValueSource v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3, 4}); ValueSource v2 = new ByteKnnVectorFieldSource("knnByteField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); @@ -196,7 +196,7 @@ public void vectorSimilarity_twoVectorsWithDifferentDimensions_shouldRaiseExcept AssertionError.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); - v1 = new ConstKnnFloatValueSource(List.of(1, 2)); + v1 = new ConstKnnFloatValueSource(new float[] {1.f, 2.f}); v2 = new FloatKnnVectorFieldSource("knnFloatField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); @@ -207,7 +207,7 @@ public void vectorSimilarity_twoVectorsWithDifferentDimensions_shouldRaiseExcept @Test public void vectorSimilarity_byteAndFloatVectors_shouldRaiseException() { - var v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 3)); + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); ValueSource v2 = new ByteKnnVectorFieldSource("knnByteField1"); FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); @@ -215,7 +215,7 @@ public void vectorSimilarity_byteAndFloatVectors_shouldRaiseException() { UnsupportedOperationException.class, () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); - v1 = new ConstKnnByteVectorValueSource(List.of(1, 2, 3)); + v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); v2 = new FloatKnnVectorFieldSource("knnFloatField1"); ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); From cf1e31c6a457e10c519b409c967f162d3f62f4ef Mon Sep 17 00:00:00 2001 From: Elia Date: Wed, 14 Jun 2023 15:06:25 +0200 Subject: [PATCH 35/35] Updated CHANGES.txt --- lucene/CHANGES.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f17590635447..baafaa93a4f6 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -58,6 +58,8 @@ API Changes New Features --------------------- +* GITHUB#12252 Add function queries for computing similarity scores between knn vectors. (Elia Porciani, Alessandro Benedetti) + * LUCENE-10010 Introduce NFARunAutomaton to run NFA directly. (Patrick Zhai) * LUCENE-10626 Hunspell: add tools to aid dictionary editing: