Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NPE when LeafReader return null VectorValues #13162

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDocVectorValues,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ VectorEncoding.BYTE);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
Expand Down
21 changes: 21 additions & 0 deletions lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,25 @@ public final long cost() {
* @return the vector value
*/
public abstract byte[] vectorValue() throws IOException;

/**
* Checks the Vector Encoding of a field
*
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
* @lucene.internal
* @lucene.experimental
*/
public static void checkField(LeafReader in, String field) {
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalStateException(
"Unexpected vector encoding ("
+ fi.getVectorEncoding()
+ ") for field "
+ field
+ "(expected="
+ VectorEncoding.BYTE
+ ")");
}
}
}
8 changes: 6 additions & 2 deletions lucene/core/src/java/org/apache/lucene/index/CodecReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ public final void searchNearestVectors(
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
if (fi == null
|| fi.getVectorDimension() == 0
|| fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
// Field does not exist or does not index vectors
return;
}
Expand All @@ -258,7 +260,9 @@ public final void searchNearestVectors(
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
if (fi == null
|| fi.getVectorDimension() == 0
|| fi.getVectorEncoding() != VectorEncoding.BYTE) {
// Field does not exist or does not index vectors
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,25 @@ public final long cost() {
* @return the vector value
*/
public abstract float[] vectorValue() throws IOException;

/**
* Checks the Vector Encoding of a field
*
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
* @lucene.internal
* @lucene.experimental
*/
public static void checkField(LeafReader in, String field) {
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
if (fi != null && fi.hasVectorValues() && fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
throw new IllegalStateException(
"Unexpected vector encoding ("
+ fi.getVectorEncoding()
+ ") for field "
+ field
+ "(expected="
+ VectorEncoding.FLOAT32
+ ")");
}
}
}
14 changes: 8 additions & 6 deletions lucene/core/src/java/org/apache/lucene/index/LeafReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,12 @@ public final PostingsEnum postings(Term term) throws IOException {
public final TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
FloatVectorValues floatVectorValues = getFloatVectorValues(fi.name);
if (floatVectorValues == null) {
FloatVectorValues.checkField(this, field);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The leaf reader here shouldn't throw. Especially since the companion method that accepts a KnnCollector doesn't.

return TopDocsCollector.EMPTY_TOPDOCS;
}
k = Math.min(k, getFloatVectorValues(fi.name).size());
k = Math.min(k, floatVectorValues.size());
if (k == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
Expand Down Expand Up @@ -287,11 +288,12 @@ public final TopDocs searchNearestVectors(
public final TopDocs searchNearestVectors(
String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
ByteVectorValues byteVectorValues = getByteVectorValues(fi.name);
if (byteVectorValues == null) {
ByteVectorValues.checkField(this, field);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The leaf reader here shouldn't throw. Especially since the companion method that accepts a KnnCollector doesn't.

return TopDocsCollector.EMPTY_TOPDOCS;
}
k = Math.min(k, getByteVectorValues(fi.name).size());
k = Math.min(k, byteVectorValues.size());
if (k == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept
}

VectorScorer vectorScorer = createVectorScorer(context, fi);
if (vectorScorer == null) {
return NO_RESULTS;
}
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
if (filterWeight == null) {
// Return exhaustive results
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
if (results.scoreDocs.length == 0) {
return null;
}
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
}

Expand Down Expand Up @@ -148,6 +151,8 @@ protected boolean match(int doc) {
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
} else if (results.scoreDocs.length == 0) {
return null;
} else {
// Return an iterator over the collected results
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName) {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) {
ByteVectorValues.checkField(ctx.reader(), fieldName);
return DoubleValues.EMPTY;
}
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
break;
}
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
int numVectors =
DocIdSetIterator vectorValues =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> leaf.getFloatVectorValues(field).size();
case BYTE -> leaf.getByteVectorValues(field).size();
case FLOAT32 -> leaf.getFloatVectorValues(field);
case BYTE -> leaf.getByteVectorValues(field);
};
if (numVectors != leaf.maxDoc()) {
if (vectorValues != null && vectorValues.cost() != leaf.maxDoc()) {
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
allReadersRewritable = false;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
if (vectorValues == null) {
return DoubleValues.EMPTY;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you throw in the byte similarity source, but not here? We need to be consistent. I think throwing here is acceptable as well (via FloatVectorValues.check).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. I will fix it.

VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -83,13 +83,13 @@ protected TopDocs approximateSearch(
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), context.reader().getByteVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
Expand All @@ -98,9 +98,6 @@ protected TopDocs approximateSearch(

@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
return null;
}
return VectorScorer.create(context, fi, target);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -84,13 +84,13 @@ protected TopDocs approximateSearch(
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
if (floatVectorValues == null) {
FloatVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (Math.min(knnCollector.k(), context.reader().getFloatVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
Expand All @@ -99,9 +99,6 @@ protected TopDocs approximateSearch(

@Override
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
return null;
}
return VectorScorer.create(context, fi, target);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,21 @@ abstract class VectorScorer {
static FloatVectorScorer create(LeafReaderContext context, FieldInfo fi, float[] query)
throws IOException {
FloatVectorValues values = context.reader().getFloatVectorValues(fi.name);
if (values == null) {
FloatVectorValues.checkField(context.reader(), fi.name);
return null;
}
final VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new FloatVectorScorer(values, query, similarity);
}

static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, byte[] query)
throws IOException {
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
if (values == null) {
ByteVectorValues.checkField(context.reader(), fi.name);
return null;
}
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new ByteVectorScorer(values, query, similarity);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ public void testGetTarget() {
assertNotSame(queryVectorBytes, q1.getTargetCopy());
}

public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
Query filter = null;
if (random().nextBoolean()) {
filter = new MatchAllDocsQuery();
}
AbstractKnnVectorQuery query =
new KnnFloatVectorQuery("field", new float[] {0, 1}, 10, filter);
IndexSearcher searcher = newSearcher(reader);
expectThrows(IllegalStateException.class, () -> searcher.search(query, 10));
}
}

private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {

public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ public void testToString() throws IOException {
}
}

public void testVectorEncodingMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
Query filter = null;
if (random().nextBoolean()) {
filter = new MatchAllDocsQuery();
}
AbstractKnnVectorQuery query = new KnnByteVectorQuery("field", new byte[] {0, 1}, 10, filter);
IndexSearcher searcher = newSearcher(reader);
expectThrows(IllegalStateException.class, () -> searcher.search(query, 10));
}
}

public void testGetTarget() {
float[] queryVector = new float[] {0, 1};
KnnFloatVectorQuery q1 = new KnnFloatVectorQuery("f1", queryVector, 10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
Expand Down Expand Up @@ -80,21 +79,21 @@ public DiversifyingChildrenByteKnnVectorQuery(
@Override
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field);
if (byteVectorValues == null) {
ByteVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
return null;
}

BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return NO_RESULTS;
}

FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
ParentBlockJoinByteVectorScorer vectorScorer =
new ParentBlockJoinByteVectorScorer(
context.reader().getByteVectorValues(field),
byteVectorValues,
acceptIterator,
parentBitSet,
query,
Expand Down
Loading