Skip to content

Commit

Permalink
use double math internally when computing similarities (continue to r…
Browse files Browse the repository at this point in the history
…eturn float at the end)
  • Loading branch information
jbellis committed Jun 9, 2023
1 parent 11b7105 commit 6802884
Showing 1 changed file with 25 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ final class VectorUtilDefaultProvider implements VectorUtilProvider {

@Override
public float dotProduct(float[] a, float[] b) {
float res = 0f;
double res = 0.0;
/*
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
* calculation.
Expand All @@ -36,7 +36,7 @@ public float dotProduct(float[] a, float[] b) {
res += b[i] * a[i];
}
if (a.length < 8) {
return res;
return (float) res;
}
for (; i + 31 < a.length; i += 32) {
res +=
Expand Down Expand Up @@ -87,43 +87,45 @@ public float dotProduct(float[] a, float[] b) {
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];
}
checkFinite(res, a, b, "dot product");
return res;
float resF = (float) res;
checkFinite(resF, a, b, "dot product");
return resF;
}

@Override
public float cosine(float[] a, float[] b) {
float sum = 0.0f;
float norm1 = 0.0f;
float norm2 = 0.0f;
double sum = 0.0;
double norm1 = 0.0;
double norm2 = 0.0;
int dim = a.length;

for (int i = 0; i < dim; i++) {
float elem1 = a[i];
float elem2 = b[i];
double elem1 = a[i];
double elem2 = b[i];
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
var r = (float) (sum / Math.sqrt(norm1 * norm2));
float r = (float) (sum / Math.sqrt(norm1 * norm2));
checkFinite(r, a, b, "cosine");
return r;
}

@Override
public float squareDistance(float[] a, float[] b) {
float squareSum = 0.0f;
double squareSum = 0.0;
int dim = a.length;
int i;
for (i = 0; i + 8 <= dim; i += 8) {
squareSum += squareDistanceUnrolled(a, b, i);
}
for (; i < dim; i++) {
float diff = a[i] - b[i];
double diff = a[i] - b[i];
squareSum += diff * diff;
}
checkFinite(squareSum, a, b, "square distance");
return squareSum;
float r = (float) squareSum;
checkFinite(r, a, b, "square distance");
return r;
}

private static void checkFinite(float r, float[] a, float[] b, String optype) {
Expand All @@ -148,15 +150,15 @@ private static void checkFinite(float r, float[] a, float[] b, String optype) {
}
}

private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
float diff0 = v1[index + 0] - v2[index + 0];
float diff1 = v1[index + 1] - v2[index + 1];
float diff2 = v1[index + 2] - v2[index + 2];
float diff3 = v1[index + 3] - v2[index + 3];
float diff4 = v1[index + 4] - v2[index + 4];
float diff5 = v1[index + 5] - v2[index + 5];
float diff6 = v1[index + 6] - v2[index + 6];
float diff7 = v1[index + 7] - v2[index + 7];
private static double squareDistanceUnrolled(float[] v1, float[] v2, int index) {
double diff0 = v1[index + 0] - v2[index + 0];
double diff1 = v1[index + 1] - v2[index + 1];
double diff2 = v1[index + 2] - v2[index + 2];
double diff3 = v1[index + 3] - v2[index + 3];
double diff4 = v1[index + 4] - v2[index + 4];
double diff5 = v1[index + 5] - v2[index + 5];
double diff6 = v1[index + 6] - v2[index + 6];
double diff7 = v1[index + 7] - v2[index + 7];
return diff0 * diff0
+ diff1 * diff1
+ diff2 * diff2
Expand Down

0 comments on commit 6802884

Please sign in to comment.