Skip to content

Commit

Permalink
Addressing review
Browse files Browse the repository at this point in the history
  • Loading branch information
eliaporciani committed Jun 13, 2023
1 parent 72dd3fc commit 481f2a0
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOExc
return Float.NaN;
}

assert f1.byteVectorVal(doc).length == f2.byteVectorVal(doc).length
assert v1.length == v2.length
: "Vectors must have the same length";

return similarityFunction.compare(f1.byteVectorVal(doc), f2.byteVectorVal(doc));
return similarityFunction.compare(v1, v2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

/** Function that returns a constant byte vector value for every document. */
public class ConstKnnByteVectorValueSource extends ValueSource {
byte[] vector;
private final byte[] vector;

public ConstKnnByteVectorValueSource(List<Number> constVector) {
this.vector = new byte[constVector.size()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import org.apache.lucene.queries.function.ValueSource;

/** Function that returns a constant float vector value for every document. */
public class ConsKnnFloatValueSource extends ValueSource {
float[] vector;
public class ConstKnnFloatValueSource extends ValueSource {
private final float[] vector;

public ConsKnnFloatValueSource(List<Number> constVector) {
public ConstKnnFloatValueSource(List<Number> constVector) {
this.vector = new float[constVector.size()];
for (int i = 0; i < constVector.size(); i++) {
vector[i] = constVector.get(i).floatValue();
Expand Down Expand Up @@ -58,8 +58,8 @@ public String toString(int doc) throws IOException {

@Override
public boolean equals(Object o) {
if (!(o instanceof ConsKnnFloatValueSource)) return false;
ConsKnnFloatValueSource other = (ConsKnnFloatValueSource) o;
if (!(o instanceof ConstKnnFloatValueSource)) return false;
ConstKnnFloatValueSource other = (ConstKnnFloatValueSource) o;
return Arrays.equals(vector, other.vector);
}

Expand All @@ -70,6 +70,6 @@ public int hashCode() {

@Override
public String description() {
return "ConsKnnFloatValueSource(" + Arrays.toString(vector) + ')';
return "ConstKnnFloatValueSource(" + Arrays.toString(vector) + ')';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOExc
return Float.NaN;
}

assert f1.floatVectorVal(doc).length == f2.floatVectorVal(doc).length
: "Vectors must have the same length";
return similarityFunction.compare(f1.floatVectorVal(doc), f2.floatVectorVal(doc));
assert v1.length == v2.length : "Vectors must have the same length";
return similarityFunction.compare(v1, v2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource;
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.ConsKnnFloatValueSource;
import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource;
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
Expand Down Expand Up @@ -106,8 +106,8 @@ public static void afterClass() throws Exception {

@Test
public void vectorSimilarity_floatConstantVectors_shouldReturnFloatSimilarity() throws Exception {
var v1 = new ConsKnnFloatValueSource(List.of(1, 2, 3));
var v2 = new ConsKnnFloatValueSource(List.of(5, 4, 1));
var v1 = new ConstKnnFloatValueSource(List.of(1, 2, 3));
var v2 = new ConstKnnFloatValueSource(List.of(5, 4, 1));
assertHits(
new FunctionQuery(
new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)),
Expand Down Expand Up @@ -147,7 +147,7 @@ public void vectorSimilarity_byteFieldVectors_shouldReturnFloatSimilarity() thro
@Test
public void vectorSimilarity_FloatConstAndFloatFieldVectors_shouldReturnFloatSimilarity()
throws Exception {
var v1 = new ConsKnnFloatValueSource(List.of(1, 2, 4));
var v1 = new ConstKnnFloatValueSource(List.of(1, 2, 4));
var v2 = new FloatKnnVectorFieldSource("knnFloatField1");
assertHits(
new FunctionQuery(
Expand All @@ -168,7 +168,7 @@ public void vectorSimilarity_ByteConstAndByteFieldVectors_shouldReturnFloatSimil

@Test
public void vectorSimilarity_missingFloatVectorField_shouldReturnNaN() throws Exception {
var v1 = new ConsKnnFloatValueSource(List.of(2.0, 1.0, 1.0));
var v1 = new ConstKnnFloatValueSource(List.of(2.0, 1.0, 1.0));
var v2 = new FloatKnnVectorFieldSource("knnFloatField3");
assertHits(
new FunctionQuery(
Expand Down Expand Up @@ -196,7 +196,7 @@ public void vectorSimilarity_twoVectorsWithDifferentDimensions_shouldRaiseExcept
AssertionError.class,
() -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10));

v1 = new ConsKnnFloatValueSource(List.of(1, 2));
v1 = new ConstKnnFloatValueSource(List.of(1, 2));
v2 = new FloatKnnVectorFieldSource("knnFloatField1");
FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction =
new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2);
Expand Down

0 comments on commit 481f2a0

Please sign in to comment.