java -cp .../lib/*.jar org.apache.lucene.util.hnsw.KnnGraphTester -ndoc 1000000 -search
+ * .../vectors.bin
+ */
+public class KnnGraphTester {
+
+ private static final String KNN_FIELD = "knn";
+ private static final String ID_FIELD = "id";
+
+ private int numDocs;
+ private int dim;
+ private int topK;
+ private int numIters;
+ private int fanout;
+ private Path indexPath;
+ private boolean quiet;
+ private boolean reindex;
+ private boolean forceMerge;
+ private int reindexTimeMsec;
+ private int beamWidth;
+ private int maxConn;
+ private VectorSimilarityFunction similarityFunction;
+ private VectorEncoding vectorEncoding;
+ private FixedBitSet matchDocs;
+ private float selectivity;
+ private boolean prefilter;
+
+ private KnnGraphTester() {
+ // set defaults
+ numDocs = 1000;
+ numIters = 1000;
+ dim = 256;
+ topK = 100;
+ fanout = topK;
+ similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
+ vectorEncoding = VectorEncoding.FLOAT32;
+ selectivity = 1f;
+ prefilter = false;
+ }
+
+ public static void main(String... args) throws Exception {
+ new KnnGraphTester().run(args);
+ }
+
+ private void run(String... args) throws Exception {
+ String operation = null;
+ Path docVectorsPath = null, queryPath = null, outputPath = null;
+ for (int iarg = 0; iarg < args.length; iarg++) {
+ String arg = args[iarg];
+ switch (arg) {
+ case "-search":
+ case "-check":
+ case "-stats":
+ case "-dump":
+ if (operation != null) {
+ throw new IllegalArgumentException(
+ "Specify only one operation, not both " + arg + " and " + operation);
+ }
+ operation = arg;
+ if (operation.equals("-search")) {
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException(
+ "Operation " + arg + " requires a following pathname");
+ }
+ queryPath = Paths.get(args[++iarg]);
+ }
+ break;
+ case "-fanout":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-fanout requires a following number");
+ }
+ fanout = Integer.parseInt(args[++iarg]);
+ break;
+ case "-beamWidthIndex":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-beamWidthIndex requires a following number");
+ }
+ beamWidth = Integer.parseInt(args[++iarg]);
+ break;
+ case "-maxConn":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-maxConn requires a following number");
+ }
+ maxConn = Integer.parseInt(args[++iarg]);
+ break;
+ case "-dim":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-dim requires a following number");
+ }
+ dim = Integer.parseInt(args[++iarg]);
+ break;
+ case "-ndoc":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-ndoc requires a following number");
+ }
+ numDocs = Integer.parseInt(args[++iarg]);
+ break;
+ case "-niter":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-niter requires a following number");
+ }
+ numIters = Integer.parseInt(args[++iarg]);
+ break;
+ case "-reindex":
+ reindex = true;
+ break;
+ case "-topK":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-topK requires a following number");
+ }
+ topK = Integer.parseInt(args[++iarg]);
+ break;
+ case "-out":
+ outputPath = Paths.get(args[++iarg]);
+ break;
+ case "-docs":
+ docVectorsPath = Paths.get(args[++iarg]);
+ break;
+ case "-encoding":
+ String encoding = args[++iarg];
+ switch (encoding) {
+ case "byte":
+ vectorEncoding = VectorEncoding.BYTE;
+ break;
+ case "float32":
+ vectorEncoding = VectorEncoding.FLOAT32;
+ break;
+ default:
+ throw new IllegalArgumentException("-encoding can be 'byte' or 'float32' only");
+ }
+ break;
+ case "-metric":
+ String metric = args[++iarg];
+ switch (metric) {
+ case "euclidean":
+ similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+ break;
+ case "angular":
+ similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
+ break;
+ default:
+ throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
+ }
+ break;
+ case "-forceMerge":
+ forceMerge = true;
+ break;
+ case "-prefilter":
+ prefilter = true;
+ break;
+ case "-filterSelectivity":
+ if (iarg == args.length - 1) {
+ throw new IllegalArgumentException("-filterSelectivity requires a following float");
+ }
+ selectivity = Float.parseFloat(args[++iarg]);
+ if (selectivity <= 0 || selectivity >= 1) {
+ throw new IllegalArgumentException("-filterSelectivity must be between 0 and 1");
+ }
+ break;
+ case "-quiet":
+ quiet = true;
+ break;
+ default:
+ throw new IllegalArgumentException("unknown argument " + arg);
+ // usage();
+ }
+ }
+ if (operation == null && reindex == false) {
+ usage();
+ }
+ if (prefilter && selectivity == 1f) {
+ throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
+ }
+ indexPath = Paths.get(formatIndexPath(docVectorsPath));
+ if (reindex) {
+ if (docVectorsPath == null) {
+ throw new IllegalArgumentException("-docs argument is required when indexing");
+ }
+ reindexTimeMsec = createIndex(docVectorsPath, indexPath);
+ if (forceMerge) {
+ forceMerge();
+ }
+ }
+ if (operation != null) {
+ switch (operation) {
+ case "-search":
+ if (docVectorsPath == null) {
+ throw new IllegalArgumentException("missing -docs arg");
+ }
+ if (selectivity < 1) {
+ matchDocs = generateRandomBitSet(numDocs, selectivity);
+ }
+ if (outputPath != null) {
+ testSearch(indexPath, queryPath, outputPath, null);
+ } else {
+ testSearch(indexPath, queryPath, null, getNN(docVectorsPath, queryPath));
+ }
+ break;
+ case "-stats":
+ printFanoutHist(indexPath);
+ break;
+ }
+ }
+ }
+
+ private String formatIndexPath(Path docsPath) {
+ return docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + ".index";
+ }
+
+ @SuppressForbidden(reason = "Prints stuff")
+ private void printFanoutHist(Path indexPath) throws IOException {
+ try (Directory dir = FSDirectory.open(indexPath);
+ DirectoryReader reader = DirectoryReader.open(dir)) {
+ for (LeafReaderContext context : reader.leaves()) {
+ LeafReader leafReader = context.reader();
+ KnnVectorsReader vectorsReader =
+ ((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) leafReader).getVectorReader())
+ .getFieldReader(KNN_FIELD);
+ HnswGraph knnValues = ((Lucene95HnswVectorsReader) vectorsReader).getGraph(KNN_FIELD);
+ System.out.printf("Leaf %d has %d documents\n", context.ord, leafReader.maxDoc());
+ printGraphFanout(knnValues, leafReader.maxDoc());
+ }
+ }
+ }
+
+ @SuppressForbidden(reason = "Prints stuff")
+ private void forceMerge() throws IOException {
+ IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND);
+ iwc.setInfoStream(new PrintStreamInfoStream(System.out));
+ System.out.println("Force merge index in " + indexPath);
+ try (IndexWriter iw = new IndexWriter(FSDirectory.open(indexPath), iwc)) {
+ iw.forceMerge(1);
+ }
+ }
+
+ @SuppressForbidden(reason = "Prints stuff")
+ private void printGraphFanout(HnswGraph knnValues, int numDocs) throws IOException {
+ int min = Integer.MAX_VALUE, max = 0, total = 0;
+ int count = 0;
+ int[] leafHist = new int[numDocs];
+ for (int node = 0; node < numDocs; node++) {
+ knnValues.seek(0, node);
+ int n = 0;
+ while (knnValues.nextNeighbor() != NO_MORE_DOCS) {
+ ++n;
+ }
+ ++leafHist[n];
+ max = Math.max(max, n);
+ min = Math.min(min, n);
+ if (n > 0) {
+ ++count;
+ total += n;
+ }
+ }
+ System.out.printf(
+ "Graph size=%d, Fanout min=%d, mean=%.2f, max=%d\n",
+ count, min, total / (float) count, max);
+ printHist(leafHist, max, count, 10);
+ }
+
+ @SuppressForbidden(reason = "Prints stuff")
+ private void printHist(int[] hist, int max, int count, int nbuckets) {
+ System.out.print("%");
+ for (int i = 0; i <= nbuckets; i++) {
+ System.out.printf("%4d", i * 100 / nbuckets);
+ }
+ System.out.printf("\n %4d", hist[0]);
+ int total = 0, ibucket = 1;
+ for (int i = 1; i <= max && ibucket <= nbuckets; i++) {
+ total += hist[i];
+ while (total >= count * ibucket / nbuckets) {
+ System.out.printf("%4d", i);
+ ++ibucket;
+ }
+ }
+ System.out.println();
+ }
+
+ @SuppressForbidden(reason = "Prints stuff")
+ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] nn)
+ throws IOException {
+ TopDocs[] results = new TopDocs[numIters];
+ long elapsed, totalCpuTime, totalVisited = 0;
+ try (FileChannel input = FileChannel.open(queryPath)) {
+ VectorReader targetReader = VectorReader.create(input, dim, vectorEncoding);
+ if (quiet == false) {
+ System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
+ }
+ long start;
+ ThreadMXBean bean = ManagementFactory.getThreadMXBean();
+ long cpuTimeStartNs;
+ try (Directory dir = FSDirectory.open(indexPath);
+ DirectoryReader reader = DirectoryReader.open(dir)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ numDocs = reader.maxDoc();
+ Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null;
+ for (int i = 0; i < numIters; i++) {
+ // warm up
+ float[] target = targetReader.next();
+ if (prefilter) {
+ doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
+ } else {
+ doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
+ }
+ }
+ targetReader.reset();
+ start = System.nanoTime();
+ cpuTimeStartNs = bean.getCurrentThreadCpuTime();
+ for (int i = 0; i < numIters; i++) {
+ float[] target = targetReader.next();
+ if (prefilter) {
+ results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
+ } else {
+ results[i] =
+ doKnnVectorQuery(
+ searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
+
+ if (matchDocs != null) {
+ results[i].scoreDocs =
+ Arrays.stream(results[i].scoreDocs)
+ .filter(scoreDoc -> matchDocs.get(scoreDoc.doc))
+ .toArray(ScoreDoc[]::new);
+ }
+ }
+ }
+ totalCpuTime =
+ TimeUnit.NANOSECONDS.toMillis(bean.getCurrentThreadCpuTime() - cpuTimeStartNs);
+ elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); // ns -> ms
+ StoredFields storedFields = reader.storedFields();
+ for (int i = 0; i < numIters; i++) {
+ totalVisited += results[i].totalHits.value;
+ for (ScoreDoc doc : results[i].scoreDocs) {
+ if (doc.doc != NO_MORE_DOCS) {
+ // there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens
+ // in some degenerate case (like input query has NaN in it?) that causes no results to
+ // be returned from HNSW search?
+ doc.doc = Integer.parseInt(storedFields.document(doc.doc).get("id"));
+ } else {
+ System.out.println("NO_MORE_DOCS!");
+ }
+ }
+ }
+ }
+ if (quiet == false) {
+ System.out.println(
+ "completed "
+ + numIters
+ + " searches in "
+ + elapsed
+ + " ms: "
+ + ((1000 * numIters) / elapsed)
+ + " QPS "
+ + "CPU time="
+ + totalCpuTime
+ + "ms");
+ }
+ }
+ if (outputPath != null) {
+ ByteBuffer buf = ByteBuffer.allocate(4);
+ IntBuffer ibuf = buf.order(ByteOrder.LITTLE_ENDIAN).asIntBuffer();
+ try (OutputStream out = Files.newOutputStream(outputPath)) {
+ for (int i = 0; i < numIters; i++) {
+ for (ScoreDoc doc : results[i].scoreDocs) {
+ ibuf.position(0);
+ ibuf.put(doc.doc);
+ out.write(buf.array());
+ }
+ }
+ }
+ } else {
+ if (quiet == false) {
+ System.out.println("checking results");
+ }
+ float recall = checkResults(results, nn);
+ totalVisited /= numIters;
+ System.out.printf(
+ Locale.ROOT,
+ "%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\t%.2f\t%s\n",
+ recall,
+ totalCpuTime / (float) numIters,
+ numDocs,
+ fanout,
+ maxConn,
+ beamWidth,
+ totalVisited,
+ reindexTimeMsec,
+ selectivity,
+ prefilter ? "pre-filter" : "post-filter");
+ }
+ }
+
+ private abstract static class VectorReader {
+ final float[] target;
+ final ByteBuffer bytes;
+ final FileChannel input;
+
+ static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) {
+ int bufferSize = dim * vectorEncoding.byteSize;
+ return switch (vectorEncoding) {
+ case BYTE -> new VectorReaderByte(input, dim, bufferSize);
+ case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize);
+ };
+ }
+
+ VectorReader(FileChannel input, int dim, int bufferSize) {
+ this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN);
+ this.input = input;
+ target = new float[dim];
+ }
+
+ void reset() throws IOException {
+ input.position(0);
+ }
+
+ protected final void readNext() throws IOException {
+ this.input.read(bytes);
+ bytes.position(0);
+ }
+
+ abstract float[] next() throws IOException;
+ }
+
+ private static class VectorReaderFloat32 extends VectorReader {
+ VectorReaderFloat32(FileChannel input, int dim, int bufferSize) {
+ super(input, dim, bufferSize);
+ }
+
+ @Override
+ float[] next() throws IOException {
+ readNext();
+ bytes.asFloatBuffer().get(target);
+ return target;
+ }
+ }
+
+ private static class VectorReaderByte extends VectorReader {
+ private final byte[] scratch;
+
+ VectorReaderByte(FileChannel input, int dim, int bufferSize) {
+ super(input, dim, bufferSize);
+ scratch = new byte[dim];
+ }
+
+ @Override
+ float[] next() throws IOException {
+ readNext();
+ bytes.get(scratch);
+ for (int i = 0; i < scratch.length; i++) {
+ target[i] = scratch[i];
+ }
+ return target;
+ }
+
+ byte[] nextBytes() throws IOException {
+ readNext();
+ bytes.get(scratch);
+ return scratch;
+ }
+ }
+
+ private static TopDocs doKnnVectorQuery(
+ IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
+ throws IOException {
+ return searcher.search(new KnnFloatVectorQuery(field, vector, k + fanout, filter), k);
+ }
+
+ private float checkResults(TopDocs[] results, int[][] nn) {
+ int totalMatches = 0;
+ int totalResults = results.length * topK;
+ for (int i = 0; i < results.length; i++) {
+ // System.out.println(Arrays.toString(nn[i]));
+ // System.out.println(Arrays.toString(results[i].scoreDocs));
+ totalMatches += compareNN(nn[i], results[i]);
+ }
+ return totalMatches / (float) totalResults;
+ }
+
+ private int compareNN(int[] expected, TopDocs results) {
+ int matched = 0;
+ /*
+ System.out.print("expected=");
+ for (int j = 0; j < expected.length; j++) {
+ System.out.print(expected[j]);
+ System.out.print(", ");
+ }
+ System.out.print('\n');
+ System.out.println("results=");
+ for (int j = 0; j < results.scoreDocs.length; j++) {
+ System.out.print("" + results.scoreDocs[j].doc + ":" + results.scoreDocs[j].score + ", ");
+ }
+ System.out.print('\n');
+ */
+ Set