Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LUCENE-10057: Use Lucene abstractions to store KnnVectorDict #252

Merged
merged 3 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions lucene/demo/src/java/org/apache/lucene/demo/IndexFiles.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.IOUtils;

/**
* Index all text files under a directory.
Expand All @@ -55,17 +56,18 @@
* command-line arguments for usage information.
*/
public class IndexFiles implements AutoCloseable {
static final String KNN_DICT = "knn-dict";

// Calculates embedding vectors for KnnVector search
private final DemoEmbeddings demoEmbeddings;
private final KnnVectorDict vectorDict;

private IndexFiles(Path vectorDictPath) throws IOException {
if (vectorDictPath != null) {
vectorDict = new KnnVectorDict(vectorDictPath);
private IndexFiles(KnnVectorDict vectorDict) throws IOException {
if (vectorDict != null) {
this.vectorDict = vectorDict;
demoEmbeddings = new DemoEmbeddings(vectorDict);
} else {
vectorDict = null;
this.vectorDict = null;
demoEmbeddings = null;
}
}
Expand All @@ -80,7 +82,7 @@ public static void main(String[] args) throws Exception {
+ "IF DICT_PATH contains a KnnVector dictionary, the index will also support KnnVector search";
String indexPath = "index";
String docsPath = null;
Path vectorDictPath = null;
String vectorDictSource = null;
boolean create = true;
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
Expand All @@ -91,7 +93,7 @@ public static void main(String[] args) throws Exception {
docsPath = args[++i];
break;
case "-knn_dict":
vectorDictPath = Paths.get(args[++i]);
vectorDictSource = args[++i];
break;
case "-update":
create = false;
Expand Down Expand Up @@ -142,8 +144,16 @@ public static void main(String[] args) throws Exception {
//
// iwc.setRAMBufferSizeMB(256.0);

KnnVectorDict vectorDictInstance = null;
long vectorDictSize = 0;
if (vectorDictSource != null) {
KnnVectorDict.build(Paths.get(vectorDictSource), dir, KNN_DICT);
vectorDictInstance = new KnnVectorDict(dir, KNN_DICT);
vectorDictSize = vectorDictInstance.ramBytesUsed();
}

try (IndexWriter writer = new IndexWriter(dir, iwc);
IndexFiles indexFiles = new IndexFiles(vectorDictPath)) {
IndexFiles indexFiles = new IndexFiles(vectorDictInstance)) {
indexFiles.indexDocs(writer, docDir);

// NOTE: if you want to maximize search performance,
Expand All @@ -153,6 +163,8 @@ public static void main(String[] args) throws Exception {
// you're done adding documents to it):
//
// writer.forceMerge(1);
} finally {
IOUtils.close(vectorDictInstance);
}

Date end = new Date();
Expand All @@ -163,6 +175,10 @@ public static void main(String[] args) throws Exception {
+ " documents in "
+ (end.getTime() - start.getTime())
+ " milliseconds");
if (reader.numDocs() > 100 && vectorDictSize < 1_000_000) {
throw new RuntimeException(
"Are you (ab)using the toy vector dictionary? See the package javadocs to understand why you got this exception.");
}
}
} catch (IOException e) {
System.out.println(" caught a " + e.getClass() + "\n with message: " + e.getMessage());
Expand Down Expand Up @@ -263,8 +279,6 @@ void indexDoc(IndexWriter writer, Path file, long lastModified) throws IOExcepti

@Override
public void close() throws IOException {
if (vectorDict != null) {
vectorDict.close();
}
IOUtils.close(vectorDict);
}
}
5 changes: 2 additions & 3 deletions lucene/demo/src/java/org/apache/lucene/demo/SearchFiles.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.lucene.demo.knn.KnnVectorDict;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.BooleanClause;
Expand Down Expand Up @@ -103,12 +102,12 @@ public static void main(String[] args) throws Exception {
}
}

IndexReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index)));
DirectoryReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index)));
IndexSearcher searcher = new IndexSearcher(reader);
Analyzer analyzer = new StandardAnalyzer();
KnnVectorDict vectorDict = null;
if (knnVectors > 0) {
vectorDict = new KnnVectorDict(Paths.get(index).resolve("knn-dict"));
vectorDict = new KnnVectorDict(reader.directory(), IndexFiles.KNN_DICT);
}
BufferedReader in;
if (queries != null) {
Expand Down
85 changes: 43 additions & 42 deletions lucene/demo/src/java/org/apache/lucene/demo/knn/KnnVectorDict.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
package org.apache.lucene.demo.knn;

import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.regex.Pattern;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.VectorUtil;
Expand All @@ -40,32 +42,29 @@
* Manages a map from token to numeric vector for use with KnnVector indexing and search. The map is
* stored as an FST: token-to-ordinal plus a dense binary file holding the vectors.
*/
public class KnnVectorDict implements AutoCloseable {
public class KnnVectorDict implements Closeable {

private final FST<Long> fst;
private final FileChannel vectors;
private final ByteBuffer vbuffer;
private final IndexInput vectors;
private final int dimension;

/**
* Sole constructor
*
* @param knnDictPath the base path name of the files that will store the KnnVectorDict. The file
* with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the '.bin'
* file.
* @param directory Lucene directory from which knn directory should be read.
* @param dictName the base name of the directory files that store the knn vector dictionary. A
* file with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the
* '.bin' file.
*/
public KnnVectorDict(Path knnDictPath) throws IOException {
String dictName = knnDictPath.getFileName().toString();
Path fstPath = knnDictPath.resolveSibling(dictName + ".fst");
Path binPath = knnDictPath.resolveSibling(dictName + ".bin");
fst = FST.read(fstPath, PositiveIntOutputs.getSingleton());
vectors = FileChannel.open(binPath);
long size = vectors.size();
if (size > Integer.MAX_VALUE) {
throw new IllegalArgumentException("vector file is too large: " + size + " bytes");
public KnnVectorDict(Directory directory, String dictName) throws IOException {
try (IndexInput fstIn = directory.openInput(dictName + ".fst", IOContext.READ)) {
fst = new FST<>(fstIn, fstIn, PositiveIntOutputs.getSingleton());
}
vbuffer = vectors.map(FileChannel.MapMode.READ_ONLY, 0, size);
dimension = vbuffer.getInt((int) (size - Integer.BYTES));

vectors = directory.openInput(dictName + ".bin", IOContext.READ);
long size = vectors.length();
vectors.seek(size - Integer.BYTES);
dimension = vectors.readInt();
if ((size - Integer.BYTES) % (dimension * Float.BYTES) != 0) {
throw new IllegalStateException(
"vector file size " + size + " is not consonant with the vector dimension " + dimension);
Expand Down Expand Up @@ -96,8 +95,8 @@ public void get(BytesRef token, byte[] output) throws IOException {
if (ord == null) {
Arrays.fill(output, (byte) 0);
} else {
vbuffer.position((int) (ord * dimension * Float.BYTES));
vbuffer.get(output);
vectors.seek(ord * dimension * Float.BYTES);
vectors.readBytes(output, 0, output.length);
}
}

Expand All @@ -122,11 +121,12 @@ public void close() throws IOException {
* and each line is space-delimited. The first column has the token, and the remaining columns
* are the vector components, as text. The dictionary must be sorted by its leading tokens
* (considered as bytes).
* @param dictOutput a dictionary path prefix. The output will be two files, named by appending
* '.fst' and '.bin' to this path.
* @param directory a Lucene directory to write the dictionary to.
* @param dictName Base name for the knn dictionary files.
*/
public static void build(Path gloveInput, Path dictOutput) throws IOException {
new Builder().build(gloveInput, dictOutput);
public static void build(Path gloveInput, Directory directory, String dictName)
throws IOException {
new Builder().build(gloveInput, directory, dictName);
}

private static class Builder {
Expand All @@ -140,25 +140,20 @@ private static class Builder {
private long ordinal = 1;
private int numFields;

void build(Path gloveInput, Path dictOutput) throws IOException {
String dictName = dictOutput.getFileName().toString();
Path fstPath = dictOutput.resolveSibling(dictName + ".fst");
Path binPath = dictOutput.resolveSibling(dictName + ".bin");
void build(Path gloveInput, Directory directory, String dictName) throws IOException {
try (BufferedReader in = Files.newBufferedReader(gloveInput);
OutputStream binOut = Files.newOutputStream(binPath);
DataOutputStream binDataOut = new DataOutputStream(binOut)) {
IndexOutput binOut = directory.createOutput(dictName + ".bin", IOContext.DEFAULT);
IndexOutput fstOut = directory.createOutput(dictName + ".fst", IOContext.DEFAULT)) {
writeFirstLine(in, binOut);
while (true) {
if (addOneLine(in, binOut) == false) {
break;
}
while (addOneLine(in, binOut)) {
// continue;
}
fstCompiler.compile().save(fstPath);
binDataOut.writeInt(numFields - 1);
fstCompiler.compile().save(fstOut, fstOut);
binOut.writeInt(numFields - 1);
}
}

private void writeFirstLine(BufferedReader in, OutputStream out) throws IOException {
private void writeFirstLine(BufferedReader in, IndexOutput out) throws IOException {
String[] fields = readOneLine(in);
if (fields == null) {
return;
Expand All @@ -178,7 +173,7 @@ private String[] readOneLine(BufferedReader in) throws IOException {
return SPACE_RE.split(line, 0);
}

private boolean addOneLine(BufferedReader in, OutputStream out) throws IOException {
private boolean addOneLine(BufferedReader in, IndexOutput out) throws IOException {
String[] fields = readOneLine(in);
if (fields == null) {
return false;
Expand All @@ -197,15 +192,21 @@ private boolean addOneLine(BufferedReader in, OutputStream out) throws IOExcepti
return true;
}

private void writeVector(String[] fields, OutputStream out) throws IOException {
private void writeVector(String[] fields, IndexOutput out) throws IOException {
byteBuffer.position(0);
FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
for (int i = 1; i < fields.length; i++) {
scratch[i - 1] = Float.parseFloat(fields[i]);
}
VectorUtil.l2normalize(scratch);
floatBuffer.put(scratch);
out.write(byteBuffer.array());
byte[] bytes = byteBuffer.array();
out.writeBytes(bytes, bytes.length);
}
}

/** Return the size of the dictionary in bytes */
public long ramBytesUsed() {
return fst.ramBytesUsed() + vectors.length();
}
}
12 changes: 12 additions & 0 deletions lucene/demo/src/java/overview.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ <h1>Apache Lucene - Building and Installing the Basic Demo</h1>
<li><a href="#Location_of_the_source">Location of the source</a></li>
<li><a href="#IndexFiles">IndexFiles</a></li>
<li><a href="#Searching_Files">Searching Files</a></li>
<li><a href="#Embeddings">Working with vector embeddings</a></li>
</ul>
</div>
<a id="About_this_Document"></a>
Expand Down Expand Up @@ -203,6 +204,17 @@ <h2 class="boxed">Searching Files</h2>
<span class="codefrag">n</span> hits. The results are printed in pages, sorted
by score (i.e. relevance).</p>
</div>
<h2 id="Embeddings" class="boxed">Working with vector embeddings</h2>
<div class="section">
<p>In addition to indexing and searching text, IndexFiles and SearchFiles can also index and search
numeric vectors derived from that text, known as "embeddings." This demo code uses pre-computed embeddings
provided by the <a href="https://nlp.stanford.edu/projects/glove/">GloVe</a> project, which are in the public
domain. The dictionary here is a tiny subset of the full GloVe dataset. It includes only the words that occur
in the toy data set, and is definitely <i>not ready for production use</i>! If you use this code to create
a vector index for a larger document set, the indexer will throw an exception because
a more complete set of embeddings is needed to get reasonable results.
</p>
</div>
</body>
</html>

7 changes: 2 additions & 5 deletions lucene/demo/src/test/org/apache/lucene/demo/TestDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.PrintStream;
import java.nio.charset.Charset;
import java.nio.file.Path;
import org.apache.lucene.demo.knn.KnnVectorDict;
import org.apache.lucene.util.LuceneTestCase;

public class TestDemo extends LuceneTestCase {
Expand Down Expand Up @@ -90,10 +89,8 @@ private void testVectorSearch(
public void testKnnVectorSearch() throws Exception {
Path dir = getDataPath("test-files/docs");
Path indexDir = createTempDir("ContribDemoTest");
Path dictPath = indexDir.resolve("knn-dict");
Path vectorDictSource = getDataPath("test-files/knn-dict").resolve("knn-token-vectors");
KnnVectorDict.build(vectorDictSource, dictPath);

Path vectorDictSource = getDataPath("test-files/knn-dict").resolve("knn-token-vectors");
IndexFiles.main(
new String[] {
"-create",
Expand All @@ -102,7 +99,7 @@ public void testKnnVectorSearch() throws Exception {
"-index",
indexDir.toString(),
"-knn_dict",
dictPath.toString()
vectorDictSource.toString()
});

// We add a single semantic hit by passing the "-knn_vector 1" argument to SearchFiles. The
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
Expand All @@ -28,30 +29,31 @@ public class TestDemoEmbeddings extends LuceneTestCase {

public void testComputeEmbedding() throws IOException {
Path testVectors = getDataPath("../test-files/knn-dict").resolve("knn-token-vectors");
Path dictPath = createTempDir("knn-demo").resolve("dict");
KnnVectorDict.build(testVectors, dictPath);
try (KnnVectorDict dict = new KnnVectorDict(dictPath)) {
DemoEmbeddings demoEmbeddings = new DemoEmbeddings(dict);
try (Directory directory = newDirectory()) {
KnnVectorDict.build(testVectors, directory, "dict");
try (KnnVectorDict dict = new KnnVectorDict(directory, "dict")) {
DemoEmbeddings demoEmbeddings = new DemoEmbeddings(dict);

// test garbage
float[] garbageVector =
demoEmbeddings.computeEmbedding("garbagethathasneverbeen seeneverinlife");
assertEquals(50, garbageVector.length);
assertArrayEquals(new float[50], garbageVector, 0);
// test garbage
float[] garbageVector =
demoEmbeddings.computeEmbedding("garbagethathasneverbeen seeneverinlife");
assertEquals(50, garbageVector.length);
assertArrayEquals(new float[50], garbageVector, 0);

// test space
assertArrayEquals(new float[50], demoEmbeddings.computeEmbedding(" "), 0);
// test space
assertArrayEquals(new float[50], demoEmbeddings.computeEmbedding(" "), 0);

// test some real words that are in the dictionary and some that are not
float[] realVector = demoEmbeddings.computeEmbedding("the real fact");
assertEquals(50, realVector.length);
// test some real words that are in the dictionary and some that are not
float[] realVector = demoEmbeddings.computeEmbedding("the real fact");
assertEquals(50, realVector.length);

float[] the = getTermVector(dict, "the");
assertArrayEquals(new float[50], getTermVector(dict, "real"), 0);
float[] fact = getTermVector(dict, "fact");
VectorUtil.add(the, fact);
VectorUtil.l2normalize(the);
assertArrayEquals(the, realVector, 0);
float[] the = getTermVector(dict, "the");
assertArrayEquals(new float[50], getTermVector(dict, "real"), 0);
float[] fact = getTermVector(dict, "fact");
VectorUtil.add(the, fact);
VectorUtil.l2normalize(the);
assertArrayEquals(the, realVector, 0);
}
}
}

Expand Down
Loading