Skip to content

Commit

Permalink
Switch from sun.misc.Unsafe to java.nio.ByteBuffer for vectoer (de-)s…
Browse files Browse the repository at this point in the history
…erialization (forwards and backwards compatible) (#608)
  • Loading branch information
alexklibisz authored Dec 3, 2023
1 parent 151a96c commit d39e53c
Show file tree
Hide file tree
Showing 22 changed files with 452 additions and 34 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ lazy val `elastiknn-models` = project

lazy val `elastiknn-jmh-benchmarks` = project
.in(file("elastiknn-jmh-benchmarks"))
.dependsOn(`elastiknn-models`, `elastiknn-api4s`, `elastiknn-lucene`)
.dependsOn(`elastiknn-models` % "compile->compile;compile->test", `elastiknn-api4s`, `elastiknn-lucene`)
.enablePlugins(JmhPlugin)
.settings(
Jmh / javaOptions ++= Seq("--add-modules", "jdk.incubator.vector"),
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The image version (`elasticsearch:A.B.C`) must match the plugin's version (e.g.

```docker
FROM docker.elastic.co/elasticsearch/elasticsearch:8.11.1
RUN elasticsearch-plugin install --batch https://github.com/alexklibisz/elastiknn/releases/download/8.11.1.0/elastiknn-8.11.1.0.zip
RUN elasticsearch-plugin install --batch https://github.com/alexklibisz/elastiknn/releases/download/8.11.1.2/elastiknn-8.11.1.2.zip
```

Build and run the Dockerfile. If you have any issues please refer to the [official docs.](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html)
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/performance/fashion-mnist/plot.b64

Large diffs are not rendered by default.

Binary file modified docs/pages/performance/fashion-mnist/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/pages/performance/fashion-mnist/results.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
|Model|Parameters|Recall|Queries per Second|
|---|---|---|---|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.379|353.162|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|295.007|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|286.531|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|245.690|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|312.826|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.846|265.204|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.921|221.817|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|195.653|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.378|349.851|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.446|296.219|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.635|286.468|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|244.536|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|315.023|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.847|264.479|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|220.714|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|193.597|
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package com.klibisz.elastiknn.jmhbenchmarks

import com.klibisz.elastiknn.storage.{ByteBufferSerialization, UnsafeSerialization}
import org.openjdk.jmh.annotations._

import scala.util.Random

@State(Scope.Benchmark)
class VectorSerializationBenchmarksState {
implicit private val rng: Random = new Random(0)
val floatArray = (0 until 1000).map(_ => rng.nextFloat()).toArray
val floatArraySerialized = ByteBufferSerialization.writeFloats(floatArray)
val intArray = (0 until 1000).map(_ => rng.nextInt()).toArray
val intArraySerialized = ByteBufferSerialization.writeInts(intArray)
val ints = Array(Int.MinValue + 1, Short.MinValue + 1, Byte.MinValue + 1, 0, Byte.MaxValue - 1, Short.MaxValue - 1, Int.MaxValue - 1)
val intsSerialized = ints.map(ByteBufferSerialization.writeInt)
}

class VectorSerializationBenchmarks {

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def writeFloats_Unsafe(state: VectorSerializationBenchmarksState): Array[Byte] = {
UnsafeSerialization.writeFloats(state.floatArray)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def writeFloats_ByteBuffer(state: VectorSerializationBenchmarksState): Array[Byte] = {
ByteBufferSerialization.writeFloats(state.floatArray)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def readFloats_Unsafe(state: VectorSerializationBenchmarksState): Array[Float] = {
UnsafeSerialization.readFloats(state.floatArraySerialized, 0, state.floatArraySerialized.length)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def readFloats_ByteBuffer(state: VectorSerializationBenchmarksState): Array[Float] = {
ByteBufferSerialization.readFloats(state.floatArraySerialized, 0, state.floatArraySerialized.length)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def writeInts_Unsafe(state: VectorSerializationBenchmarksState): Array[Byte] = {
UnsafeSerialization.writeInts(state.intArray)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def writeInts_ByteBuffer(state: VectorSerializationBenchmarksState): Array[Byte] = {
ByteBufferSerialization.writeInts(state.intArray)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def readInts_Unsafe(state: VectorSerializationBenchmarksState): Array[Int] = {
UnsafeSerialization.readInts(state.intArraySerialized, 0, state.intArraySerialized.length)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def readInts_ByteBuffer(state: VectorSerializationBenchmarksState): Array[Int] = {
ByteBufferSerialization.readInts(state.intArraySerialized, 0, state.intArraySerialized.length)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def writeIntsWithPrefix_Unsafe(state: VectorSerializationBenchmarksState): Array[Byte] = {
UnsafeSerialization.writeIntsWithPrefix(state.intArray.length, state.intArray)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def writeIntsWithPrefix_ByteBuffer(state: VectorSerializationBenchmarksState): Array[Byte] = {
ByteBufferSerialization.writeIntsWithPrefix(state.intArray.length, state.intArray)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def writeInt_Unsafe(state: VectorSerializationBenchmarksState): Unit = {
state.ints.foreach(UnsafeSerialization.writeInt)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def writeInt_ByteBuffer(state: VectorSerializationBenchmarksState): Unit = {
state.ints.foreach(ByteBufferSerialization.writeInt)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def readInt_Unsafe(state: VectorSerializationBenchmarksState): Unit = {
state.intsSerialized.foreach(UnsafeSerialization.readInt)
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 1)
@Measurement(time = 5, iterations = 1)
def readInt_ByteBuffer(state: VectorSerializationBenchmarksState): Unit = {
state.intsSerialized.foreach(ByteBufferSerialization.readInt)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.apache.lucene.search

import com.klibisz.elastiknn.lucene.{HashFieldType, LuceneSupport}
import com.klibisz.elastiknn.models.HashAndFreq
import com.klibisz.elastiknn.storage.UnsafeSerialization._
import com.klibisz.elastiknn.storage.ByteBufferSerialization._
import org.apache.lucene.document.{Document, Field, FieldType}
import org.apache.lucene.index._
import org.scalatest._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.klibisz.elastiknn.storage.BitBuffer;
import com.klibisz.elastiknn.vectors.FloatVectorOps;

import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeInt;
import static com.klibisz.elastiknn.storage.ByteBufferSerialization.writeInt;

import java.util.Random;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import java.util.*;
import java.util.stream.Collectors;

import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeInt;
import static com.klibisz.elastiknn.storage.ByteBufferSerialization.writeInt;

public class HammingLshModel implements HashingModel.SparseBool{

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import java.util.Arrays;
import java.util.Random;

import static com.klibisz.elastiknn.storage.UnsafeSerialization.*;
import static com.klibisz.elastiknn.storage.ByteBufferSerialization.*;

public class JaccardLshModel implements HashingModel.SparseBool {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import java.util.*;

import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeInts;
import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeIntsWithPrefix;
import static com.klibisz.elastiknn.storage.ByteBufferSerialization.writeInts;
import static com.klibisz.elastiknn.storage.ByteBufferSerialization.writeIntsWithPrefix;

public class L2LshModel implements HashingModel.DenseFloat {
private final int L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import java.util.PriorityQueue;

import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeInt;
import static com.klibisz.elastiknn.storage.ByteBufferSerialization.writeInt;

public class PermutationLshModel implements HashingModel.DenseFloat {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public void putZero() {

@Override
public byte[] toByteArray() {
byte[] barr = UnsafeSerialization.writeInt(b);
byte[] barr = ByteBufferSerialization.writeInt(b);
byte[] res = new byte[prefix.length + barr.length];
System.arraycopy(prefix, 0, res, 0, prefix.length);
System.arraycopy(barr, 0, res, prefix.length, barr.length);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.klibisz.elastiknn.storage;

import scala.util.control.Exception;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public class ByteBufferSerialization {
public static final int numBytesInInt = 4;
public static final int numBytesInFloat = 4;

public static final ByteOrder byteOrder = ByteOrder.LITTLE_ENDIAN;

public static byte[] writeInt(final int i) {
final int a = Math.abs(i);
if (a <= Byte.MAX_VALUE) {
return new byte[]{(byte) i};
} else if (a <= Short.MAX_VALUE) {
return new byte[]{
(byte) (i & 0xFF),
(byte) ((i >> 8) & 0xFF)
};
} else {
return new byte[]{
(byte) (i & 0xFF),
(byte) ((i >> 8) & 0xFF),
(byte) ((i >> 16) & 0xFF),
(byte) ((i >> 24) & 0xFF),
};
}
}

public static int readInt(final byte[] barr) {
if (barr.length == 1) {
return barr[0];
} else if (barr.length == 2) {
ByteBuffer bb = ByteBuffer.wrap(barr).order(byteOrder);
return bb.getShort();
} else {
ByteBuffer bb = ByteBuffer.wrap(barr).order(byteOrder);
return bb.getInt();
}
}

public static byte[] writeInts(final int[] iarr) {
ByteBuffer bb = ByteBuffer.allocate(iarr.length * numBytesInInt).order(byteOrder);
bb.asIntBuffer().put(iarr);
return bb.array();
}

public static byte[] writeIntsWithPrefix(int prefix, final int[] iarr) {
ByteBuffer bb = ByteBuffer.allocate((iarr.length + 1) * numBytesInInt).order(byteOrder);
bb.asIntBuffer().put(prefix).position(1).put(iarr);
return bb.array();
}

public static int[] readInts(final byte[] barr, final int offset, final int length) {
int[] dst = new int[length / numBytesInInt];
ByteBuffer bb = ByteBuffer.wrap(barr, offset, length).order(byteOrder);
bb.asIntBuffer().get(dst);
return dst;
}

public static byte[] writeFloats(final float[] farr) {
ByteBuffer bb = ByteBuffer.allocate(farr.length * numBytesInFloat).order(byteOrder);
bb.asFloatBuffer().put(farr);
return bb.array();
}

public static float[] readFloats(final byte[] barr, int offset, int length) {
float[] dst = new float[length / numBytesInFloat];
ByteBuffer bb = ByteBuffer.wrap(barr, offset, length).order(byteOrder);
bb.asFloatBuffer().get(dst);
return dst;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ class BitBufferSuite extends AnyFunSuite with Matchers {
ib.putZero() // +0 = 1
ib.putOne() // +4 = 5
ib.putOne() // +8 = 13
ib.toByteArray shouldBe UnsafeSerialization.writeInt(13)
ib.toByteArray shouldBe ByteBufferSerialization.writeInt(13)
}

test("IntBuffer randomized test") {
val rng = new Random(0)
for (_ <- 0 until 100) {
val len = rng.nextInt(32)
val bits = (0 until len).map(_ => rng.nextInt(2))
val prefix = UnsafeSerialization.writeInt(rng.nextInt(Int.MaxValue))
val prefix = ByteBufferSerialization.writeInt(rng.nextInt(Int.MaxValue))
val expected = bits.zipWithIndex
.map { case (b, i) =>
b * math.pow(2, i)
Expand All @@ -30,8 +30,7 @@ class BitBufferSuite extends AnyFunSuite with Matchers {
.toInt
val bitBuf = new BitBuffer.IntBuffer(prefix)
bits.foreach(b => if (b == 0) bitBuf.putZero() else bitBuf.putOne())
bitBuf.toByteArray shouldBe prefix ++ UnsafeSerialization.writeInt(expected)
bitBuf.toByteArray shouldBe prefix ++ ByteBufferSerialization.writeInt(expected)
}
}

}
Loading

0 comments on commit d39e53c

Please sign in to comment.