Skip to content

Commit

Permalink
Fix ByteKnnVectorFieldSource & FloatKnnVectorFieldSource to work corr…
Browse files Browse the repository at this point in the history
…ectly when a segment does not contain any docs with vectors (apache#13105)

(cherry picked from commit bf6f386)
  • Loading branch information
hossman committed Feb 26, 2024
1 parent 5683676 commit 8a555eb
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 9 deletions.
4 changes: 3 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ Optimizations

Bug Fixes
---------------------
(No changes)

* GITHUB#13105: Fix ByteKnnVectorFieldSource & FloatKnnVectorFieldSource to work correctly when a segment does not contain
any docs with vectors (hossman)

Other
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -39,11 +41,25 @@ public ByteKnnVectorFieldSource(String fieldName) {
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {

final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName);
final LeafReader reader = readerContext.reader();
final ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);

if (vectorValues == null) {
throw new IllegalArgumentException(
"no byte vector value is indexed for field '" + fieldName + "'");
VectorFieldFunction.checkField(reader, fieldName, VectorEncoding.BYTE);

return new VectorFieldFunction(this) {
private final DocIdSetIterator empty = DocIdSetIterator.empty();

@Override
public byte[] byteVectorVal(int doc) throws IOException {
return null;
}

@Override
protected DocIdSetIterator getVectorIterator() {
return empty;
}
};
}

return new VectorFieldFunction(this) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -39,12 +41,26 @@ public FloatKnnVectorFieldSource(String fieldName) {
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {

final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName);
final LeafReader reader = readerContext.reader();
final FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName);

if (vectorValues == null) {
throw new IllegalArgumentException(
"no float vector value is indexed for field '" + fieldName + "'");
VectorFieldFunction.checkField(reader, fieldName, VectorEncoding.FLOAT32);
return new VectorFieldFunction(this) {
private final DocIdSetIterator empty = DocIdSetIterator.empty();

@Override
public float[] floatVectorVal(int doc) throws IOException {
return null;
}

@Override
protected DocIdSetIterator getVectorIterator() {
return empty;
}
};
}

return new VectorFieldFunction(this) {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;
Expand Down Expand Up @@ -53,4 +56,29 @@ public boolean exists(int doc) throws IOException {
}
return doc == curDocID;
}

/**
* Checks the Vector Encoding of a field
*
* @throws IllegalStateException if {@code field} exists, but was not indexed with vectors.
* @throws IllegalStateException if {@code field} has vectors, but using a different encoding
* @lucene.internal
* @lucene.experimental
*/
static void checkField(LeafReader in, String field, VectorEncoding expectedEncoding) {
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
if (fi != null) {
final VectorEncoding actual = fi.hasVectorValues() ? fi.getVectorEncoding() : null;
if (expectedEncoding != actual) {
throw new IllegalStateException(
"Unexpected vector encoding ("
+ actual
+ ") for field "
+ field
+ "(expected="
+ expectedEncoding
+ ")");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ public static void beforeClass() throws Exception {
document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3}));
iw.addDocument(document);

if (usually(random())) {
iw.commit();
}

Document document2 = new Document();
document2.add(new StringField("id", "2", Field.Store.NO));
document2.add(new SortedDocValuesField("id", new BytesRef("2")));
Expand Down Expand Up @@ -232,7 +236,7 @@ public void vectorSimilarity_wrongFieldType_shouldRaiseException() {
new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);

assertThrows(
IllegalArgumentException.class,
IllegalStateException.class,
() -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10));

v1 = new FloatKnnVectorFieldSource("knnByteField1");
Expand All @@ -241,8 +245,16 @@ public void vectorSimilarity_wrongFieldType_shouldRaiseException() {
new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);

assertThrows(
IllegalArgumentException.class,
IllegalStateException.class,
() -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10));

v1 = new FloatKnnVectorFieldSource("id");
FloatVectorSimilarityFunction idVectorSimilarityFunction =
new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);

assertThrows(
IllegalStateException.class,
() -> searcher.search(new FunctionQuery(idVectorSimilarityFunction), 10));
}

private static void assertHits(Query q, float[] scores) throws Exception {
Expand Down

0 comments on commit 8a555eb

Please sign in to comment.