From 2cb0e26075559e4ce38d2fa9765bcccaa187ce0d Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Wed, 4 May 2022 18:22:48 -0400 Subject: [PATCH] LUCENE-10504: KnnGraphTester to use KnnVectorQuery (#796) * LUCENE-10504: KnnGraphTester to use KnnVectorQuery --- .../lucene/util/hnsw/KnnGraphTester.java | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java index 0ae7dea51058..1e9fdbbf1260 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java @@ -32,8 +32,10 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.nio.file.attribute.FileTime; import java.util.HashSet; import java.util.Locale; +import java.util.Objects; import java.util.Set; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; @@ -47,7 +49,6 @@ import org.apache.lucene.document.StoredField; import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReader; @@ -55,11 +56,12 @@ import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.PrintStreamInfoStream; @@ -79,7 +81,6 @@ public class KnnGraphTester { private int numDocs; private int dim; private int topK; - private int warmCount; private int numIters; private int fanout; private Path indexPath; @@ -98,7 +99,6 @@ private KnnGraphTester() { numIters = 1000; dim = 256; topK = 100; - warmCount = 1000; fanout = topK; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; } @@ -178,9 +178,6 @@ private void run(String... args) throws Exception { case "-out": outputPath = Paths.get(args[++iarg]); break; - case "-warm": - warmCount = Integer.parseInt(args[++iarg]); - break; case "-docs": docVectorsPath = Paths.get(args[++iarg]); break; @@ -349,8 +346,9 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] TopDocs[] results = new TopDocs[numIters]; long elapsed, totalCpuTime, totalVisited = 0; try (FileChannel q = FileChannel.open(queryPath)) { + int bufferSize = numIters * dim * Float.BYTES; FloatBuffer targets = - q.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES) + q.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize) .order(ByteOrder.LITTLE_ENDIAN) .asFloatBuffer(); float[] target = new float[dim]; @@ -362,18 +360,19 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] long cpuTimeStartNs; try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); numDocs = reader.maxDoc(); - for (int i = 0; i < warmCount; i++) { + for (int i = 0; i < numIters; i++) { // warm up targets.get(target); - results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout); + doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout); } targets.position(0); start = System.nanoTime(); cpuTimeStartNs = bean.getCurrentThreadCpuTime(); for (int i = 0; i < numIters; i++) { targets.get(target); - results[i] = doKnnSearch(reader, KNN_FIELD, target, topK, fanout); + results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout); } totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000; elapsed = (System.nanoTime() - start) / 1_000_000; // ns -> ms @@ -430,19 +429,9 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] } } - private static TopDocs doKnnSearch( - IndexReader reader, String field, float[] vector, int k, int fanout) throws IOException { - TopDocs[] results = new TopDocs[reader.leaves().size()]; - for (LeafReaderContext ctx : reader.leaves()) { - Bits liveDocs = ctx.reader().getLiveDocs(); - results[ctx.ord] = - ctx.reader().searchNearestVectors(field, vector, k + fanout, liveDocs, Integer.MAX_VALUE); - int docBase = ctx.docBase; - for (ScoreDoc scoreDoc : results[ctx.ord].scoreDocs) { - scoreDoc.doc += docBase; - } - } - return TopDocs.merge(k, results); + private static TopDocs doKnnVectorQuery( + IndexSearcher searcher, String field, float[] vector, int k, int fanout) throws IOException { + return searcher.search(new KnnVectorQuery(field, vector, k + fanout), k); } private float checkResults(TopDocs[] results, int[][] nn) { @@ -487,9 +476,10 @@ private int compareNN(int[] expected, TopDocs results) { private int[][] getNN(Path docPath, Path queryPath) throws IOException { // look in working directory for cached nn file - String nnFileName = "nn-" + numDocs + "-" + numIters + "-" + topK + "-" + dim + ".bin"; + String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK), 36); + String nnFileName = "nn-" + hash + ".bin"; Path nnPath = Paths.get(nnFileName); - if (Files.exists(nnPath)) { + if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath)) { return readNN(nnPath); } else { int[][] nn = computeNN(docPath, queryPath); @@ -498,6 +488,16 @@ private int[][] getNN(Path docPath, Path queryPath) throws IOException { } } + private boolean isNewer(Path path, Path... others) throws IOException { + FileTime modified = Files.getLastModifiedTime(path); + for (Path other : others) { + if (Files.getLastModifiedTime(other).compareTo(modified) >= 0) { + return false; + } + } + return true; + } + private int[][] readNN(Path nnPath) throws IOException { int[][] result = new int[numIters][]; try (FileChannel in = FileChannel.open(nnPath)) {