Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduced the Word2VecSynonymFilter #12169

Merged
merged 6 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
import org.apache.lucene.analysis.standard.StandardTokenizer;
import org.apache.lucene.analysis.stempel.StempelStemmer;
import org.apache.lucene.analysis.synonym.SynonymMap;
import org.apache.lucene.analysis.synonym.word2vec.SynonymProvider;
import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel;
import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider;
import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.tests.analysis.MockTokenFilter;
Expand All @@ -99,8 +102,10 @@
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil;
import org.apache.lucene.util.AttributeFactory;
import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.IgnoreRandomChains;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.Version;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
Expand Down Expand Up @@ -415,6 +420,27 @@ private String randomNonEmptyString(Random random) {
}
}
});
put(
SynonymProvider.class,
random -> {
final int numEntries = atLeast(10);
final int vectorDimension = random.nextInt(99) + 1;
Word2VecModel model = new Word2VecModel(numEntries, vectorDimension);
for (int j = 0; j < numEntries; j++) {
String s = TestUtil.randomSimpleString(random, 10, 20);
float[] vec = new float[vectorDimension];
for (int i = 0; i < vectorDimension; i++) {
vec[i] = random.nextFloat();
}
model.addTermAndVector(new TermAndVector(new BytesRef(s), vec));
}
try {
return new Word2VecSynonymProvider(model);
} catch (IOException e) {
Rethrow.rethrow(e);
return null; // unreachable code
}
});
put(
DateFormat.class,
random -> {
Expand Down
2 changes: 2 additions & 0 deletions lucene/analysis/common/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
exports org.apache.lucene.analysis.sr;
exports org.apache.lucene.analysis.sv;
exports org.apache.lucene.analysis.synonym;
exports org.apache.lucene.analysis.synonym.word2vec;
exports org.apache.lucene.analysis.ta;
exports org.apache.lucene.analysis.te;
exports org.apache.lucene.analysis.th;
Expand Down Expand Up @@ -257,6 +258,7 @@
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory,
org.apache.lucene.analysis.synonym.SynonymFilterFactory,
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory,
org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory,
org.apache.lucene.analysis.core.FlattenGraphFilterFactory,
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory,
org.apache.lucene.analysis.te.TeluguStemFilterFactory,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.analysis.synonym.word2vec;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Locale;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndVector;

/**
* Word2VecModelReader is a Word2VecModelReader that reads the file generated by the library
dantuzi marked this conversation as resolved.
Show resolved Hide resolved
* Deeplearning4j
*
* <p>Dl4j Word2Vec documentation:
* https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to
* generate a model using dl4j:
* https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java
*
* @lucene.experimental
*/
public class Dl4jModelReader implements Word2VecModelReader {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of context, I worked with @dantuzi on deciding if implementing a custom reader(this one) or using dl4j as an imported library.
Multiple attempts were done to include dl4j as a library in Lucene, but the effort and impact was not worth it so we reverted to a simple custom reader class.

There are downsides for this of course, ma it's much more lightweight

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to avoid dependencies whenever practical


private static final String MODEL_FILE_NAME_PREFIX = "syn0";

private final String word2vecModelFilePath;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why pass in the path when it is only used in toString()? Can we choose between accepting a java.nio.Path that does its own open/close of the zip file and an (anonymous) InputStream?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything comes from the Word2VecSynonymFilterFactory that implements ResourceLoaderAware. This interface provides us a org.apache.lucene.util.ResourceLoader and the possibility to obtain an anonymous InputStream.
I decided to pass also the model file path to enrich the Exception message and make the user's life easier.
BTW I don't have a strong opinion about this. I can easily remove that string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed comment by removing the string word2vecModelFilePath and changing the exception message

private final ZipInputStream word2VecModelZipFile;

public Dl4jModelReader(String word2vecModelFilePath, InputStream stream) {
this.word2vecModelFilePath = word2vecModelFilePath;
this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream));
}

@Override
public Word2VecModel read() throws IOException {

ZipEntry entry;
while ((entry = word2VecModelZipFile.getNextEntry()) != null) {
String fileName = entry.getName();
if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) {
BufferedReader reader =
new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8));

String header = reader.readLine();
String[] headerValues = header.split(" ");
int dictionarySize = Integer.parseInt(headerValues[0]);
int vectorDimension = Integer.parseInt(headerValues[1]);

Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension);
reader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to me this would read much clearer using a traditional while loop:

while ((line = reader.readLine()) != null) ...

.lines()
.forEach(
line -> {
String[] tokens = line.split(" ");
BytesRef term = decodeTerm(tokens[0]);

float[] vector = new float[tokens.length - 1];

if (vectorDimension != vector.length) {
throw new RuntimeException(
String.format(
Locale.ROOT,
"Word2Vec model file corrupted. "
+ "Declared vectors of size %d but found vector of size %d for word %s (%s)",
vectorDimension,
vector.length,
tokens[0],
term.utf8ToString()));
}

for (int i = 1; i < tokens.length; i++) {
vector[i - 1] = Float.parseFloat(tokens[i]);
}
model.addTermAndVector(new TermAndVector(term, vector));
});
return model;
}
}
throw new UnsupportedEncodingException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this exception is really intended for use with character set encodings only. Maybe IllegalArgumentException would fit better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we use the library DL4J to train a model and we export it, we obtain a compressed zip file.
This zip contains multiple files but we are only interested in file syn0. The exception is thrown if the passed zip does not contain any syn0 file.
I guess IllegalArgumentException would fit

"The ZIP file '"
+ word2vecModelFilePath
+ "' does not contain any "
+ MODEL_FILE_NAME_PREFIX
+ " file");
}

static BytesRef decodeTerm(String term) {
if (term.toLowerCase(Locale.ROOT).startsWith("b64:")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems wasteful to lower case every term here, even when they are not b64-encoded. Also: is this something that dl4j will do consistently throughout the file? If so, we can peek at the first term and then assume the remainder will also be b64-encoded. I also wonder about the trim() - why do we need it? Does Base64.decode leave garbage at the end of the terms sometimes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your suggestion to read the first term and assume the remaining terms are encoded in the same way.
I did some checks and the trim() was useless. Thank you for noticing it

byte[] buffer = Base64.getDecoder().decode(term.substring(4).trim());
return new BytesRef(buffer, 0, buffer.length);
}
return new BytesRef(term);
}

@Override
public void close() throws IOException {
word2VecModelZipFile.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.analysis.synonym.word2vec;

import java.io.IOException;
import java.util.List;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndBoost;

/**
* Generic synonym provider
*
* @lucene.experimental
*/
public interface SynonymProvider {

/**
* SynonymProvider constructor
*
* @param term we want to find the synonyms
* @param maxSynonymsPerTerm maximum number of result returned by the synonym search
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see that we need this interface if its only use is in this one Dl4jWord2VecSynonymFilter. Can't we simply refer directly to the implementing class?

* @param minAcceptedSimilarity minimal value of cosine similarity between the searched vector and
* the retrieved ones
*/
List<TermAndBoost> getSynonyms(BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity)
throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.analysis.synonym.word2vec;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
* word in dictionary
*
* @lucene.experimental
*/
public class Word2VecModel implements RandomAccessVectorValues<float[]> {

private final int dictionarySize;
private final int vectorDimension;
private final TermAndVector[] data;
private final Map<BytesRef, TermAndVector> word2Vec;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you consider using BytesRefHash?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never seen the ByteRefHash before, thank you for your suggestion

private int loadedCount = 0;

public Word2VecModel(int dictionarySize, int vectorDimension) {
this.dictionarySize = dictionarySize;
this.vectorDimension = vectorDimension;
this.data = new TermAndVector[dictionarySize];
this.word2Vec = new HashMap<>();
}

private Word2VecModel(
int dictionarySize,
int vectorDimension,
TermAndVector[] data,
Map<BytesRef, TermAndVector> word2Vec) {
this.dictionarySize = dictionarySize;
this.vectorDimension = vectorDimension;
this.data = data;
this.word2Vec = word2Vec;
}

public void addTermAndVector(TermAndVector modelEntry) {
modelEntry.normalizeVector();
this.data[loadedCount++] = modelEntry;
this.word2Vec.put(modelEntry.getTerm(), modelEntry);
}

@Override
public float[] vectorValue(int ord) throws IOException {
return data[ord].getVector();
}

public float[] vectorValue(BytesRef term) {
TermAndVector entry = word2Vec.get(term);
return (entry == null) ? null : entry.getVector();
}

public BytesRef binaryValue(int targetOrd) throws IOException {
return data[targetOrd].getTerm();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect - the purpose of this method is to return some bytes representing the vector value. I think instead you ought to simply throw UnsupportedOperationException since this implementation will never be used in an indexing context where this method is required.

Copy link
Contributor Author

@dantuzi dantuzi Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you can see, this method is not @Override so this is not an implementation of RandomAccessVectorValues interface. This is a custom method used in our implementation

}

@Override
public int dimension() {
return vectorDimension;
}

@Override
public int size() {
return dictionarySize;
}

@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this copy is meant to do a deep copy, I suspect we'll need to handle it differently, I am not sure it's copying the internal elements, but reusing them?
So a copy could end up adding elements to data structures used by the original Object?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not need to do a deep copy; the purpose of this method is to enable multiple concurrent accesses to the underlying data. Since this implementation doesn't have any temporary variable into which vectors are decoded (which could be overwritten), I think it's safe to simply return this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msokolov I tried to implement your suggestion but it looks like the method HnswGraphBuilder::build doesn't want the same reference passed to the HnswGraphBuilder.create. [1]
To be honest I still don't understand why this check [2] is required

[1]

Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()
java.lang.IllegalArgumentException: Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()
	at __randomizedtesting.SeedInfo.seed([994075DD4398F0A4:E100BB05917EA0E6]:0)
	at org.apache.lucene.core@10.0.0-SNAPSHOT/org.apache.lucene.util.hnsw.HnswGraphBuilder.build(HnswGraphBuilder.java:165)
	at org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider.<init>(Word2VecSynonymProvider.java:64)
	at org.apache.lucene.analysis.synonym.word2vec.TestWord2VecSynonymProvider.<init>(TestWord2VecSynonymProvider.java:39)

[2]

if (vectorsToAdd == this.vectors) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
}

return new Word2VecModel(this.dictionarySize, this.vectorDimension, this.data, this.word2Vec);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.analysis.synonym.word2vec;

import java.io.Closeable;
import java.io.IOException;

/**
* Each class extending this interface must be able to read a Word2Vec model format and provide a
* Word2VecModel with normalized vectors
*
* @lucene.experimental
*/
public interface Word2VecModelReader extends Closeable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have more than one implementation? I don't think this interface is necessary. Later we can always add it if we have multiple implementations and need to abstract. For now it's just extra stuff to maintain


Word2VecModel read() throws IOException;
}
Loading