Skip to content

Commit

Permalink
Extracted KnnIndexer class out of KnnGraphTester code so that the log…
Browse files Browse the repository at this point in the history
…ic can be reused
  • Loading branch information
nitirajrathore committed Feb 7, 2024
1 parent e2e3197 commit cac5a3e
Show file tree
Hide file tree
Showing 8 changed files with 426 additions and 86 deletions.
2 changes: 1 addition & 1 deletion build.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

<javac srcdir="src/main"
destdir="build"
includes="KnnGraphTester.java,WikiVectors.java,perf/VectorDictionary.java"
includes="knn/KnnGraphTester.java,WikiVectors.java,perf/VectorDictionary.java"
classpathref="build.classpath"
includeantruntime="false"/>

Expand Down
123 changes: 39 additions & 84 deletions src/main/KnnGraphTester.java → src/main/knn/KnnGraphTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
* limitations under the License.
*/

package knn;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
Expand Down Expand Up @@ -88,13 +90,13 @@
/**
* For testing indexing and search performance of a knn-graph
*
* <p>java -cp .../lib/*.jar org.apache.lucene.util.hnsw.KnnGraphTester -ndoc 1000000 -search
* <p>java -cp .../lib/*.jar knn.KnnGraphTester -ndoc 1000000 -search
* .../vectors.bin
*/
public class KnnGraphTester {

private static final String KNN_FIELD = "knn";
private static final String ID_FIELD = "id";
public static final String KNN_FIELD = "knn";
public static final String ID_FIELD = "id";
private static final double WRITER_BUFFER_MB = 1994d;

private int numDocs;
Expand Down Expand Up @@ -231,6 +233,9 @@ private void run(String... args) throws Exception {
case "-docs":
docVectorsPath = Paths.get(args[++iarg]);
break;
case "-indexPath":
indexPath = Paths.get(args[++iarg]);
break;
case "-encoding":
String encoding = args[++iarg];
switch (encoding) {
Expand Down Expand Up @@ -308,12 +313,15 @@ private void run(String... args) throws Exception {
if (prefilter && selectivity == 1f) {
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
}
indexPath = Paths.get(formatIndexPath(docVectorsPath));
if (indexPath == null) {
indexPath = Paths.get(formatIndexPath(docVectorsPath)); // derive index path
}
if (reindex) {
if (docVectorsPath == null) {
throw new IllegalArgumentException("-docs argument is required when indexing");
}
reindexTimeMsec = createIndex(docVectorsPath, indexPath);
reindexTimeMsec = new KnnIndexer(docVectorsPath, indexPath, maxConn, beamWidth, vectorEncoding, dim,
similarityFunction, numDocs, quiet).createIndex();
}
if (forceMerge) {
forceMerge();
Expand Down Expand Up @@ -347,7 +355,7 @@ private String formatIndexPath(Path docsPath) {
@SuppressForbidden(reason = "Prints stuff")
private void printFanoutHist(Path indexPath) throws IOException {
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
DirectoryReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext context : reader.leaves()) {
LeafReader leafReader = context.reader();
KnnVectorsReader vectorsReader =
Expand Down Expand Up @@ -470,7 +478,7 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][]
ThreadMXBean bean = ManagementFactory.getThreadMXBean();
long cpuTimeStartNs;
try (Directory dir = FSDirectory.open(indexPath);
DirectoryReader reader = DirectoryReader.open(dir)) {
DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
numDocs = reader.maxDoc();
Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null;
Expand Down Expand Up @@ -557,7 +565,22 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][]
totalVisited /= numIters;
System.out.printf(
Locale.ROOT,
"%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\t%.2f\t%s\n",
"|%s\t|%s\t|%s\t|%s\t|%s\t|%s\t|%s\t|%s\t|%s\t|%s|\n",
"recall",
"avgCpuTime",
"numDocs",
"fanout",
"maxConn",
"beamWidth",
"totalVisited",
"reindexTimeMsec",
"selectivity",
"prefilter");
System.out.println(
"|---|---|---|---|---|---|---|---|---|---|");
System.out.printf(
Locale.ROOT,
"|%5.3f\t|%5.2f\t|%d\t|%d\t|%d\t|%d\t|%d\t|%d\t|%.2f\t|%s|\n",
recall,
totalCpuTime / (float) numIters,
numDocs,
Expand All @@ -571,75 +594,6 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][]
}
}

private abstract static class VectorReader {
final float[] target;
final ByteBuffer bytes;
final FileChannel input;

static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) {
int bufferSize = dim * vectorEncoding.byteSize;
return switch (vectorEncoding) {
case BYTE -> new VectorReaderByte(input, dim, bufferSize);
case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize);
};
}

VectorReader(FileChannel input, int dim, int bufferSize) {
this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN);
this.input = input;
target = new float[dim];
}

void reset() throws IOException {
input.position(0);
}

protected final void readNext() throws IOException {
this.input.read(bytes);
bytes.position(0);
}

abstract float[] next() throws IOException;
}

private static class VectorReaderFloat32 extends VectorReader {
VectorReaderFloat32(FileChannel input, int dim, int bufferSize) {
super(input, dim, bufferSize);
}

@Override
float[] next() throws IOException {
readNext();
bytes.asFloatBuffer().get(target);
return target;
}
}

private static class VectorReaderByte extends VectorReader {
private final byte[] scratch;

VectorReaderByte(FileChannel input, int dim, int bufferSize) {
super(input, dim, bufferSize);
scratch = new byte[dim];
}

@Override
float[] next() throws IOException {
readNext();
bytes.get(scratch);
for (int i = 0; i < scratch.length; i++) {
target[i] = scratch[i];
}
return target;
}

byte[] nextBytes() throws IOException {
readNext();
bytes.get(scratch);
return scratch;
}
}

private static TopDocs doKnnVectorQuery(
IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
throws IOException {
Expand Down Expand Up @@ -763,7 +717,7 @@ private int[][] computeNN(Path docPath, Path queryPath, VectorEncoding encoding)
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
}
try (FileChannel in = FileChannel.open(docPath);
FileChannel qIn = FileChannel.open(queryPath)) {
FileChannel qIn = FileChannel.open(queryPath)) {
VectorReader docReader = VectorReader.create(in, dim, encoding);
VectorReader queryReader = VectorReader.create(qIn, dim, encoding);
for (int i = 0; i < numIters; i++) {
Expand Down Expand Up @@ -810,7 +764,7 @@ private int createIndex(Path docsPath, Path indexPath) throws IOException {
}
long start = System.nanoTime();
try (FSDirectory dir = FSDirectory.open(indexPath);
IndexWriter iw = new IndexWriter(dir, iwc)) {
IndexWriter iw = new IndexWriter(dir, iwc)) {
try (FileChannel in = FileChannel.open(docsPath)) {
VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding);
for (int i = 0; i < numDocs; i++) {
Expand Down Expand Up @@ -847,17 +801,17 @@ private static Codec getCodec(int maxConn, int beamWidth, ExecutorService exec,
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return quantize ?
new Lucene99HnswScalarQuantizedVectorsFormat(maxConn, beamWidth) :
new Lucene99HnswVectorsFormat(maxConn, beamWidth);
new Lucene99HnswScalarQuantizedVectorsFormat(maxConn, beamWidth) :
new Lucene99HnswVectorsFormat(maxConn, beamWidth);
}
};
} else {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return quantize ?
new Lucene99HnswScalarQuantizedVectorsFormat(maxConn, beamWidth, numMergeWorker, null, exec) :
new Lucene99HnswVectorsFormat(maxConn, beamWidth, numMergeWorker, exec);
new Lucene99HnswScalarQuantizedVectorsFormat(maxConn, beamWidth, numMergeWorker, null, exec) :
new Lucene99HnswVectorsFormat(maxConn, beamWidth, numMergeWorker, exec);
}
};
}
Expand Down Expand Up @@ -927,7 +881,8 @@ public boolean isCacheable(LeafReaderContext ctx) {
}

@Override
public void visit(QueryVisitor visitor) {}
public void visit(QueryVisitor visitor) {
}

@Override
public String toString(String field) {
Expand Down
139 changes: 139 additions & 0 deletions src/main/knn/KnnIndexer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package knn;

import knn.KnnGraphTester;
import knn.VectorReader;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.document.*;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.PrintStreamInfoStream;

import java.io.IOException;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.util.concurrent.TimeUnit;

public class KnnIndexer {

Path docsPath;
Path indexPath;
int maxConn;
int beamWidth;
VectorEncoding vectorEncoding;
int dim;
VectorSimilarityFunction similarityFunction;
int numDocs;
int docsStartIndex;
boolean quiet;

public KnnIndexer(Path docsPath, Path indexPath, int maxConn, int beamWidth, VectorEncoding vectorEncoding, int dim,
VectorSimilarityFunction similarityFunction, int numDocs, boolean quiet) {
this(docsPath, indexPath, maxConn, beamWidth, vectorEncoding, dim, similarityFunction, numDocs, 0, quiet);
}

public KnnIndexer(Path docsPath, Path indexPath, int maxConn, int beamWidth, VectorEncoding vectorEncoding, int dim,
VectorSimilarityFunction similarityFunction, int numDocs, int docsStartIndex, boolean quiet) {
this.docsPath = docsPath;
this.indexPath = indexPath;
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.vectorEncoding = vectorEncoding;
this.dim = dim;
this.similarityFunction = similarityFunction;
this.numDocs = numDocs;
this.docsStartIndex = docsStartIndex;
this.quiet = quiet;
}

public int createIndex() throws IOException {
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
iwc.setCodec(
new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(maxConn, beamWidth);
}
});
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
iwc.setRAMBufferSizeMB(1994d);
iwc.setUseCompoundFile(false);
// iwc.setMaxBufferedDocs(10000);

FieldType fieldType =
switch (vectorEncoding) {
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
};
if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath);
}

if (!indexPath.toFile().exists()) {
indexPath.toFile().mkdirs();
}

long start = System.nanoTime();
try (FSDirectory dir = FSDirectory.open(indexPath);
IndexWriter iw = new IndexWriter(dir, iwc)) {
try (FileChannel in = FileChannel.open(docsPath)) {
if (docsStartIndex > 0) {
seekToStartDoc(in, dim, vectorEncoding, docsStartIndex);
}
VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
switch (vectorEncoding) {
case BYTE -> doc.add(
new KnnByteVectorField(
KnnGraphTester.KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
case FLOAT32 -> doc.add(
new KnnFloatVectorField(KnnGraphTester.KNN_FIELD, vectorReader.next(), fieldType));
}
doc.add(new StoredField(KnnGraphTester.ID_FIELD, i));
iw.addDocument(doc);

if (quiet == false && i % 10000 == 0) {
System.out.println("Done indexing " + (i + 1) + " documents.");
}
}
if (quiet == false) {
System.out.println("Done indexing " + numDocs + " documents; now flush");
}
}
}
long elapsed = System.nanoTime() - start;
if (quiet == false) {
System.out.println(
"Indexed " + numDocs + " documents in " + TimeUnit.NANOSECONDS.toSeconds(elapsed) + "s");
}
return (int) TimeUnit.NANOSECONDS.toMillis(elapsed);
}

private void seekToStartDoc(FileChannel in, int dim, VectorEncoding vectorEncoding, int docsStartIndex) throws IOException {
int startByte = docsStartIndex * dim * vectorEncoding.byteSize;
in.position(startByte);
}
}
Loading

0 comments on commit cac5a3e

Please sign in to comment.