Skip to content

Commit

Permalink
Fix NPE when LeafReader return null VectorValues (#13162)
Browse files Browse the repository at this point in the history
### Description
`LeafReader#getXXXVectorValues` may return null value.

**Reproduction**:
```
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
  public void testVectorEncodingMismatch() throws IOException {
    try (Directory indexStore =
        getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
        IndexReader reader = DirectoryReader.open(indexStore)) {
      AbstractKnnVectorQuery query =
          new KnnFloatVectorQuery("field", new float[] {0, 1}, 10);
      IndexSearcher searcher = newSearcher(reader);
      searcher.search(query, 10);
    }
  }
}
```
**Output**:
```
java.lang.NullPointerException: Cannot invoke "org.apache.lucene.index.FloatVectorValues.size()" because the return value of "org.apache.lucene.index.LeafReader.getFloatVectorValues(String)" is null
```
  • Loading branch information
bugmakerrrrrr authored Mar 11, 2024
1 parent 6445bc0 commit 5b5815a
Show file tree
Hide file tree
Showing 22 changed files with 182 additions and 51 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ Bug Fixes

* GITHUB#13154: Hunspell GeneratingSuggester: ensure there are never more than 100 roots to process (Peter Gromov)

* GITHUB#13162: Fix NPE when LeafReader return null VectorValues (Pan Guixin)

Other
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDocVectorValues,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
Expand Down
21 changes: 21 additions & 0 deletions lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,25 @@ public final long cost() {
* @return the vector value
*/
public abstract byte[] vectorValue() throws IOException;

/**
* Checks the Vector Encoding of a field
*
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
* @lucene.internal
* @lucene.experimental
*/
public static void checkField(LeafReader in, String field) {
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalStateException(
"Unexpected vector encoding ("
+ fi.getVectorEncoding()
+ ") for field "
+ field
+ "(expected="
+ VectorEncoding.BYTE
+ ")");
}
}
}
8 changes: 6 additions & 2 deletions lucene/core/src/java/org/apache/lucene/index/CodecReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ public final void searchNearestVectors(
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
if (fi == null
|| fi.getVectorDimension() == 0
|| fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
// Field does not exist or does not index vectors
return;
}
Expand All @@ -258,7 +260,9 @@ public final void searchNearestVectors(
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
if (fi == null
|| fi.getVectorDimension() == 0
|| fi.getVectorEncoding() != VectorEncoding.BYTE) {
// Field does not exist or does not index vectors
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,25 @@ public final long cost() {
* @return the vector value
*/
public abstract float[] vectorValue() throws IOException;

/**
* Checks the Vector Encoding of a field
*
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
* @lucene.internal
* @lucene.experimental
*/
public static void checkField(LeafReader in, String field) {
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
throw new IllegalStateException(
"Unexpected vector encoding ("
+ fi.getVectorEncoding()
+ ") for field "
+ field
+ "(expected="
+ VectorEncoding.FLOAT32
+ ")");
}
}
}
12 changes: 6 additions & 6 deletions lucene/core/src/java/org/apache/lucene/index/LeafReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,11 @@ public final PostingsEnum postings(Term term) throws IOException {
public final TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
FloatVectorValues floatVectorValues = getFloatVectorValues(fi.name);
if (floatVectorValues == null) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
k = Math.min(k, getFloatVectorValues(fi.name).size());
k = Math.min(k, floatVectorValues.size());
if (k == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
Expand Down Expand Up @@ -287,11 +287,11 @@ public final TopDocs searchNearestVectors(
public final TopDocs searchNearestVectors(
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
ByteVectorValues byteVectorValues = getByteVectorValues(fi.name);
if (byteVectorValues == null) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
k = Math.min(k, getByteVectorValues(fi.name).size());
k = Math.min(k, byteVectorValues.size());
if (k == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
}

VectorScorer vectorScorer = createVectorScorer(context, fi);
if (vectorScorer == null) {
return NO_RESULTS;
}
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
if (filterWeight == null) {
// Return exhaustive results
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
if (results.scoreDocs.length == 0) {
return null;
}
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
}

Expand Down Expand Up @@ -148,6 +151,8 @@ protected boolean match(int doc) {
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
} else if (results.scoreDocs.length == 0) {
return null;
} else {
// Return an iterator over the collected results
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName) {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) {
ByteVectorValues.checkField(ctx.reader(), fieldName);
return DoubleValues.EMPTY;
}
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,13 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
break;
}
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
int numVectors =
DocIdSetIterator vectorValues =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> leaf.getFloatVectorValues(field).size();
case BYTE -> leaf.getByteVectorValues(field).size();
case FLOAT32 -> leaf.getFloatVectorValues(field);
case BYTE -> leaf.getByteVectorValues(field);
};
if (numVectors != leaf.maxDoc()) {
assert vectorValues != null : "unexpected null vector values";
if (vectorValues != null && vectorValues.cost() != leaf.maxDoc()) {
allReadersRewritable = false;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
if (vectorValues == null) {
FloatVectorValues.checkField(ctx.reader(), fieldName);
return DoubleValues.EMPTY;
}
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -83,13 +83,13 @@ protected TopDocs approximateSearch(
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), context.reader().getByteVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
Expand All @@ -98,9 +98,6 @@ protected TopDocs approximateSearch(

@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
return null;
}
return VectorScorer.create(context, fi, target);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -84,13 +84,13 @@ protected TopDocs approximateSearch(
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
if (floatVectorValues == null) {
FloatVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), context.reader().getFloatVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
Expand All @@ -99,9 +99,6 @@ protected TopDocs approximateSearch(

@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
return null;
}
return VectorScorer.create(context, fi, target);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,21 @@ abstract class VectorScorer {
static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
throws IOException {
FloatVectorValues values = context.reader().getFloatVectorValues(fi.name);
if (values == null) {
FloatVectorValues.checkField(context.reader(), fi.name);
return null;
}
final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new FloatVectorScorer(values, query, similarity);
}

static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query)
throws IOException {
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
if (values == null) {
ByteVectorValues.checkField(context.reader(), fi.name);
return null;
}
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new ByteVectorScorer(values, query, similarity);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ public void testGetTarget() {
assertNotSame(queryVectorBytes, q1.getTargetCopy());
}

public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
Query filter = null;
if (random().nextBoolean()) {
filter = new MatchAllDocsQuery();
}
AbstractKnnVectorQuery query =
new KnnFloatVectorQuery("field", new float[] {0, 1}, 10, filter);
IndexSearcher searcher = newSearcher(reader);
expectThrows(IllegalStateException.class, () -> searcher.search(query, 10));
}
}

private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {

public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ public void testToString() throws IOException {
}
}

public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
Query filter = null;
if (random().nextBoolean()) {
filter = new MatchAllDocsQuery();
}
AbstractKnnVectorQuery query = new KnnByteVectorQuery("field", new byte[] {0, 1}, 10, filter);
IndexSearcher searcher = newSearcher(reader);
expectThrows(IllegalStateException.class, () -> searcher.search(query, 10));
}
}

public void testGetTarget() {
float[] queryVector = new float[] {0, 1};
KnnFloatVectorQuery q1 = new KnnFloatVectorQuery("f1", queryVector, 10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
Expand Down Expand Up @@ -80,21 +79,21 @@ public DiversifyingChildrenByteKnnVectorQuery(
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
return null;
}

BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return NO_RESULTS;
}

FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
ParentBlockJoinByteVectorScorer vectorScorer =
new ParentBlockJoinByteVectorScorer(
context.reader().getByteVectorValues(field),
byteVectorValues,
acceptIterator,
parentBitSet,
query,
Expand Down
Loading

0 comments on commit 5b5815a

Please sign in to comment.