diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 42bb06993dc9..4e530357fafc 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -206,7 +206,9 @@ Optimizations Bug Fixes --------------------- -(No changes) + +* GITHUB#13105: Fix ByteKnnVectorFieldSource & FloatKnnVectorFieldSource to work correctly when a segment does not contain + any docs with vectors (hossman) Other --------------------- diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java index c8a4a93a2dfc..32517496d542 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -20,7 +20,9 @@ import java.util.Map; import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; @@ -39,11 +41,25 @@ public ByteKnnVectorFieldSource(String fieldName) { public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { - final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); + final LeafReader reader = readerContext.reader(); + final ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName); if (vectorValues == null) { - throw new IllegalArgumentException( - "no byte vector value is indexed for field '" + fieldName + "'"); + VectorFieldFunction.checkField(reader, fieldName, VectorEncoding.BYTE); + + return new VectorFieldFunction(this) { + private final DocIdSetIterator empty = DocIdSetIterator.empty(); + + @Override + public byte[] byteVectorVal(int doc) throws IOException { + return null; + } + + @Override + protected DocIdSetIterator getVectorIterator() { + return empty; + } + }; } return new VectorFieldFunction(this) { diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java index 9a1f27a7c79d..43cc3aff880e 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -20,7 +20,9 @@ import java.util.Map; import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; @@ -39,12 +41,26 @@ public FloatKnnVectorFieldSource(String fieldName) { public FunctionValues getValues(Map context, LeafReaderContext readerContext) throws IOException { - final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); + final LeafReader reader = readerContext.reader(); + final FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName); if (vectorValues == null) { - throw new IllegalArgumentException( - "no float vector value is indexed for field '" + fieldName + "'"); + VectorFieldFunction.checkField(reader, fieldName, VectorEncoding.FLOAT32); + return new VectorFieldFunction(this) { + private final DocIdSetIterator empty = DocIdSetIterator.empty(); + + @Override + public float[] floatVectorVal(int doc) throws IOException { + return null; + } + + @Override + protected DocIdSetIterator getVectorIterator() { + return empty; + } + }; } + return new VectorFieldFunction(this) { @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java index de64984249fe..aace2795e525 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java @@ -17,6 +17,9 @@ package org.apache.lucene.queries.function.valuesource; import java.io.IOException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.search.DocIdSetIterator; @@ -53,4 +56,29 @@ public boolean exists(int doc) throws IOException { } return doc == curDocID; } + + /** + * Checks the Vector Encoding of a field + * + * @throws IllegalStateException if {@code field} exists, but was not indexed with vectors. + * @throws IllegalStateException if {@code field} has vectors, but using a different encoding + * @lucene.internal + * @lucene.experimental + */ + static void checkField(LeafReader in, String field, VectorEncoding expectedEncoding) { + FieldInfo fi = in.getFieldInfos().fieldInfo(field); + if (fi != null) { + final VectorEncoding actual = fi.hasVectorValues() ? fi.getVectorEncoding() : null; + if (expectedEncoding != actual) { + throw new IllegalStateException( + "Unexpected vector encoding (" + + actual + + ") for field " + + field + + "(expected=" + + expectedEncoding + + ")"); + } + } + } } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java index 12144b252ba0..df13582259cc 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -78,6 +78,10 @@ public static void beforeClass() throws Exception { document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); iw.addDocument(document); + if (usually(random())) { + iw.commit(); + } + Document document2 = new Document(); document2.add(new StringField("id", "2", Field.Store.NO)); document2.add(new SortedDocValuesField("id", new BytesRef("2"))); @@ -232,7 +236,7 @@ public void vectorSimilarity_wrongFieldType_shouldRaiseException() { new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - IllegalArgumentException.class, + IllegalStateException.class, () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); v1 = new FloatKnnVectorFieldSource("knnByteField1"); @@ -241,8 +245,16 @@ public void vectorSimilarity_wrongFieldType_shouldRaiseException() { new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); assertThrows( - IllegalArgumentException.class, + IllegalStateException.class, () -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10)); + + v1 = new FloatKnnVectorFieldSource("id"); + FloatVectorSimilarityFunction idVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + + assertThrows( + IllegalStateException.class, + () -> searcher.search(new FunctionQuery(idVectorSimilarityFunction), 10)); } private static void assertHits(Query q, float[] scores) throws Exception {