diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index a7016627cd78..91f1ea039a71 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -135,6 +135,8 @@ New Features crash the JVM. To disable this feature, pass the following sysprop on Java command line: "-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler) +* GITHUB#12169: Introduce a new token filter to expand synonyms based on Word2Vec DL4j models. (Daniele Antuzi, Ilaria Petreti, Alessandro Benedetti) + Improvements --------------------- diff --git a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java index 8c245e7058c7..988deaf99e59 100644 --- a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java +++ b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java @@ -89,6 +89,8 @@ import org.apache.lucene.analysis.standard.StandardTokenizer; import org.apache.lucene.analysis.stempel.StempelStemmer; import org.apache.lucene.analysis.synonym.SynonymMap; +import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel; +import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider; import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; import org.apache.lucene.tests.analysis.MockTokenFilter; @@ -99,8 +101,10 @@ import org.apache.lucene.tests.util.automaton.AutomatonTestUtil; import org.apache.lucene.util.AttributeFactory; import org.apache.lucene.util.AttributeSource; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.CharsRef; import org.apache.lucene.util.IgnoreRandomChains; +import org.apache.lucene.util.TermAndVector; import org.apache.lucene.util.Version; import org.apache.lucene.util.automaton.Automaton; import org.apache.lucene.util.automaton.CharacterRunAutomaton; @@ -415,6 +419,27 @@ private String randomNonEmptyString(Random random) { } } }); + put( + Word2VecSynonymProvider.class, + random -> { + final int numEntries = atLeast(10); + final int vectorDimension = random.nextInt(99) + 1; + Word2VecModel model = new Word2VecModel(numEntries, vectorDimension); + for (int j = 0; j < numEntries; j++) { + String s = TestUtil.randomSimpleString(random, 10, 20); + float[] vec = new float[vectorDimension]; + for (int i = 0; i < vectorDimension; i++) { + vec[i] = random.nextFloat(); + } + model.addTermAndVector(new TermAndVector(new BytesRef(s), vec)); + } + try { + return new Word2VecSynonymProvider(model); + } catch (IOException e) { + Rethrow.rethrow(e); + return null; // unreachable code + } + }); put( DateFormat.class, random -> { diff --git a/lucene/analysis/common/src/java/module-info.java b/lucene/analysis/common/src/java/module-info.java index 96ec0d293d3d..84654191961b 100644 --- a/lucene/analysis/common/src/java/module-info.java +++ b/lucene/analysis/common/src/java/module-info.java @@ -79,6 +79,7 @@ exports org.apache.lucene.analysis.sr; exports org.apache.lucene.analysis.sv; exports org.apache.lucene.analysis.synonym; + exports org.apache.lucene.analysis.synonym.word2vec; exports org.apache.lucene.analysis.ta; exports org.apache.lucene.analysis.te; exports org.apache.lucene.analysis.th; @@ -257,6 +258,7 @@ org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory, org.apache.lucene.analysis.synonym.SynonymFilterFactory, org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory, + org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory, org.apache.lucene.analysis.core.FlattenGraphFilterFactory, org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory, org.apache.lucene.analysis.te.TeluguStemFilterFactory, diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java new file mode 100644 index 000000000000..f022dd8eca67 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.BufferedInputStream; +import java.io.BufferedReader; +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Locale; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TermAndVector; + +/** + * Dl4jModelReader reads the file generated by the library Deeplearning4j and provide a + * Word2VecModel with normalized vectors + * + *

Dl4j Word2Vec documentation: + * https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to + * generate a model using dl4j: + * https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java + * + * @lucene.experimental + */ +public class Dl4jModelReader implements Closeable { + + private static final String MODEL_FILE_NAME_PREFIX = "syn0"; + + private final ZipInputStream word2VecModelZipFile; + + public Dl4jModelReader(InputStream stream) { + this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream)); + } + + public Word2VecModel read() throws IOException { + + ZipEntry entry; + while ((entry = word2VecModelZipFile.getNextEntry()) != null) { + String fileName = entry.getName(); + if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) { + BufferedReader reader = + new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8)); + + String header = reader.readLine(); + String[] headerValues = header.split(" "); + int dictionarySize = Integer.parseInt(headerValues[0]); + int vectorDimension = Integer.parseInt(headerValues[1]); + + Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension); + String line = reader.readLine(); + boolean isTermB64Encoded = false; + if (line != null) { + String[] tokens = line.split(" "); + isTermB64Encoded = + tokens[0].substring(0, 3).toLowerCase(Locale.ROOT).compareTo("b64") == 0; + model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded)); + } + while ((line = reader.readLine()) != null) { + String[] tokens = line.split(" "); + model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded)); + } + return model; + } + } + throw new IllegalArgumentException( + "Cannot read Dl4j word2vec model - '" + + MODEL_FILE_NAME_PREFIX + + "' file is missing in the zip. '" + + MODEL_FILE_NAME_PREFIX + + "' is a mandatory file containing the mapping between terms and vectors generated by the DL4j library."); + } + + private static TermAndVector extractTermAndVector( + String[] tokens, int vectorDimension, boolean isTermB64Encoded) { + BytesRef term = isTermB64Encoded ? decodeB64Term(tokens[0]) : new BytesRef((tokens[0])); + + float[] vector = new float[tokens.length - 1]; + + if (vectorDimension != vector.length) { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Word2Vec model file corrupted. " + + "Declared vectors of size %d but found vector of size %d for word %s (%s)", + vectorDimension, + vector.length, + tokens[0], + term.utf8ToString())); + } + + for (int i = 1; i < tokens.length; i++) { + vector[i - 1] = Float.parseFloat(tokens[i]); + } + return new TermAndVector(term, vector); + } + + static BytesRef decodeB64Term(String term) { + byte[] buffer = Base64.getDecoder().decode(term.substring(4)); + return new BytesRef(buffer, 0, buffer.length); + } + + @Override + public void close() throws IOException { + word2VecModelZipFile.close(); + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java new file mode 100644 index 000000000000..6719639b67d9 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefHash; +import org.apache.lucene.util.TermAndVector; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; + +/** + * Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each + * word in dictionary + * + * @lucene.experimental + */ +public class Word2VecModel implements RandomAccessVectorValues { + + private final int dictionarySize; + private final int vectorDimension; + private final TermAndVector[] termsAndVectors; + private final BytesRefHash word2Vec; + private int loadedCount = 0; + + public Word2VecModel(int dictionarySize, int vectorDimension) { + this.dictionarySize = dictionarySize; + this.vectorDimension = vectorDimension; + this.termsAndVectors = new TermAndVector[dictionarySize]; + this.word2Vec = new BytesRefHash(); + } + + private Word2VecModel( + int dictionarySize, + int vectorDimension, + TermAndVector[] termsAndVectors, + BytesRefHash word2Vec) { + this.dictionarySize = dictionarySize; + this.vectorDimension = vectorDimension; + this.termsAndVectors = termsAndVectors; + this.word2Vec = word2Vec; + } + + public void addTermAndVector(TermAndVector modelEntry) { + modelEntry.normalizeVector(); + this.termsAndVectors[loadedCount++] = modelEntry; + this.word2Vec.add(modelEntry.getTerm()); + } + + @Override + public float[] vectorValue(int targetOrd) { + return termsAndVectors[targetOrd].getVector(); + } + + public float[] vectorValue(BytesRef term) { + int termOrd = this.word2Vec.find(term); + if (termOrd < 0) return null; + TermAndVector entry = this.termsAndVectors[termOrd]; + return (entry == null) ? null : entry.getVector(); + } + + public BytesRef termValue(int targetOrd) { + return termsAndVectors[targetOrd].getTerm(); + } + + @Override + public int dimension() { + return vectorDimension; + } + + @Override + public int size() { + return dictionarySize; + } + + @Override + public RandomAccessVectorValues copy() throws IOException { + return new Word2VecModel( + this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec); + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java new file mode 100644 index 000000000000..29f9d37d5fdf --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import org.apache.lucene.analysis.TokenFilter; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.synonym.SynonymGraphFilter; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; +import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute; +import org.apache.lucene.analysis.tokenattributes.TypeAttribute; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.TermAndBoost; + +/** + * Applies single-token synonyms from a Word2Vec trained network to an incoming {@link TokenStream}. + * + * @lucene.experimental + */ +public final class Word2VecSynonymFilter extends TokenFilter { + + private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); + private final PositionIncrementAttribute posIncrementAtt = + addAttribute(PositionIncrementAttribute.class); + private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class); + private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class); + + private final Word2VecSynonymProvider synonymProvider; + private final int maxSynonymsPerTerm; + private final float minAcceptedSimilarity; + private final LinkedList synonymBuffer = new LinkedList<>(); + private State lastState; + + /** + * Apply previously built synonymProvider to incoming tokens. + * + * @param input input tokenstream + * @param synonymProvider synonym provider + * @param maxSynonymsPerTerm maximum number of result returned by the synonym search + * @param minAcceptedSimilarity minimal value of cosine similarity between the searched vector and + * the retrieved ones + */ + public Word2VecSynonymFilter( + TokenStream input, + Word2VecSynonymProvider synonymProvider, + int maxSynonymsPerTerm, + float minAcceptedSimilarity) { + super(input); + this.synonymProvider = synonymProvider; + this.maxSynonymsPerTerm = maxSynonymsPerTerm; + this.minAcceptedSimilarity = minAcceptedSimilarity; + } + + @Override + public boolean incrementToken() throws IOException { + + if (!synonymBuffer.isEmpty()) { + TermAndBoost synonym = synonymBuffer.pollFirst(); + clearAttributes(); + restoreState(this.lastState); + termAtt.setEmpty(); + termAtt.append(synonym.term.utf8ToString()); + typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM); + posLenAtt.setPositionLength(1); + posIncrementAtt.setPositionIncrement(0); + return true; + } + + if (input.incrementToken()) { + BytesRefBuilder bytesRefBuilder = new BytesRefBuilder(); + bytesRefBuilder.copyChars(termAtt.buffer(), 0, termAtt.length()); + BytesRef term = bytesRefBuilder.get(); + List synonyms = + this.synonymProvider.getSynonyms(term, maxSynonymsPerTerm, minAcceptedSimilarity); + if (synonyms.size() > 0) { + this.lastState = captureState(); + this.synonymBuffer.addAll(synonyms); + } + return true; + } + return false; + } + + @Override + public void reset() throws IOException { + super.reset(); + synonymBuffer.clear(); + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java new file mode 100644 index 000000000000..32b6288926fc --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import org.apache.lucene.analysis.TokenFilterFactory; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProviderFactory.Word2VecSupportedFormats; +import org.apache.lucene.util.ResourceLoader; +import org.apache.lucene.util.ResourceLoaderAware; + +/** + * Factory for {@link Word2VecSynonymFilter}. + * + * @lucene.experimental + * @lucene.spi {@value #NAME} + */ +public class Word2VecSynonymFilterFactory extends TokenFilterFactory + implements ResourceLoaderAware { + + /** SPI name */ + public static final String NAME = "Word2VecSynonym"; + + public static final int DEFAULT_MAX_SYNONYMS_PER_TERM = 5; + public static final float DEFAULT_MIN_ACCEPTED_SIMILARITY = 0.8f; + + private final int maxSynonymsPerTerm; + private final float minAcceptedSimilarity; + private final Word2VecSupportedFormats format; + private final String word2vecModelFileName; + + private Word2VecSynonymProvider synonymProvider; + + public Word2VecSynonymFilterFactory(Map args) { + super(args); + this.maxSynonymsPerTerm = getInt(args, "maxSynonymsPerTerm", DEFAULT_MAX_SYNONYMS_PER_TERM); + this.minAcceptedSimilarity = + getFloat(args, "minAcceptedSimilarity", DEFAULT_MIN_ACCEPTED_SIMILARITY); + this.word2vecModelFileName = require(args, "model"); + + String modelFormat = get(args, "format", "dl4j").toUpperCase(Locale.ROOT); + try { + this.format = Word2VecSupportedFormats.valueOf(modelFormat); + } catch (IllegalArgumentException exc) { + throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc); + } + + if (!args.isEmpty()) { + throw new IllegalArgumentException("Unknown parameters: " + args); + } + if (minAcceptedSimilarity <= 0 || minAcceptedSimilarity > 1) { + throw new IllegalArgumentException( + "minAcceptedSimilarity must be in the range (0, 1]. Found: " + minAcceptedSimilarity); + } + if (maxSynonymsPerTerm <= 0) { + throw new IllegalArgumentException( + "maxSynonymsPerTerm must be a positive integer greater than 0. Found: " + + maxSynonymsPerTerm); + } + } + + /** Default ctor for compatibility with SPI */ + public Word2VecSynonymFilterFactory() { + throw defaultCtorException(); + } + + Word2VecSynonymProvider getSynonymProvider() { + return this.synonymProvider; + } + + @Override + public TokenStream create(TokenStream input) { + return synonymProvider == null + ? input + : new Word2VecSynonymFilter( + input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity); + } + + @Override + public void inform(ResourceLoader loader) throws IOException { + this.synonymProvider = + Word2VecSynonymProviderFactory.getSynonymProvider(loader, word2vecModelFileName, format); + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java new file mode 100644 index 000000000000..1d1f06fc991f --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TermAndBoost; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.HnswGraphBuilder; +import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.NeighborQueue; + +/** + * The Word2VecSynonymProvider generates the list of sysnonyms of a term. + * + * @lucene.experimental + */ +public class Word2VecSynonymProvider { + + private static final VectorSimilarityFunction SIMILARITY_FUNCTION = + VectorSimilarityFunction.DOT_PRODUCT; + private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32; + private final Word2VecModel word2VecModel; + private final HnswGraph hnswGraph; + + /** + * Word2VecSynonymProvider constructor + * + * @param model containing the set of TermAndVector entries + */ + public Word2VecSynonymProvider(Word2VecModel model) throws IOException { + word2VecModel = model; + + HnswGraphBuilder builder = + HnswGraphBuilder.create( + word2VecModel, + VECTOR_ENCODING, + SIMILARITY_FUNCTION, + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + HnswGraphBuilder.randSeed); + this.hnswGraph = builder.build(word2VecModel.copy()); + } + + public List getSynonyms( + BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException { + + if (term == null) { + throw new IllegalArgumentException("Term must not be null"); + } + + LinkedList result = new LinkedList<>(); + float[] query = word2VecModel.vectorValue(term); + if (query != null) { + NeighborQueue synonyms = + HnswGraphSearcher.search( + query, + // The query vector is in the model. When looking for the top-k + // it's always the nearest neighbour of itself so, we look for the top-k+1 + maxSynonymsPerTerm + 1, + word2VecModel, + VECTOR_ENCODING, + SIMILARITY_FUNCTION, + hnswGraph, + null, + word2VecModel.size()); + + int size = synonyms.size(); + for (int i = 0; i < size; i++) { + float similarity = synonyms.topScore(); + int id = synonyms.pop(); + + BytesRef synonym = word2VecModel.termValue(id); + // We remove the original query term + if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) { + result.addFirst(new TermAndBoost(synonym, similarity)); + } + } + } + return result; + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java new file mode 100644 index 000000000000..ea849e653cd6 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.lucene.util.ResourceLoader; + +/** + * Supply Word2Vec Word2VecSynonymProvider cache avoiding that multiple instances of + * Word2VecSynonymFilterFactory will instantiate multiple instances of the same SynonymProvider. + * Assumes synonymProvider implementations are thread-safe. + */ +public class Word2VecSynonymProviderFactory { + + enum Word2VecSupportedFormats { + DL4J + } + + private static Map word2vecSynonymProviders = + new ConcurrentHashMap<>(); + + public static Word2VecSynonymProvider getSynonymProvider( + ResourceLoader loader, String modelFileName, Word2VecSupportedFormats format) + throws IOException { + Word2VecSynonymProvider synonymProvider = word2vecSynonymProviders.get(modelFileName); + if (synonymProvider == null) { + try (InputStream stream = loader.openResource(modelFileName)) { + try (Dl4jModelReader reader = getModelReader(format, stream)) { + synonymProvider = new Word2VecSynonymProvider(reader.read()); + } + } + word2vecSynonymProviders.put(modelFileName, synonymProvider); + } + return synonymProvider; + } + + private static Dl4jModelReader getModelReader( + Word2VecSupportedFormats format, InputStream stream) { + switch (format) { + case DL4J: + return new Dl4jModelReader(stream); + } + return null; + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java new file mode 100644 index 000000000000..e8d69ab3cf9b --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Analysis components for Synonyms using Word2Vec model. */ +package org.apache.lucene.analysis.synonym.word2vec; diff --git a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory index 19a34b7840a8..1e4e17eaeadf 100644 --- a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory +++ b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory @@ -118,6 +118,7 @@ org.apache.lucene.analysis.sv.SwedishLightStemFilterFactory org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory org.apache.lucene.analysis.synonym.SynonymFilterFactory org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory +org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory org.apache.lucene.analysis.core.FlattenGraphFilterFactory org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory org.apache.lucene.analysis.te.TeluguStemFilterFactory diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java new file mode 100644 index 000000000000..213dcdaccd33 --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.junit.Test; + +public class TestDl4jModelReader extends LuceneTestCase { + + private static final String MODEL_FILE = "word2vec-model.zip"; + private static final String MODEL_EMPTY_FILE = "word2vec-empty-model.zip"; + private static final String CORRUPTED_VECTOR_DIMENSION_MODEL_FILE = + "word2vec-corrupted-vector-dimension-model.zip"; + + InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_FILE); + Dl4jModelReader unit = new Dl4jModelReader(stream); + + @Test + public void read_zipFileWithMetadata_shouldReturnDictionarySize() throws Exception { + Word2VecModel model = unit.read(); + long expectedDictionarySize = 235; + assertEquals(expectedDictionarySize, model.size()); + } + + @Test + public void read_zipFileWithMetadata_shouldReturnVectorLength() throws Exception { + Word2VecModel model = unit.read(); + int expectedVectorDimension = 100; + assertEquals(expectedVectorDimension, model.dimension()); + } + + @Test + public void read_zipFile_shouldReturnDecodedTerm() throws Exception { + Word2VecModel model = unit.read(); + BytesRef expectedDecodedFirstTerm = new BytesRef("it"); + assertEquals(expectedDecodedFirstTerm, model.termValue(0)); + } + + @Test + public void decodeTerm_encodedTerm_shouldReturnDecodedTerm() throws Exception { + byte[] originalInput = "lucene".getBytes(StandardCharsets.UTF_8); + String B64encodedLuceneTerm = Base64.getEncoder().encodeToString(originalInput); + String word2vecEncodedLuceneTerm = "B64:" + B64encodedLuceneTerm; + assertEquals(new BytesRef("lucene"), Dl4jModelReader.decodeB64Term(word2vecEncodedLuceneTerm)); + } + + @Test + public void read_EmptyZipFile_shouldThrowException() throws Exception { + try (InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_EMPTY_FILE)) { + Dl4jModelReader unit = new Dl4jModelReader(stream); + expectThrows(IllegalArgumentException.class, unit::read); + } + } + + @Test + public void read_corruptedVectorDimensionModelFile_shouldThrowException() throws Exception { + try (InputStream stream = + TestDl4jModelReader.class.getResourceAsStream(CORRUPTED_VECTOR_DIMENSION_MODEL_FILE)) { + Dl4jModelReader unit = new Dl4jModelReader(stream); + expectThrows(RuntimeException.class, unit::read); + } + } +} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java new file mode 100644 index 000000000000..3999931dd758 --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; +import org.apache.lucene.tests.analysis.MockTokenizer; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TermAndVector; +import org.junit.Test; + +public class TestWord2VecSynonymFilter extends BaseTokenStreamTestCase { + + @Test + public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() throws Exception { + int maxSynonymPerTerm = 10; + float minAcceptedSimilarity = 0.9f; + Word2VecModel model = new Word2VecModel(6, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output + new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset + new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset + new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types + new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements + new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts + a.close(); + } + + @Test + public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() throws Exception { + int maxSynonymPerTerm = 2; + float minAcceptedSimilarity = 0.9f; + Word2VecModel model = new Word2VecModel(5, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "d", "e", "post"}, // output + new int[] {0, 4, 4, 4, 6}, // start offset + new int[] {3, 5, 5, 5, 10}, // end offset + new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types + new int[] {1, 1, 0, 0, 1}, // posIncrements + new int[] {1, 1, 1, 1, 1}); // posLenghts + a.close(); + } + + @Test + public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Exception { + Word2VecModel model = new Word2VecModel(8, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {1, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("post"), new float[] {-10, -11})); + model.addTermAndVector(new TermAndVector(new BytesRef("after"), new float[] {-8, -10})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "d", "e", "c", "b", "post", "after"}, // output + new int[] {0, 4, 4, 4, 4, 4, 6, 6}, // start offset + new int[] {3, 5, 5, 5, 5, 5, 10, 10}, // end offset + new String[] { // types + "word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM" + }, + new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements + new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths + a.close(); + } + + @Test + public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms() + throws Exception { + Word2VecModel model = new Word2VecModel(4, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, -10})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "post"}, // output + new int[] {0, 4, 6}, // start offset + new int[] {3, 5, 10}, // end offset + new String[] {"word", "word", "word"}, // types + new int[] {1, 1, 1}, // posIncrements + new int[] {1, 1, 1}); // posLengths + a.close(); + } + + private Analyzer getAnalyzer( + Word2VecSynonymProvider synonymProvider, + int maxSynonymsPerTerm, + float minAcceptedSimilarity) { + return new Analyzer() { + @Override + protected TokenStreamComponents createComponents(String fieldName) { + Tokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false); + // Make a local variable so testRandomHuge doesn't share it across threads! + Word2VecSynonymFilter synFilter = + new Word2VecSynonymFilter( + tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity); + return new TokenStreamComponents(tokenizer, synFilter); + } + }; + } +} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java new file mode 100644 index 000000000000..007fedf4abed --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import org.apache.lucene.tests.analysis.BaseTokenStreamFactoryTestCase; +import org.apache.lucene.util.ClasspathResourceLoader; +import org.apache.lucene.util.ResourceLoader; +import org.junit.Test; + +public class TestWord2VecSynonymFilterFactory extends BaseTokenStreamFactoryTestCase { + + public static final String FACTORY_NAME = "Word2VecSynonym"; + private static final String WORD2VEC_MODEL_FILE = "word2vec-model.zip"; + + @Test + public void testInform() throws Exception { + ResourceLoader loader = new ClasspathResourceLoader(getClass()); + assertTrue("loader is null and it shouldn't be", loader != null); + Word2VecSynonymFilterFactory factory = + (Word2VecSynonymFilterFactory) + tokenFilterFactory( + FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "minAcceptedSimilarity", "0.7"); + + Word2VecSynonymProvider synonymProvider = factory.getSynonymProvider(); + assertNotEquals(null, synonymProvider); + } + + @Test + public void missingRequiredArgument_shouldThrowException() throws Exception { + IllegalArgumentException expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "format", + "dl4j", + "minAcceptedSimilarity", + "0.7", + "maxSynonymsPerTerm", + "10"); + }); + assertTrue(expected.getMessage().contains("Configuration Error: missing parameter 'model'")); + } + + @Test + public void unsupportedModelFormat_shouldThrowException() throws Exception { + IllegalArgumentException expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "format", "bogusValue"); + }); + assertTrue(expected.getMessage().contains("Model format 'BOGUSVALUE' not supported")); + } + + @Test + public void bogusArgument_shouldThrowException() throws Exception { + IllegalArgumentException expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "bogusArg", "bogusValue"); + }); + assertTrue(expected.getMessage().contains("Unknown parameters")); + } + + @Test + public void illegalArguments_shouldThrowException() throws Exception { + IllegalArgumentException expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "model", + WORD2VEC_MODEL_FILE, + "minAcceptedSimilarity", + "2", + "maxSynonymsPerTerm", + "10"); + }); + assertTrue( + expected + .getMessage() + .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 2")); + + expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "model", + WORD2VEC_MODEL_FILE, + "minAcceptedSimilarity", + "0", + "maxSynonymsPerTerm", + "10"); + }); + assertTrue( + expected + .getMessage() + .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 0")); + + expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "model", + WORD2VEC_MODEL_FILE, + "minAcceptedSimilarity", + "0.7", + "maxSynonymsPerTerm", + "-1"); + }); + assertTrue( + expected + .getMessage() + .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: -1")); + + expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "model", + WORD2VEC_MODEL_FILE, + "minAcceptedSimilarity", + "0.7", + "maxSynonymsPerTerm", + "0"); + }); + assertTrue( + expected + .getMessage() + .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: 0")); + } +} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java new file mode 100644 index 000000000000..0803406bdafa --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import java.util.List; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TermAndBoost; +import org.apache.lucene.util.TermAndVector; +import org.junit.Test; + +public class TestWord2VecSynonymProvider extends LuceneTestCase { + + private static final int MAX_SYNONYMS_PER_TERM = 10; + private static final float MIN_ACCEPTED_SIMILARITY = 0.85f; + + private final Word2VecSynonymProvider unit; + + public TestWord2VecSynonymProvider() throws IOException { + Word2VecModel model = new Word2VecModel(2, 3); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {0.24f, 0.78f, 0.28f})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {0.44f, 0.01f, 0.81f})); + unit = new Word2VecSynonymProvider(model); + } + + @Test + public void getSynonyms_nullToken_shouldThrowException() { + expectThrows( + IllegalArgumentException.class, + () -> unit.getSynonyms(null, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY)); + } + + @Test + public void getSynonyms_shouldReturnSynonymsBasedOnMinAcceptedSimilarity() throws Exception { + Word2VecModel model = new Word2VecModel(6, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); + + Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); + + BytesRef inputTerm = new BytesRef("a"); + String[] expectedSynonyms = {"d", "e", "c", "b"}; + List actualSynonymsResults = + unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); + + assertEquals(4, actualSynonymsResults.size()); + for (int i = 0; i < expectedSynonyms.length; i++) { + assertEquals(new BytesRef(expectedSynonyms[i]), actualSynonymsResults.get(i).term); + } + } + + @Test + public void getSynonyms_shouldReturnSynonymsBoost() throws Exception { + Word2VecModel model = new Word2VecModel(3, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {99, 101})); + + Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); + + BytesRef inputTerm = new BytesRef("a"); + List actualSynonymsResults = + unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); + + BytesRef expectedFirstSynonymTerm = new BytesRef("b"); + double expectedFirstSynonymBoost = 1.0; + assertEquals(expectedFirstSynonymTerm, actualSynonymsResults.get(0).term); + assertEquals(expectedFirstSynonymBoost, actualSynonymsResults.get(0).boost, 0.001f); + } + + @Test + public void noSynonymsWithinAcceptedSimilarity_shouldReturnNoSynonyms() throws Exception { + Word2VecModel model = new Word2VecModel(4, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {6, -6})); + + Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); + + BytesRef inputTerm = newBytesRef("a"); + List actualSynonymsResults = + unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); + assertEquals(0, actualSynonymsResults.size()); + } + + @Test + public void testModel_shouldReturnNormalizedVectors() { + Word2VecModel model = new Word2VecModel(4, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); + + float[] vectorIdA = model.vectorValue(new BytesRef("a")); + float[] vectorIdF = model.vectorValue(new BytesRef("f")); + assertArrayEquals(new float[] {0.70710f, 0.70710f}, vectorIdA, 0.001f); + assertArrayEquals(new float[] {-0.0995f, 0.99503f}, vectorIdF, 0.001f); + } + + @Test + public void normalizedVector_shouldReturnModule1() { + TermAndVector synonymTerm = new TermAndVector(new BytesRef("a"), new float[] {10, 10}); + synonymTerm.normalizeVector(); + float[] vector = synonymTerm.getVector(); + float len = 0; + for (int i = 0; i < vector.length; i++) { + len += vector[i] * vector[i]; + } + assertEquals(1, Math.sqrt(len), 0.0001f); + } +} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip new file mode 100644 index 000000000000..e25693dd83cf Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip differ diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip new file mode 100644 index 000000000000..57d7832dd787 Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip differ diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip new file mode 100644 index 000000000000..6d31b8d5a3fa Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip differ diff --git a/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java b/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java index 3ad17dca5ce9..adf30baec294 100644 --- a/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java @@ -62,20 +62,6 @@ public class QueryBuilder { protected boolean enableGraphQueries = true; protected boolean autoGenerateMultiTermSynonymsPhraseQuery = false; - /** Wraps a term and boost */ - public static class TermAndBoost { - /** the term */ - public final BytesRef term; - /** the boost */ - public final float boost; - - /** Creates a new TermAndBoost */ - public TermAndBoost(BytesRef term, float boost) { - this.term = BytesRef.deepCopyOf(term); - this.boost = boost; - } - } - /** Creates a new QueryBuilder using the given analyzer. */ public QueryBuilder(Analyzer analyzer) { this.analyzer = analyzer; diff --git a/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java b/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java new file mode 100644 index 000000000000..0c7958a99a08 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +/** Wraps a term and boost */ +public class TermAndBoost { + /** the term */ + public final BytesRef term; + /** the boost */ + public final float boost; + + /** Creates a new TermAndBoost */ + public TermAndBoost(BytesRef term, float boost) { + this.term = BytesRef.deepCopyOf(term); + this.boost = boost; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java new file mode 100644 index 000000000000..1ade19a19803 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +import java.util.Locale; + +/** + * Word2Vec unit composed by a term with the associated vector + * + * @lucene.experimental + */ +public class TermAndVector { + + private final BytesRef term; + private final float[] vector; + + public TermAndVector(BytesRef term, float[] vector) { + this.term = term; + this.vector = vector; + } + + public BytesRef getTerm() { + return this.term; + } + + public float[] getVector() { + return this.vector; + } + + public int size() { + return vector.length; + } + + public void normalizeVector() { + float vectorLength = 0; + for (int i = 0; i < vector.length; i++) { + vectorLength += vector[i] * vector[i]; + } + vectorLength = (float) Math.sqrt(vectorLength); + for (int i = 0; i < vector.length; i++) { + vector[i] /= vectorLength; + } + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(this.term.utf8ToString()); + builder.append(" ["); + if (vector.length > 0) { + for (int i = 0; i < vector.length - 1; i++) { + builder.append(String.format(Locale.ROOT, "%.3f,", vector[i])); + } + builder.append(String.format(Locale.ROOT, "%.3f]", vector[vector.length - 1])); + } + return builder.toString(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 9f1e6c505254..edefb696b8c5 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -41,8 +41,17 @@ */ public final class HnswGraphBuilder { + /** Default number of maximum connections per node */ + public static final int DEFAULT_MAX_CONN = 16; + + /** + * Default number of the size of the queue maintained while searching during a graph construction. + */ + public static final int DEFAULT_BEAM_WIDTH = 100; + /** Default random seed for level generation * */ private static final long DEFAULT_RAND_SEED = 42; + /** A name for the HNSW component for the info-stream * */ public static final String HNSW_COMPONENT = "HNSW"; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 08f089430ba5..372384df5572 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -54,6 +54,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; +import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.junit.After; import org.junit.Before; @@ -62,7 +63,7 @@ public class TestKnnGraph extends LuceneTestCase { private static final String KNN_GRAPH_FIELD = "vector"; - private static int M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN; + private static int M = HnswGraphBuilder.DEFAULT_MAX_CONN; private Codec codec; private Codec float32Codec; @@ -80,7 +81,7 @@ public void setup() { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); } }; @@ -92,7 +93,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); } }; @@ -103,7 +104,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); } }; } @@ -115,7 +116,7 @@ private VectorEncoding randomVectorEncoding() { @After public void cleanup() { - M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN; + M = HnswGraphBuilder.DEFAULT_MAX_CONN; } /** Basic test of creating documents in a graph */ diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java index b2ce16d80776..f9cd607ccacc 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java @@ -55,6 +55,7 @@ import org.apache.lucene.document.TextField; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableFieldType; +import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; @@ -154,7 +155,8 @@ public static void assertTokenStreamContents( boolean[] keywordAtts, boolean graphOffsetsAreCorrect, byte[][] payloads, - int[] flags) + int[] flags, + float[] boost) throws IOException { assertNotNull(output); CheckClearAttributesAttribute checkClearAtt = @@ -221,6 +223,12 @@ public static void assertTokenStreamContents( flagsAtt = ts.getAttribute(FlagsAttribute.class); } + BoostAttribute boostAtt = null; + if (boost != null) { + assertTrue("has no BoostAttribute", ts.hasAttribute(BoostAttribute.class)); + boostAtt = ts.getAttribute(BoostAttribute.class); + } + // Maps position to the start/end offset: final Map posToStartOffset = new HashMap<>(); final Map posToEndOffset = new HashMap<>(); @@ -243,6 +251,7 @@ public static void assertTokenStreamContents( if (payloadAtt != null) payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24})); if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's + if (boostAtt != null) boostAtt.setBoost(-1f); checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before assertTrue("token " + i + " does not exist", ts.incrementToken()); @@ -278,6 +287,9 @@ public static void assertTokenStreamContents( if (flagsAtt != null) { assertEquals("flagsAtt " + i + " term=" + termAtt, flags[i], flagsAtt.getFlags()); } + if (boostAtt != null) { + assertEquals("boostAtt " + i + " term=" + termAtt, boost[i], boostAtt.getBoost(), 0.001); + } if (payloads != null) { if (payloads[i] != null) { assertEquals("payloads " + i, new BytesRef(payloads[i]), payloadAtt.getPayload()); @@ -405,6 +417,7 @@ public static void assertTokenStreamContents( if (payloadAtt != null) payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24})); if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's + if (boostAtt != null) boostAtt.setBoost(-1); checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before @@ -426,6 +439,38 @@ public static void assertTokenStreamContents( ts.close(); } + public static void assertTokenStreamContents( + TokenStream ts, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + Integer finalOffset, + Integer finalPosInc, + boolean[] keywordAtts, + boolean graphOffsetsAreCorrect, + byte[][] payloads, + int[] flags) + throws IOException { + assertTokenStreamContents( + ts, + output, + startOffsets, + endOffsets, + types, + posIncrements, + posLengths, + finalOffset, + finalPosInc, + keywordAtts, + graphOffsetsAreCorrect, + payloads, + flags, + null); + } + public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -438,6 +483,33 @@ public static void assertTokenStreamContents( boolean[] keywordAtts, boolean graphOffsetsAreCorrect) throws IOException { + assertTokenStreamContents( + ts, + output, + startOffsets, + endOffsets, + types, + posIncrements, + posLengths, + finalOffset, + keywordAtts, + graphOffsetsAreCorrect, + null); + } + + public static void assertTokenStreamContents( + TokenStream ts, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + Integer finalOffset, + boolean[] keywordAtts, + boolean graphOffsetsAreCorrect, + float[] boost) + throws IOException { assertTokenStreamContents( ts, output, @@ -451,7 +523,8 @@ public static void assertTokenStreamContents( keywordAtts, graphOffsetsAreCorrect, null, - null); + null, + boost); } public static void assertTokenStreamContents( @@ -481,9 +554,36 @@ public static void assertTokenStreamContents( keywordAtts, graphOffsetsAreCorrect, payloads, + null, null); } + public static void assertTokenStreamContents( + TokenStream ts, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + Integer finalOffset, + boolean graphOffsetsAreCorrect, + float[] boost) + throws IOException { + assertTokenStreamContents( + ts, + output, + startOffsets, + endOffsets, + types, + posIncrements, + posLengths, + finalOffset, + null, + graphOffsetsAreCorrect, + boost); + } + public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -505,7 +605,8 @@ public static void assertTokenStreamContents( posLengths, finalOffset, null, - graphOffsetsAreCorrect); + graphOffsetsAreCorrect, + null); } public static void assertTokenStreamContents( @@ -522,6 +623,30 @@ public static void assertTokenStreamContents( ts, output, startOffsets, endOffsets, types, posIncrements, posLengths, finalOffset, true); } + public static void assertTokenStreamContents( + TokenStream ts, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + Integer finalOffset, + float[] boost) + throws IOException { + assertTokenStreamContents( + ts, + output, + startOffsets, + endOffsets, + types, + posIncrements, + posLengths, + finalOffset, + true, + boost); + } + public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -649,6 +774,21 @@ public static void assertAnalyzesTo( int[] posIncrements, int[] posLengths) throws IOException { + assertAnalyzesTo( + a, input, output, startOffsets, endOffsets, types, posIncrements, posLengths, null); + } + + public static void assertAnalyzesTo( + Analyzer a, + String input, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + float[] boost) + throws IOException { assertTokenStreamContents( a.tokenStream("dummy", input), output, @@ -657,7 +797,8 @@ public static void assertAnalyzesTo( types, posIncrements, posLengths, - input.length()); + input.length(), + boost); checkResetException(a, input); checkAnalysisConsistency(random(), a, true, input); }