Skip to content

Commit

Permalink
add checkFinite and fix TestExitableDirectoryReader (cosine taken wit…
Browse files Browse the repository at this point in the history
…h 0 is undefined)
  • Loading branch information
jbellis committed Jun 10, 2023
1 parent 90aea15 commit c9b3cb8
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.lucene.util;

import java.util.Arrays;

/** The default VectorUtil provider implementation. */
final class VectorUtilDefaultProvider implements VectorUtilProvider {

Expand Down Expand Up @@ -85,6 +87,7 @@ public float dotProduct(float[] a, float[] b) {
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];
}
checkFinite(res, a, b, "dot product");
return res;
}

Expand All @@ -102,7 +105,9 @@ public float cosine(float[] a, float[] b) {
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt(norm1 * norm2));
var r = (float) (sum / Math.sqrt(norm1 * norm2));
checkFinite(r, a, b, "cosine");
return r;
}

@Override
Expand All @@ -117,9 +122,32 @@ public float squareDistance(float[] a, float[] b) {
float diff = a[i] - b[i];
squareSum += diff * diff;
}
checkFinite(squareSum, a, b, "square distance");
return squareSum;
}

private static void checkFinite(float r, float[] a, float[] b, String optype) {
if (!Float.isFinite(r)) {
for (int i = 0; i < a.length; i++) {
if (!Float.isFinite(a[i])) {
throw new IllegalArgumentException("v1[" + i + "]=" + a[i]);
}
if (!Float.isFinite(b[i])) {
throw new IllegalArgumentException("v2[" + i + "]=" + b[i]);
}
}
throw new IllegalArgumentException(
"Non-finite ("
+ r
+ ") "
+ optype
+ " similarity from "
+ Arrays.toString(a)
+ " and "
+ Arrays.toString(b));
}
}

private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
float diff0 = v1[index + 0] - v2[index + 0];
float diff1 = v1[index + 1] - v2[index + 1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;

/**
* Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running
Expand Down Expand Up @@ -463,13 +464,21 @@ public void testVectorValues() throws IOException {
ExitingReaderException.class,
() ->
leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE));
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, iter);

leaf.searchNearestVectors(
"vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE);
"vector",
TestVectorUtil.randomVector(dimension),
5,
leaf.getLiveDocs(),
Integer.MAX_VALUE);
}

reader.close();
Expand Down
10 changes: 10 additions & 0 deletions lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ public void testCosineThrowsForDimensionMismatch() {
expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v));
}

public void testCosineThrowsForNaN() {
float[] v = {1, 0, Float.NaN}, u = {0, 0, 0};
expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v));
}

public void testCosineThrowsForInfinity() {
float[] v = {1, 0, Float.NEGATIVE_INFINITY}, u = {0, 0, 0};
expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v));
}

public void testNormalize() {
float[] v = randomVector();
v[random().nextInt(v.length)] = 1; // ensure vector is not all zeroes
Expand Down

0 comments on commit c9b3cb8

Please sign in to comment.