From 113c4ad5d43652c102b105e02e3bcfeb490822f5 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Sat, 10 Jun 2023 13:38:54 -0500 Subject: [PATCH] revert changes to VUDefaultProvider; add checkFinite to VectorUtil instead, and call from KFVF.createType --- .../lucene/document/KnnFloatVectorField.java | 3 ++ .../org/apache/lucene/util/VectorUtil.java | 38 +++++++++++++++++-- .../util/VectorUtilDefaultProvider.java | 30 +-------------- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index d6673293c720..20e00dc1427e 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -17,6 +17,8 @@ package org.apache.lucene.document; +import static org.apache.lucene.util.VectorUtil.checkFinite; + import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -50,6 +52,7 @@ private static FieldType createType(float[] v, VectorSimilarityFunction similari throw new IllegalArgumentException( "cannot index vectors with dimension greater than " + FloatVectorValues.MAX_DIMENSIONS); } + checkFinite(v); if (similarityFunction == null) { throw new IllegalArgumentException("similarity function must not be null"); } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 068a6edc035b..dc40b3876be9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -17,6 +17,8 @@ package org.apache.lucene.util; +import java.util.Arrays; + /** Utilities for computations with numeric arrays */ public final class VectorUtil { @@ -34,7 +36,9 @@ public static float dotProduct(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.dotProduct(a, b); + float r = PROVIDER.dotProduct(a, b); + checkFinite(r, a, b, "dot product"); + return r; } /** @@ -46,7 +50,9 @@ public static float cosine(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.cosine(a, b); + float r = PROVIDER.cosine(a, b); + checkFinite(r, a, b, "dot product"); + return r; } /** Returns the cosine similarity between the two vectors. */ @@ -66,7 +72,9 @@ public static float squareDistance(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return PROVIDER.squareDistance(a, b); + float r = PROVIDER.squareDistance(a, b); + checkFinite(r, a, b, "square distance"); + return r; } /** Returns the sum of squared differences of the two vectors. */ @@ -154,4 +162,28 @@ public static float dotProductScore(byte[] a, byte[] b) { float denom = (float) (a.length * (1 << 15)); return 0.5f + dotProduct(a, b) / denom; } + + private static void checkFinite(float r, float[] a, float[] b, String optype) { + if (!Float.isFinite(r)) { + checkFinite(a); + checkFinite(b); + throw new IllegalArgumentException( + "Non-finite (" + + r + + ") " + + optype + + " similarity from " + + Arrays.toString(a) + + " and " + + Arrays.toString(b)); + } + } + + public static void checkFinite(float[] a) { + for (int i = 0; i < a.length; i++) { + if (!Float.isFinite(a[i])) { + throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + a[i]); + } + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java index 4f87ffd92d10..da8483ed04de 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java @@ -17,8 +17,6 @@ package org.apache.lucene.util; -import java.util.Arrays; - /** The default VectorUtil provider implementation. */ final class VectorUtilDefaultProvider implements VectorUtilProvider { @@ -87,7 +85,6 @@ public float dotProduct(float[] a, float[] b) { + b[i + 6] * a[i + 6] + b[i + 7] * a[i + 7]; } - checkFinite(res, a, b, "dot product"); return res; } @@ -105,9 +102,7 @@ public float cosine(float[] a, float[] b) { norm1 += elem1 * elem1; norm2 += elem2 * elem2; } - var r = (float) (sum / Math.sqrt(norm1 * norm2)); - checkFinite(r, a, b, "cosine"); - return r; + return (float) (sum / Math.sqrt(norm1 * norm2)); } @Override @@ -122,32 +117,9 @@ public float squareDistance(float[] a, float[] b) { float diff = a[i] - b[i]; squareSum += diff * diff; } - checkFinite(squareSum, a, b, "square distance"); return squareSum; } - private static void checkFinite(float r, float[] a, float[] b, String optype) { - if (!Float.isFinite(r)) { - for (int i = 0; i < a.length; i++) { - if (!Float.isFinite(a[i])) { - throw new IllegalArgumentException("v1[" + i + "]=" + a[i]); - } - if (!Float.isFinite(b[i])) { - throw new IllegalArgumentException("v2[" + i + "]=" + b[i]); - } - } - throw new IllegalArgumentException( - "Non-finite (" - + r - + ") " - + optype - + " similarity from " - + Arrays.toString(a) - + " and " - + Arrays.toString(b)); - } - } - private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { float diff0 = v1[index + 0] - v2[index + 0]; float diff1 = v1[index + 1] - v2[index + 1];