From 5b5815a26d14201f32ba60a0a133a092ac666032 Mon Sep 17 00:00:00 2001 From: panguixin Date: Mon, 11 Mar 2024 20:07:04 +0800 Subject: [PATCH] Fix NPE when LeafReader return null VectorValues (#13162) ### Description `LeafReader#getXXXVectorValues` may return null value. **Reproduction**: ``` public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { public void testVectorEncodingMismatch() throws IOException { try (Directory indexStore = getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); IndexReader reader = DirectoryReader.open(indexStore)) { AbstractKnnVectorQuery query = new KnnFloatVectorQuery("field", new float[] {0, 1}, 10); IndexSearcher searcher = newSearcher(reader); searcher.search(query, 10); } } } ``` **Output**: ``` java.lang.NullPointerException: Cannot invoke "org.apache.lucene.index.FloatVectorValues.size()" because the return value of "org.apache.lucene.index.LeafReader.getFloatVectorValues(String)" is null ``` --- lucene/CHANGES.txt | 2 ++ .../lucene94/Lucene94HnswVectorsReader.java | 2 +- .../lucene95/Lucene95HnswVectorsReader.java | 2 +- .../lucene99/Lucene99FlatVectorsReader.java | 2 +- .../apache/lucene/index/ByteVectorValues.java | 21 +++++++++++++++++++ .../org/apache/lucene/index/CodecReader.java | 8 +++++-- .../lucene/index/FloatVectorValues.java | 21 +++++++++++++++++++ .../org/apache/lucene/index/LeafReader.java | 12 +++++------ .../lucene/search/AbstractKnnVectorQuery.java | 3 +++ .../search/AbstractVectorSimilarityQuery.java | 5 +++++ .../ByteVectorSimilarityValuesSource.java | 4 ++++ .../lucene/search/FieldExistsQuery.java | 9 ++++---- .../FloatVectorSimilarityValuesSource.java | 4 ++++ .../lucene/search/KnnByteVectorQuery.java | 17 +++++++-------- .../lucene/search/KnnFloatVectorQuery.java | 17 +++++++-------- .../apache/lucene/search/VectorScorer.java | 8 +++++++ .../lucene/search/TestKnnByteVectorQuery.java | 15 +++++++++++++ .../search/TestKnnFloatVectorQuery.java | 14 +++++++++++++ ...iversifyingChildrenByteKnnVectorQuery.java | 15 +++++++------ ...versifyingChildrenFloatKnnVectorQuery.java | 15 +++++++------ ...TestParentBlockJoinByteKnnVectorQuery.java | 21 +++++++++++++++++++ ...estParentBlockJoinFloatKnnVectorQuery.java | 16 ++++++++++++++ 22 files changed, 182 insertions(+), 51 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 2d346c4167d4..efebe3614fb7 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -230,6 +230,8 @@ Bug Fixes * GITHUB#13154: Hunspell GeneratingSuggester: ensure there are never more than 100 roots to process (Peter Gromov) +* GITHUB#13162: Fix NPE when LeafReader return null VectorValues (Pan Guixin) + Other --------------------- diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index 47d56aec192f..1a54fff5913c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -254,7 +254,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " - + VectorEncoding.FLOAT32); + + VectorEncoding.BYTE); } return OffHeapByteVectorValues.load(fieldEntry, vectorData); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index 961a77f1bb11..9c00c5ef93c0 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -270,7 +270,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " - + VectorEncoding.FLOAT32); + + VectorEncoding.BYTE); } return OffHeapByteVectorValues.load( fieldEntry.ordToDocVectorValues, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java index 6f44bcaa545d..365883f18b80 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java @@ -232,7 +232,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " - + VectorEncoding.FLOAT32); + + VectorEncoding.BYTE); } return OffHeapByteVectorValues.load( fieldEntry.ordToDoc, diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index e731e727aa8c..4792a7d84744 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -54,4 +54,25 @@ public final long cost() { * @return the vector value */ public abstract byte[] vectorValue() throws IOException; + + /** + * Checks the Vector Encoding of a field + * + * @throws IllegalStateException if {@code field} has vectors, but using a different encoding + * @lucene.internal + * @lucene.experimental + */ + public static void checkField(LeafReader in, String field) { + FieldInfo fi = in.getFieldInfos().fieldInfo(field); + if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.BYTE) { + throw new IllegalStateException( + "Unexpected vector encoding (" + + fi.getVectorEncoding() + + ") for field " + + field + + "(expected=" + + VectorEncoding.BYTE + + ")"); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java index c5ff5cf271d9..bd4718c9c4c5 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java @@ -246,7 +246,9 @@ public final void searchNearestVectors( String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorDimension() == 0) { + if (fi == null + || fi.getVectorDimension() == 0 + || fi.getVectorEncoding() != VectorEncoding.FLOAT32) { // Field does not exist or does not index vectors return; } @@ -258,7 +260,9 @@ public final void searchNearestVectors( String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorDimension() == 0) { + if (fi == null + || fi.getVectorDimension() == 0 + || fi.getVectorEncoding() != VectorEncoding.BYTE) { // Field does not exist or does not index vectors return; } diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 0c3194bfac8f..61f0157ddfde 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -54,4 +54,25 @@ public final long cost() { * @return the vector value */ public abstract float[] vectorValue() throws IOException; + + /** + * Checks the Vector Encoding of a field + * + * @throws IllegalStateException if {@code field} has vectors, but using a different encoding + * @lucene.internal + * @lucene.experimental + */ + public static void checkField(LeafReader in, String field) { + FieldInfo fi = in.getFieldInfos().fieldInfo(field); + if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.FLOAT32) { + throw new IllegalStateException( + "Unexpected vector encoding (" + + fi.getVectorEncoding() + + ") for field " + + field + + "(expected=" + + VectorEncoding.FLOAT32 + + ")"); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java index df691df1c710..4856feedca3c 100644 --- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java @@ -246,11 +246,11 @@ public final PostingsEnum postings(Term term) throws IOException { public final TopDocs searchNearestVectors( String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { FieldInfo fi = getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorDimension() == 0) { - // The field does not exist or does not index vectors + FloatVectorValues floatVectorValues = getFloatVectorValues(fi.name); + if (floatVectorValues == null) { return TopDocsCollector.EMPTY_TOPDOCS; } - k = Math.min(k, getFloatVectorValues(fi.name).size()); + k = Math.min(k, floatVectorValues.size()); if (k == 0) { return TopDocsCollector.EMPTY_TOPDOCS; } @@ -287,11 +287,11 @@ public final TopDocs searchNearestVectors( public final TopDocs searchNearestVectors( String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { FieldInfo fi = getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorDimension() == 0) { - // The field does not exist or does not index vectors + ByteVectorValues byteVectorValues = getByteVectorValues(fi.name); + if (byteVectorValues == null) { return TopDocsCollector.EMPTY_TOPDOCS; } - k = Math.min(k, getByteVectorValues(fi.name).size()); + k = Math.min(k, byteVectorValues.size()); if (k == 0) { return TopDocsCollector.EMPTY_TOPDOCS; } diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 525156a26660..d713b9151cf5 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -187,6 +187,9 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept } VectorScorer vectorScorer = createVectorScorer(context, fi); + if (vectorScorer == null) { + return NO_RESULTS; + } HitQueue queue = new HitQueue(k, true); ScoreDoc topDoc = queue.top(); int doc; diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java index 03dea82578c1..4aea90cee38d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -105,6 +105,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException { if (filterWeight == null) { // Return exhaustive results TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE); + if (results.scoreDocs.length == 0) { + return null; + } return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); } @@ -148,6 +151,8 @@ protected boolean match(int doc) { createVectorScorer(context), new BitSetIterator(acceptDocs, cardinality), resultSimilarity); + } else if (results.scoreDocs.length == 0) { + return null; } else { // Return an iterator over the collected results return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java index 06afe3d3a54f..89d029ed1409 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java @@ -39,6 +39,10 @@ public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName) { @Override public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName); + if (vectorValues == null) { + ByteVectorValues.checkField(ctx.reader(), fieldName); + return DoubleValues.EMPTY; + } VectorSimilarityFunction function = ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(); return new DoubleValues() { diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java index 537a0f23ce08..8f482874adb6 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java @@ -128,12 +128,13 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { break; } } else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors - int numVectors = + DocIdSetIterator vectorValues = switch (fieldInfo.getVectorEncoding()) { - case FLOAT32 -> leaf.getFloatVectorValues(field).size(); - case BYTE -> leaf.getByteVectorValues(field).size(); + case FLOAT32 -> leaf.getFloatVectorValues(field); + case BYTE -> leaf.getByteVectorValues(field); }; - if (numVectors != leaf.maxDoc()) { + assert vectorValues != null : "unexpected null vector values"; + if (vectorValues != null && vectorValues.cost() != leaf.maxDoc()) { allReadersRewritable = false; break; } diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java index b914ce75e20b..8198467fc508 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java @@ -40,6 +40,10 @@ public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) { @Override public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); + if (vectorValues == null) { + FloatVectorValues.checkField(ctx.reader(), fieldName); + return DoubleValues.EMPTY; + } VectorSimilarityFunction function = ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(); return new DoubleValues() { diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index a07342bff33f..ba94a243e07f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -21,9 +21,9 @@ import java.util.Objects; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; @@ -83,13 +83,13 @@ protected TopDocs approximateSearch( KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorDimension() == 0) { - // The field does not exist or does not index vectors - return TopDocsCollector.EMPTY_TOPDOCS; + ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field); + if (byteVectorValues == null) { + ByteVectorValues.checkField(context.reader(), field); + return NO_RESULTS; } - if (Math.min(knnCollector.k(), context.reader().getByteVectorValues(fi.name).size()) == 0) { - return TopDocsCollector.EMPTY_TOPDOCS; + if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) { + return NO_RESULTS; } context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs); TopDocs results = knnCollector.topDocs(); @@ -98,9 +98,6 @@ protected TopDocs approximateSearch( @Override VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { - if (fi.getVectorEncoding() != VectorEncoding.BYTE) { - return null; - } return VectorScorer.create(context, fi, target); } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 3d8430a45ff4..e6e38192e746 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -22,8 +22,8 @@ import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; @@ -84,13 +84,13 @@ protected TopDocs approximateSearch( KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorDimension() == 0) { - // The field does not exist or does not index vectors - return TopDocsCollector.EMPTY_TOPDOCS; + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field); + if (floatVectorValues == null) { + FloatVectorValues.checkField(context.reader(), field); + return NO_RESULTS; } - if (Math.min(knnCollector.k(), context.reader().getFloatVectorValues(fi.name).size()) == 0) { - return TopDocsCollector.EMPTY_TOPDOCS; + if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) { + return NO_RESULTS; } context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs); TopDocs results = knnCollector.topDocs(); @@ -99,9 +99,6 @@ protected TopDocs approximateSearch( @Override VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { - if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) { - return null; - } return VectorScorer.create(context, fi, target); } diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java index 29b76e53af26..249c7e88d875 100644 --- a/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/VectorScorer.java @@ -41,6 +41,10 @@ abstract class VectorScorer { static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query) throws IOException { FloatVectorValues values = context.reader().getFloatVectorValues(fi.name); + if (values == null) { + FloatVectorValues.checkField(context.reader(), fi.name); + return null; + } final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction(); return new FloatVectorScorer(values, query, similarity); } @@ -48,6 +52,10 @@ static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query) throws IOException { ByteVectorValues values = context.reader().getByteVectorValues(fi.name); + if (values == null) { + ByteVectorValues.checkField(context.reader(), fi.name); + return null; + } VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction(); return new ByteVectorScorer(values, query, similarity); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index 8badf34f2bbe..1b912ae7aad4 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -87,6 +87,21 @@ public void testGetTarget() { assertNotSame(queryVectorBytes, q1.getTargetCopy()); } + public void testVectorEncodingMismatch() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + Query filter = null; + if (random().nextBoolean()) { + filter = new MatchAllDocsQuery(); + } + AbstractKnnVectorQuery query = + new KnnFloatVectorQuery("field", new float[] {0, 1}, 10, filter); + IndexSearcher searcher = newSearcher(reader); + expectThrows(IllegalStateException.class, () -> searcher.search(query, 10)); + } + } + private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 325b65ef9b2c..60969663361d 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -79,6 +79,20 @@ public void testToString() throws IOException { } } + public void testVectorEncodingMismatch() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + Query filter = null; + if (random().nextBoolean()) { + filter = new MatchAllDocsQuery(); + } + AbstractKnnVectorQuery query = new KnnByteVectorQuery("field", new byte[] {0, 1}, 10, filter); + IndexSearcher searcher = newSearcher(reader); + expectThrows(IllegalStateException.class, () -> searcher.search(query, 10)); + } + } + public void testGetTarget() { float[] queryVector = new float[] {0, 1}; KnnFloatVectorQuery q1 = new KnnFloatVectorQuery("f1", queryVector, 10); diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index d3e9deac7a6f..833e36197ccc 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -22,7 +22,6 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; @@ -80,21 +79,21 @@ public DiversifyingChildrenByteKnnVectorQuery( @Override protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) throws IOException { - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorDimension() == 0) { - // The field does not exist or does not index vectors + ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field); + if (byteVectorValues == null) { + ByteVectorValues.checkField(context.reader(), field); return NO_RESULTS; } - if (fi.getVectorEncoding() != VectorEncoding.BYTE) { - return null; - } + BitSet parentBitSet = parentsFilter.getBitSet(context); if (parentBitSet == null) { return NO_RESULTS; } + + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); ParentBlockJoinByteVectorScorer vectorScorer = new ParentBlockJoinByteVectorScorer( - context.reader().getByteVectorValues(field), + byteVectorValues, acceptIterator, parentBitSet, query, diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index 0520f180025c..22877167d6b4 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -22,7 +22,6 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; @@ -80,21 +79,21 @@ public DiversifyingChildrenFloatKnnVectorQuery( @Override protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) throws IOException { - FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); - if (fi == null || fi.getVectorDimension() == 0) { - // The field does not exist or does not index vectors + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field); + if (floatVectorValues == null) { + FloatVectorValues.checkField(context.reader(), field); return NO_RESULTS; } - if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) { - return null; - } + BitSet parentBitSet = parentsFilter.getBitSet(context); if (parentBitSet == null) { return NO_RESULTS; } + + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); DiversifyingChildrenFloatVectorScorer vectorScorer = new DiversifyingChildrenFloatVectorScorer( - context.reader().getFloatVectorValues(field), + floatVectorValues, acceptIterator, parentBitSet, query, diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java index 08c3434e8ede..5b70763d044b 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java @@ -17,10 +17,17 @@ package org.apache.lucene.search.join; +import java.io.IOException; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.Directory; public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase { @@ -46,6 +53,20 @@ Field getKnnVectorField( return new KnnByteVectorField(name, fromFloat(vector), vectorSimilarityFunction); } + public void testVectorEncodingMismatch() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + Query filter = new TermQuery(new Term("other", "value")); + BitSetProducer parentFilter = parentFilter(reader); + Query kvq = + new DiversifyingChildrenFloatKnnVectorQuery( + "field", new float[] {1, 2}, filter, 2, parentFilter); + assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3)); + } + } + private static byte[] fromFloat(float[] queryVector) { byte[] query = new byte[queryVector.length]; for (int i = 0; i < queryVector.length; i++) { diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java index aa4983cb781a..fee519a1e1d6 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java @@ -29,9 +29,11 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase { @@ -47,6 +49,20 @@ Query getParentJoinKnnQuery( fieldName, queryVector, childFilter, k, parentBitSet); } + public void testVectorEncodingMismatch() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + Query filter = new TermQuery(new Term("other", "value")); + BitSetProducer parentFilter = parentFilter(reader); + Query kvq = + new DiversifyingChildrenByteKnnVectorQuery( + "field", new byte[] {1, 2}, filter, 2, parentFilter); + assertThrows(IllegalStateException.class, () -> searcher.search(kvq, 3)); + } + } + public void testScoreCosine() throws IOException { try (Directory d = newDirectory()) { try (IndexWriter w =