diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index cfa640dca018..a1d2ed19d44e 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -32,8 +32,6 @@ New Features crash the JVM. To disable this feature, pass the following sysprop on Java command line: "-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler) -* GITHUB#12169: Introduce a new token filter to expand synonyms based on Word2Vec DL4j models. (Daniele Antuzi, Ilaria Petreti, Alessandro Benedetti) - Improvements --------------------- diff --git a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java index 988deaf99e59..8c245e7058c7 100644 --- a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java +++ b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java @@ -89,8 +89,6 @@ import org.apache.lucene.analysis.standard.StandardTokenizer; import org.apache.lucene.analysis.stempel.StempelStemmer; import org.apache.lucene.analysis.synonym.SynonymMap; -import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel; -import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider; import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; import org.apache.lucene.tests.analysis.MockTokenFilter; @@ -101,10 +99,8 @@ import org.apache.lucene.tests.util.automaton.AutomatonTestUtil; import org.apache.lucene.util.AttributeFactory; import org.apache.lucene.util.AttributeSource; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.CharsRef; import org.apache.lucene.util.IgnoreRandomChains; -import org.apache.lucene.util.TermAndVector; import org.apache.lucene.util.Version; import org.apache.lucene.util.automaton.Automaton; import org.apache.lucene.util.automaton.CharacterRunAutomaton; @@ -419,27 +415,6 @@ private String randomNonEmptyString(Random random) { } } }); - put( - Word2VecSynonymProvider.class, - random -> { - final int numEntries = atLeast(10); - final int vectorDimension = random.nextInt(99) + 1; - Word2VecModel model = new Word2VecModel(numEntries, vectorDimension); - for (int j = 0; j < numEntries; j++) { - String s = TestUtil.randomSimpleString(random, 10, 20); - float[] vec = new float[vectorDimension]; - for (int i = 0; i < vectorDimension; i++) { - vec[i] = random.nextFloat(); - } - model.addTermAndVector(new TermAndVector(new BytesRef(s), vec)); - } - try { - return new Word2VecSynonymProvider(model); - } catch (IOException e) { - Rethrow.rethrow(e); - return null; // unreachable code - } - }); put( DateFormat.class, random -> { diff --git a/lucene/analysis/common/src/java/module-info.java b/lucene/analysis/common/src/java/module-info.java index 15ad5a2b1af0..5679f0dde295 100644 --- a/lucene/analysis/common/src/java/module-info.java +++ b/lucene/analysis/common/src/java/module-info.java @@ -78,7 +78,6 @@ exports org.apache.lucene.analysis.sr; exports org.apache.lucene.analysis.sv; exports org.apache.lucene.analysis.synonym; - exports org.apache.lucene.analysis.synonym.word2vec; exports org.apache.lucene.analysis.ta; exports org.apache.lucene.analysis.te; exports org.apache.lucene.analysis.th; @@ -257,7 +256,6 @@ org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory, org.apache.lucene.analysis.synonym.SynonymFilterFactory, org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory, - org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory, org.apache.lucene.analysis.core.FlattenGraphFilterFactory, org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory, org.apache.lucene.analysis.te.TeluguStemFilterFactory, diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java deleted file mode 100644 index f022dd8eca67..000000000000 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import java.io.BufferedInputStream; -import java.io.BufferedReader; -import java.io.Closeable; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.nio.charset.StandardCharsets; -import java.util.Base64; -import java.util.Locale; -import java.util.zip.ZipEntry; -import java.util.zip.ZipInputStream; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.TermAndVector; - -/** - * Dl4jModelReader reads the file generated by the library Deeplearning4j and provide a - * Word2VecModel with normalized vectors - * - *

Dl4j Word2Vec documentation: - * https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to - * generate a model using dl4j: - * https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java - * - * @lucene.experimental - */ -public class Dl4jModelReader implements Closeable { - - private static final String MODEL_FILE_NAME_PREFIX = "syn0"; - - private final ZipInputStream word2VecModelZipFile; - - public Dl4jModelReader(InputStream stream) { - this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream)); - } - - public Word2VecModel read() throws IOException { - - ZipEntry entry; - while ((entry = word2VecModelZipFile.getNextEntry()) != null) { - String fileName = entry.getName(); - if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) { - BufferedReader reader = - new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8)); - - String header = reader.readLine(); - String[] headerValues = header.split(" "); - int dictionarySize = Integer.parseInt(headerValues[0]); - int vectorDimension = Integer.parseInt(headerValues[1]); - - Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension); - String line = reader.readLine(); - boolean isTermB64Encoded = false; - if (line != null) { - String[] tokens = line.split(" "); - isTermB64Encoded = - tokens[0].substring(0, 3).toLowerCase(Locale.ROOT).compareTo("b64") == 0; - model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded)); - } - while ((line = reader.readLine()) != null) { - String[] tokens = line.split(" "); - model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded)); - } - return model; - } - } - throw new IllegalArgumentException( - "Cannot read Dl4j word2vec model - '" - + MODEL_FILE_NAME_PREFIX - + "' file is missing in the zip. '" - + MODEL_FILE_NAME_PREFIX - + "' is a mandatory file containing the mapping between terms and vectors generated by the DL4j library."); - } - - private static TermAndVector extractTermAndVector( - String[] tokens, int vectorDimension, boolean isTermB64Encoded) { - BytesRef term = isTermB64Encoded ? decodeB64Term(tokens[0]) : new BytesRef((tokens[0])); - - float[] vector = new float[tokens.length - 1]; - - if (vectorDimension != vector.length) { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Word2Vec model file corrupted. " - + "Declared vectors of size %d but found vector of size %d for word %s (%s)", - vectorDimension, - vector.length, - tokens[0], - term.utf8ToString())); - } - - for (int i = 1; i < tokens.length; i++) { - vector[i - 1] = Float.parseFloat(tokens[i]); - } - return new TermAndVector(term, vector); - } - - static BytesRef decodeB64Term(String term) { - byte[] buffer = Base64.getDecoder().decode(term.substring(4)); - return new BytesRef(buffer, 0, buffer.length); - } - - @Override - public void close() throws IOException { - word2VecModelZipFile.close(); - } -} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java deleted file mode 100644 index 6719639b67d9..000000000000 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import java.io.IOException; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefHash; -import org.apache.lucene.util.TermAndVector; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; - -/** - * Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each - * word in dictionary - * - * @lucene.experimental - */ -public class Word2VecModel implements RandomAccessVectorValues { - - private final int dictionarySize; - private final int vectorDimension; - private final TermAndVector[] termsAndVectors; - private final BytesRefHash word2Vec; - private int loadedCount = 0; - - public Word2VecModel(int dictionarySize, int vectorDimension) { - this.dictionarySize = dictionarySize; - this.vectorDimension = vectorDimension; - this.termsAndVectors = new TermAndVector[dictionarySize]; - this.word2Vec = new BytesRefHash(); - } - - private Word2VecModel( - int dictionarySize, - int vectorDimension, - TermAndVector[] termsAndVectors, - BytesRefHash word2Vec) { - this.dictionarySize = dictionarySize; - this.vectorDimension = vectorDimension; - this.termsAndVectors = termsAndVectors; - this.word2Vec = word2Vec; - } - - public void addTermAndVector(TermAndVector modelEntry) { - modelEntry.normalizeVector(); - this.termsAndVectors[loadedCount++] = modelEntry; - this.word2Vec.add(modelEntry.getTerm()); - } - - @Override - public float[] vectorValue(int targetOrd) { - return termsAndVectors[targetOrd].getVector(); - } - - public float[] vectorValue(BytesRef term) { - int termOrd = this.word2Vec.find(term); - if (termOrd < 0) return null; - TermAndVector entry = this.termsAndVectors[termOrd]; - return (entry == null) ? null : entry.getVector(); - } - - public BytesRef termValue(int targetOrd) { - return termsAndVectors[targetOrd].getTerm(); - } - - @Override - public int dimension() { - return vectorDimension; - } - - @Override - public int size() { - return dictionarySize; - } - - @Override - public RandomAccessVectorValues copy() throws IOException { - return new Word2VecModel( - this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec); - } -} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java deleted file mode 100644 index 29f9d37d5fdf..000000000000 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import java.io.IOException; -import java.util.LinkedList; -import java.util.List; -import org.apache.lucene.analysis.TokenFilter; -import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.analysis.synonym.SynonymGraphFilter; -import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; -import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; -import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute; -import org.apache.lucene.analysis.tokenattributes.TypeAttribute; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefBuilder; -import org.apache.lucene.util.TermAndBoost; - -/** - * Applies single-token synonyms from a Word2Vec trained network to an incoming {@link TokenStream}. - * - * @lucene.experimental - */ -public final class Word2VecSynonymFilter extends TokenFilter { - - private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); - private final PositionIncrementAttribute posIncrementAtt = - addAttribute(PositionIncrementAttribute.class); - private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class); - private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class); - - private final Word2VecSynonymProvider synonymProvider; - private final int maxSynonymsPerTerm; - private final float minAcceptedSimilarity; - private final LinkedList synonymBuffer = new LinkedList<>(); - private State lastState; - - /** - * Apply previously built synonymProvider to incoming tokens. - * - * @param input input tokenstream - * @param synonymProvider synonym provider - * @param maxSynonymsPerTerm maximum number of result returned by the synonym search - * @param minAcceptedSimilarity minimal value of cosine similarity between the searched vector and - * the retrieved ones - */ - public Word2VecSynonymFilter( - TokenStream input, - Word2VecSynonymProvider synonymProvider, - int maxSynonymsPerTerm, - float minAcceptedSimilarity) { - super(input); - this.synonymProvider = synonymProvider; - this.maxSynonymsPerTerm = maxSynonymsPerTerm; - this.minAcceptedSimilarity = minAcceptedSimilarity; - } - - @Override - public boolean incrementToken() throws IOException { - - if (!synonymBuffer.isEmpty()) { - TermAndBoost synonym = synonymBuffer.pollFirst(); - clearAttributes(); - restoreState(this.lastState); - termAtt.setEmpty(); - termAtt.append(synonym.term.utf8ToString()); - typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM); - posLenAtt.setPositionLength(1); - posIncrementAtt.setPositionIncrement(0); - return true; - } - - if (input.incrementToken()) { - BytesRefBuilder bytesRefBuilder = new BytesRefBuilder(); - bytesRefBuilder.copyChars(termAtt.buffer(), 0, termAtt.length()); - BytesRef term = bytesRefBuilder.get(); - List synonyms = - this.synonymProvider.getSynonyms(term, maxSynonymsPerTerm, minAcceptedSimilarity); - if (synonyms.size() > 0) { - this.lastState = captureState(); - this.synonymBuffer.addAll(synonyms); - } - return true; - } - return false; - } - - @Override - public void reset() throws IOException { - super.reset(); - synonymBuffer.clear(); - } -} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java deleted file mode 100644 index 32b6288926fc..000000000000 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import java.io.IOException; -import java.util.Locale; -import java.util.Map; -import org.apache.lucene.analysis.TokenFilterFactory; -import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProviderFactory.Word2VecSupportedFormats; -import org.apache.lucene.util.ResourceLoader; -import org.apache.lucene.util.ResourceLoaderAware; - -/** - * Factory for {@link Word2VecSynonymFilter}. - * - * @lucene.experimental - * @lucene.spi {@value #NAME} - */ -public class Word2VecSynonymFilterFactory extends TokenFilterFactory - implements ResourceLoaderAware { - - /** SPI name */ - public static final String NAME = "Word2VecSynonym"; - - public static final int DEFAULT_MAX_SYNONYMS_PER_TERM = 5; - public static final float DEFAULT_MIN_ACCEPTED_SIMILARITY = 0.8f; - - private final int maxSynonymsPerTerm; - private final float minAcceptedSimilarity; - private final Word2VecSupportedFormats format; - private final String word2vecModelFileName; - - private Word2VecSynonymProvider synonymProvider; - - public Word2VecSynonymFilterFactory(Map args) { - super(args); - this.maxSynonymsPerTerm = getInt(args, "maxSynonymsPerTerm", DEFAULT_MAX_SYNONYMS_PER_TERM); - this.minAcceptedSimilarity = - getFloat(args, "minAcceptedSimilarity", DEFAULT_MIN_ACCEPTED_SIMILARITY); - this.word2vecModelFileName = require(args, "model"); - - String modelFormat = get(args, "format", "dl4j").toUpperCase(Locale.ROOT); - try { - this.format = Word2VecSupportedFormats.valueOf(modelFormat); - } catch (IllegalArgumentException exc) { - throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc); - } - - if (!args.isEmpty()) { - throw new IllegalArgumentException("Unknown parameters: " + args); - } - if (minAcceptedSimilarity <= 0 || minAcceptedSimilarity > 1) { - throw new IllegalArgumentException( - "minAcceptedSimilarity must be in the range (0, 1]. Found: " + minAcceptedSimilarity); - } - if (maxSynonymsPerTerm <= 0) { - throw new IllegalArgumentException( - "maxSynonymsPerTerm must be a positive integer greater than 0. Found: " - + maxSynonymsPerTerm); - } - } - - /** Default ctor for compatibility with SPI */ - public Word2VecSynonymFilterFactory() { - throw defaultCtorException(); - } - - Word2VecSynonymProvider getSynonymProvider() { - return this.synonymProvider; - } - - @Override - public TokenStream create(TokenStream input) { - return synonymProvider == null - ? input - : new Word2VecSynonymFilter( - input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity); - } - - @Override - public void inform(ResourceLoader loader) throws IOException { - this.synonymProvider = - Word2VecSynonymProviderFactory.getSynonymProvider(loader, word2vecModelFileName, format); - } -} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java deleted file mode 100644 index 1d1f06fc991f..000000000000 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH; -import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN; - -import java.io.IOException; -import java.util.LinkedList; -import java.util.List; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.TermAndBoost; -import org.apache.lucene.util.hnsw.HnswGraph; -import org.apache.lucene.util.hnsw.HnswGraphBuilder; -import org.apache.lucene.util.hnsw.HnswGraphSearcher; -import org.apache.lucene.util.hnsw.NeighborQueue; - -/** - * The Word2VecSynonymProvider generates the list of sysnonyms of a term. - * - * @lucene.experimental - */ -public class Word2VecSynonymProvider { - - private static final VectorSimilarityFunction SIMILARITY_FUNCTION = - VectorSimilarityFunction.DOT_PRODUCT; - private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32; - private final Word2VecModel word2VecModel; - private final HnswGraph hnswGraph; - - /** - * Word2VecSynonymProvider constructor - * - * @param model containing the set of TermAndVector entries - */ - public Word2VecSynonymProvider(Word2VecModel model) throws IOException { - word2VecModel = model; - - HnswGraphBuilder builder = - HnswGraphBuilder.create( - word2VecModel, - VECTOR_ENCODING, - SIMILARITY_FUNCTION, - DEFAULT_MAX_CONN, - DEFAULT_BEAM_WIDTH, - HnswGraphBuilder.randSeed); - this.hnswGraph = builder.build(word2VecModel.copy()); - } - - public List getSynonyms( - BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException { - - if (term == null) { - throw new IllegalArgumentException("Term must not be null"); - } - - LinkedList result = new LinkedList<>(); - float[] query = word2VecModel.vectorValue(term); - if (query != null) { - NeighborQueue synonyms = - HnswGraphSearcher.search( - query, - // The query vector is in the model. When looking for the top-k - // it's always the nearest neighbour of itself so, we look for the top-k+1 - maxSynonymsPerTerm + 1, - word2VecModel, - VECTOR_ENCODING, - SIMILARITY_FUNCTION, - hnswGraph, - null, - word2VecModel.size()); - - int size = synonyms.size(); - for (int i = 0; i < size; i++) { - float similarity = synonyms.topScore(); - int id = synonyms.pop(); - - BytesRef synonym = word2VecModel.termValue(id); - // We remove the original query term - if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) { - result.addFirst(new TermAndBoost(synonym, similarity)); - } - } - } - return result; - } -} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java deleted file mode 100644 index ea849e653cd6..000000000000 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import java.io.IOException; -import java.io.InputStream; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import org.apache.lucene.util.ResourceLoader; - -/** - * Supply Word2Vec Word2VecSynonymProvider cache avoiding that multiple instances of - * Word2VecSynonymFilterFactory will instantiate multiple instances of the same SynonymProvider. - * Assumes synonymProvider implementations are thread-safe. - */ -public class Word2VecSynonymProviderFactory { - - enum Word2VecSupportedFormats { - DL4J - } - - private static Map word2vecSynonymProviders = - new ConcurrentHashMap<>(); - - public static Word2VecSynonymProvider getSynonymProvider( - ResourceLoader loader, String modelFileName, Word2VecSupportedFormats format) - throws IOException { - Word2VecSynonymProvider synonymProvider = word2vecSynonymProviders.get(modelFileName); - if (synonymProvider == null) { - try (InputStream stream = loader.openResource(modelFileName)) { - try (Dl4jModelReader reader = getModelReader(format, stream)) { - synonymProvider = new Word2VecSynonymProvider(reader.read()); - } - } - word2vecSynonymProviders.put(modelFileName, synonymProvider); - } - return synonymProvider; - } - - private static Dl4jModelReader getModelReader( - Word2VecSupportedFormats format, InputStream stream) { - switch (format) { - case DL4J: - return new Dl4jModelReader(stream); - } - return null; - } -} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java deleted file mode 100644 index e8d69ab3cf9b..000000000000 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** Analysis components for Synonyms using Word2Vec model. */ -package org.apache.lucene.analysis.synonym.word2vec; diff --git a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory index 1e4e17eaeadf..19a34b7840a8 100644 --- a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory +++ b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory @@ -118,7 +118,6 @@ org.apache.lucene.analysis.sv.SwedishLightStemFilterFactory org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory org.apache.lucene.analysis.synonym.SynonymFilterFactory org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory -org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory org.apache.lucene.analysis.core.FlattenGraphFilterFactory org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory org.apache.lucene.analysis.te.TeluguStemFilterFactory diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java deleted file mode 100644 index 213dcdaccd33..000000000000 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.Base64; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.BytesRef; -import org.junit.Test; - -public class TestDl4jModelReader extends LuceneTestCase { - - private static final String MODEL_FILE = "word2vec-model.zip"; - private static final String MODEL_EMPTY_FILE = "word2vec-empty-model.zip"; - private static final String CORRUPTED_VECTOR_DIMENSION_MODEL_FILE = - "word2vec-corrupted-vector-dimension-model.zip"; - - InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_FILE); - Dl4jModelReader unit = new Dl4jModelReader(stream); - - @Test - public void read_zipFileWithMetadata_shouldReturnDictionarySize() throws Exception { - Word2VecModel model = unit.read(); - long expectedDictionarySize = 235; - assertEquals(expectedDictionarySize, model.size()); - } - - @Test - public void read_zipFileWithMetadata_shouldReturnVectorLength() throws Exception { - Word2VecModel model = unit.read(); - int expectedVectorDimension = 100; - assertEquals(expectedVectorDimension, model.dimension()); - } - - @Test - public void read_zipFile_shouldReturnDecodedTerm() throws Exception { - Word2VecModel model = unit.read(); - BytesRef expectedDecodedFirstTerm = new BytesRef("it"); - assertEquals(expectedDecodedFirstTerm, model.termValue(0)); - } - - @Test - public void decodeTerm_encodedTerm_shouldReturnDecodedTerm() throws Exception { - byte[] originalInput = "lucene".getBytes(StandardCharsets.UTF_8); - String B64encodedLuceneTerm = Base64.getEncoder().encodeToString(originalInput); - String word2vecEncodedLuceneTerm = "B64:" + B64encodedLuceneTerm; - assertEquals(new BytesRef("lucene"), Dl4jModelReader.decodeB64Term(word2vecEncodedLuceneTerm)); - } - - @Test - public void read_EmptyZipFile_shouldThrowException() throws Exception { - try (InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_EMPTY_FILE)) { - Dl4jModelReader unit = new Dl4jModelReader(stream); - expectThrows(IllegalArgumentException.class, unit::read); - } - } - - @Test - public void read_corruptedVectorDimensionModelFile_shouldThrowException() throws Exception { - try (InputStream stream = - TestDl4jModelReader.class.getResourceAsStream(CORRUPTED_VECTOR_DIMENSION_MODEL_FILE)) { - Dl4jModelReader unit = new Dl4jModelReader(stream); - expectThrows(RuntimeException.class, unit::read); - } - } -} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java deleted file mode 100644 index 3999931dd758..000000000000 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.analysis.Tokenizer; -import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; -import org.apache.lucene.tests.analysis.MockTokenizer; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.TermAndVector; -import org.junit.Test; - -public class TestWord2VecSynonymFilter extends BaseTokenStreamTestCase { - - @Test - public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() throws Exception { - int maxSynonymPerTerm = 10; - float minAcceptedSimilarity = 0.9f; - Word2VecModel model = new Word2VecModel(6, 2); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); - model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); - model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); - model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); - - Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); - - Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity); - assertAnalyzesTo( - a, - "pre a post", // input - new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output - new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset - new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset - new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types - new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements - new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts - a.close(); - } - - @Test - public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() throws Exception { - int maxSynonymPerTerm = 2; - float minAcceptedSimilarity = 0.9f; - Word2VecModel model = new Word2VecModel(5, 2); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); - model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); - model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); - - Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); - - Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity); - assertAnalyzesTo( - a, - "pre a post", // input - new String[] {"pre", "a", "d", "e", "post"}, // output - new int[] {0, 4, 4, 4, 6}, // start offset - new int[] {3, 5, 5, 5, 10}, // end offset - new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types - new int[] {1, 1, 0, 0, 1}, // posIncrements - new int[] {1, 1, 1, 1, 1}); // posLenghts - a.close(); - } - - @Test - public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Exception { - Word2VecModel model = new Word2VecModel(8, 2); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); - model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); - model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); - model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {1, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("post"), new float[] {-10, -11})); - model.addTermAndVector(new TermAndVector(new BytesRef("after"), new float[] {-8, -10})); - - Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); - - Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f); - assertAnalyzesTo( - a, - "pre a post", // input - new String[] {"pre", "a", "d", "e", "c", "b", "post", "after"}, // output - new int[] {0, 4, 4, 4, 4, 4, 6, 6}, // start offset - new int[] {3, 5, 5, 5, 5, 5, 10, 10}, // end offset - new String[] { // types - "word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM" - }, - new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements - new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths - a.close(); - } - - @Test - public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms() - throws Exception { - Word2VecModel model = new Word2VecModel(4, 2); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8})); - model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10})); - model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, -10})); - - Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); - - Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f); - assertAnalyzesTo( - a, - "pre a post", // input - new String[] {"pre", "a", "post"}, // output - new int[] {0, 4, 6}, // start offset - new int[] {3, 5, 10}, // end offset - new String[] {"word", "word", "word"}, // types - new int[] {1, 1, 1}, // posIncrements - new int[] {1, 1, 1}); // posLengths - a.close(); - } - - private Analyzer getAnalyzer( - Word2VecSynonymProvider synonymProvider, - int maxSynonymsPerTerm, - float minAcceptedSimilarity) { - return new Analyzer() { - @Override - protected TokenStreamComponents createComponents(String fieldName) { - Tokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false); - // Make a local variable so testRandomHuge doesn't share it across threads! - Word2VecSynonymFilter synFilter = - new Word2VecSynonymFilter( - tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity); - return new TokenStreamComponents(tokenizer, synFilter); - } - }; - } -} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java deleted file mode 100644 index 007fedf4abed..000000000000 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import org.apache.lucene.tests.analysis.BaseTokenStreamFactoryTestCase; -import org.apache.lucene.util.ClasspathResourceLoader; -import org.apache.lucene.util.ResourceLoader; -import org.junit.Test; - -public class TestWord2VecSynonymFilterFactory extends BaseTokenStreamFactoryTestCase { - - public static final String FACTORY_NAME = "Word2VecSynonym"; - private static final String WORD2VEC_MODEL_FILE = "word2vec-model.zip"; - - @Test - public void testInform() throws Exception { - ResourceLoader loader = new ClasspathResourceLoader(getClass()); - assertTrue("loader is null and it shouldn't be", loader != null); - Word2VecSynonymFilterFactory factory = - (Word2VecSynonymFilterFactory) - tokenFilterFactory( - FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "minAcceptedSimilarity", "0.7"); - - Word2VecSynonymProvider synonymProvider = factory.getSynonymProvider(); - assertNotEquals(null, synonymProvider); - } - - @Test - public void missingRequiredArgument_shouldThrowException() throws Exception { - IllegalArgumentException expected = - expectThrows( - IllegalArgumentException.class, - () -> { - tokenFilterFactory( - FACTORY_NAME, - "format", - "dl4j", - "minAcceptedSimilarity", - "0.7", - "maxSynonymsPerTerm", - "10"); - }); - assertTrue(expected.getMessage().contains("Configuration Error: missing parameter 'model'")); - } - - @Test - public void unsupportedModelFormat_shouldThrowException() throws Exception { - IllegalArgumentException expected = - expectThrows( - IllegalArgumentException.class, - () -> { - tokenFilterFactory( - FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "format", "bogusValue"); - }); - assertTrue(expected.getMessage().contains("Model format 'BOGUSVALUE' not supported")); - } - - @Test - public void bogusArgument_shouldThrowException() throws Exception { - IllegalArgumentException expected = - expectThrows( - IllegalArgumentException.class, - () -> { - tokenFilterFactory( - FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "bogusArg", "bogusValue"); - }); - assertTrue(expected.getMessage().contains("Unknown parameters")); - } - - @Test - public void illegalArguments_shouldThrowException() throws Exception { - IllegalArgumentException expected = - expectThrows( - IllegalArgumentException.class, - () -> { - tokenFilterFactory( - FACTORY_NAME, - "model", - WORD2VEC_MODEL_FILE, - "minAcceptedSimilarity", - "2", - "maxSynonymsPerTerm", - "10"); - }); - assertTrue( - expected - .getMessage() - .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 2")); - - expected = - expectThrows( - IllegalArgumentException.class, - () -> { - tokenFilterFactory( - FACTORY_NAME, - "model", - WORD2VEC_MODEL_FILE, - "minAcceptedSimilarity", - "0", - "maxSynonymsPerTerm", - "10"); - }); - assertTrue( - expected - .getMessage() - .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 0")); - - expected = - expectThrows( - IllegalArgumentException.class, - () -> { - tokenFilterFactory( - FACTORY_NAME, - "model", - WORD2VEC_MODEL_FILE, - "minAcceptedSimilarity", - "0.7", - "maxSynonymsPerTerm", - "-1"); - }); - assertTrue( - expected - .getMessage() - .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: -1")); - - expected = - expectThrows( - IllegalArgumentException.class, - () -> { - tokenFilterFactory( - FACTORY_NAME, - "model", - WORD2VEC_MODEL_FILE, - "minAcceptedSimilarity", - "0.7", - "maxSynonymsPerTerm", - "0"); - }); - assertTrue( - expected - .getMessage() - .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: 0")); - } -} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java deleted file mode 100644 index 0803406bdafa..000000000000 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.analysis.synonym.word2vec; - -import java.io.IOException; -import java.util.List; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.TermAndBoost; -import org.apache.lucene.util.TermAndVector; -import org.junit.Test; - -public class TestWord2VecSynonymProvider extends LuceneTestCase { - - private static final int MAX_SYNONYMS_PER_TERM = 10; - private static final float MIN_ACCEPTED_SIMILARITY = 0.85f; - - private final Word2VecSynonymProvider unit; - - public TestWord2VecSynonymProvider() throws IOException { - Word2VecModel model = new Word2VecModel(2, 3); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {0.24f, 0.78f, 0.28f})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {0.44f, 0.01f, 0.81f})); - unit = new Word2VecSynonymProvider(model); - } - - @Test - public void getSynonyms_nullToken_shouldThrowException() { - expectThrows( - IllegalArgumentException.class, - () -> unit.getSynonyms(null, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY)); - } - - @Test - public void getSynonyms_shouldReturnSynonymsBasedOnMinAcceptedSimilarity() throws Exception { - Word2VecModel model = new Word2VecModel(6, 2); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); - model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); - model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); - model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); - - Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); - - BytesRef inputTerm = new BytesRef("a"); - String[] expectedSynonyms = {"d", "e", "c", "b"}; - List actualSynonymsResults = - unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); - - assertEquals(4, actualSynonymsResults.size()); - for (int i = 0; i < expectedSynonyms.length; i++) { - assertEquals(new BytesRef(expectedSynonyms[i]), actualSynonymsResults.get(i).term); - } - } - - @Test - public void getSynonyms_shouldReturnSynonymsBoost() throws Exception { - Word2VecModel model = new Word2VecModel(3, 2); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {1, 1})); - model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {99, 101})); - - Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); - - BytesRef inputTerm = new BytesRef("a"); - List actualSynonymsResults = - unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); - - BytesRef expectedFirstSynonymTerm = new BytesRef("b"); - double expectedFirstSynonymBoost = 1.0; - assertEquals(expectedFirstSynonymTerm, actualSynonymsResults.get(0).term); - assertEquals(expectedFirstSynonymBoost, actualSynonymsResults.get(0).boost, 0.001f); - } - - @Test - public void noSynonymsWithinAcceptedSimilarity_shouldReturnNoSynonyms() throws Exception { - Word2VecModel model = new Word2VecModel(4, 2); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8})); - model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10})); - model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {6, -6})); - - Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); - - BytesRef inputTerm = newBytesRef("a"); - List actualSynonymsResults = - unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); - assertEquals(0, actualSynonymsResults.size()); - } - - @Test - public void testModel_shouldReturnNormalizedVectors() { - Word2VecModel model = new Word2VecModel(4, 2); - model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); - model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); - model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); - - float[] vectorIdA = model.vectorValue(new BytesRef("a")); - float[] vectorIdF = model.vectorValue(new BytesRef("f")); - assertArrayEquals(new float[] {0.70710f, 0.70710f}, vectorIdA, 0.001f); - assertArrayEquals(new float[] {-0.0995f, 0.99503f}, vectorIdF, 0.001f); - } - - @Test - public void normalizedVector_shouldReturnModule1() { - TermAndVector synonymTerm = new TermAndVector(new BytesRef("a"), new float[] {10, 10}); - synonymTerm.normalizeVector(); - float[] vector = synonymTerm.getVector(); - float len = 0; - for (int i = 0; i < vector.length; i++) { - len += vector[i] * vector[i]; - } - assertEquals(1, Math.sqrt(len), 0.0001f); - } -} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip deleted file mode 100644 index e25693dd83cf..000000000000 Binary files a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip and /dev/null differ diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip deleted file mode 100644 index 57d7832dd787..000000000000 Binary files a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip and /dev/null differ diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip deleted file mode 100644 index 6d31b8d5a3fa..000000000000 Binary files a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip and /dev/null differ diff --git a/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java b/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java index adf30baec294..3ad17dca5ce9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java @@ -62,6 +62,20 @@ public class QueryBuilder { protected boolean enableGraphQueries = true; protected boolean autoGenerateMultiTermSynonymsPhraseQuery = false; + /** Wraps a term and boost */ + public static class TermAndBoost { + /** the term */ + public final BytesRef term; + /** the boost */ + public final float boost; + + /** Creates a new TermAndBoost */ + public TermAndBoost(BytesRef term, float boost) { + this.term = BytesRef.deepCopyOf(term); + this.boost = boost; + } + } + /** Creates a new QueryBuilder using the given analyzer. */ public QueryBuilder(Analyzer analyzer) { this.analyzer = analyzer; diff --git a/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java b/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java deleted file mode 100644 index 0c7958a99a08..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.util; - -/** Wraps a term and boost */ -public class TermAndBoost { - /** the term */ - public final BytesRef term; - /** the boost */ - public final float boost; - - /** Creates a new TermAndBoost */ - public TermAndBoost(BytesRef term, float boost) { - this.term = BytesRef.deepCopyOf(term); - this.boost = boost; - } -} diff --git a/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java deleted file mode 100644 index 1ade19a19803..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.util; - -import java.util.Locale; - -/** - * Word2Vec unit composed by a term with the associated vector - * - * @lucene.experimental - */ -public class TermAndVector { - - private final BytesRef term; - private final float[] vector; - - public TermAndVector(BytesRef term, float[] vector) { - this.term = term; - this.vector = vector; - } - - public BytesRef getTerm() { - return this.term; - } - - public float[] getVector() { - return this.vector; - } - - public int size() { - return vector.length; - } - - public void normalizeVector() { - float vectorLength = 0; - for (int i = 0; i < vector.length; i++) { - vectorLength += vector[i] * vector[i]; - } - vectorLength = (float) Math.sqrt(vectorLength); - for (int i = 0; i < vector.length; i++) { - vector[i] /= vectorLength; - } - } - - @Override - public String toString() { - StringBuilder builder = new StringBuilder(this.term.utf8ToString()); - builder.append(" ["); - if (vector.length > 0) { - for (int i = 0; i < vector.length - 1; i++) { - builder.append(String.format(Locale.ROOT, "%.3f,", vector[i])); - } - builder.append(String.format(Locale.ROOT, "%.3f]", vector[vector.length - 1])); - } - return builder.toString(); - } -} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index b01be4dbaf01..2c5e84be2859 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -41,17 +41,8 @@ */ public final class HnswGraphBuilder { - /** Default number of maximum connections per node */ - public static final int DEFAULT_MAX_CONN = 16; - - /** - * Default number of the size of the queue maintained while searching during a graph construction. - */ - public static final int DEFAULT_BEAM_WIDTH = 100; - /** Default random seed for level generation * */ private static final long DEFAULT_RAND_SEED = 42; - /** A name for the HNSW component for the info-stream * */ public static final String HNSW_COMPONENT = "HNSW"; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 372384df5572..08f089430ba5 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -54,7 +54,6 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; -import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.junit.After; import org.junit.Before; @@ -63,7 +62,7 @@ public class TestKnnGraph extends LuceneTestCase { private static final String KNN_GRAPH_FIELD = "vector"; - private static int M = HnswGraphBuilder.DEFAULT_MAX_CONN; + private static int M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN; private Codec codec; private Codec float32Codec; @@ -81,7 +80,7 @@ public void setup() { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); } }; @@ -93,7 +92,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); } }; @@ -104,7 +103,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); } }; } @@ -116,7 +115,7 @@ private VectorEncoding randomVectorEncoding() { @After public void cleanup() { - M = HnswGraphBuilder.DEFAULT_MAX_CONN; + M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN; } /** Basic test of creating documents in a graph */ diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java index f9cd607ccacc..b2ce16d80776 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java @@ -55,7 +55,6 @@ import org.apache.lucene.document.TextField; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableFieldType; -import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; @@ -155,8 +154,7 @@ public static void assertTokenStreamContents( boolean[] keywordAtts, boolean graphOffsetsAreCorrect, byte[][] payloads, - int[] flags, - float[] boost) + int[] flags) throws IOException { assertNotNull(output); CheckClearAttributesAttribute checkClearAtt = @@ -223,12 +221,6 @@ public static void assertTokenStreamContents( flagsAtt = ts.getAttribute(FlagsAttribute.class); } - BoostAttribute boostAtt = null; - if (boost != null) { - assertTrue("has no BoostAttribute", ts.hasAttribute(BoostAttribute.class)); - boostAtt = ts.getAttribute(BoostAttribute.class); - } - // Maps position to the start/end offset: final Map posToStartOffset = new HashMap<>(); final Map posToEndOffset = new HashMap<>(); @@ -251,7 +243,6 @@ public static void assertTokenStreamContents( if (payloadAtt != null) payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24})); if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's - if (boostAtt != null) boostAtt.setBoost(-1f); checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before assertTrue("token " + i + " does not exist", ts.incrementToken()); @@ -287,9 +278,6 @@ public static void assertTokenStreamContents( if (flagsAtt != null) { assertEquals("flagsAtt " + i + " term=" + termAtt, flags[i], flagsAtt.getFlags()); } - if (boostAtt != null) { - assertEquals("boostAtt " + i + " term=" + termAtt, boost[i], boostAtt.getBoost(), 0.001); - } if (payloads != null) { if (payloads[i] != null) { assertEquals("payloads " + i, new BytesRef(payloads[i]), payloadAtt.getPayload()); @@ -417,7 +405,6 @@ public static void assertTokenStreamContents( if (payloadAtt != null) payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24})); if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's - if (boostAtt != null) boostAtt.setBoost(-1); checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before @@ -439,38 +426,6 @@ public static void assertTokenStreamContents( ts.close(); } - public static void assertTokenStreamContents( - TokenStream ts, - String[] output, - int[] startOffsets, - int[] endOffsets, - String[] types, - int[] posIncrements, - int[] posLengths, - Integer finalOffset, - Integer finalPosInc, - boolean[] keywordAtts, - boolean graphOffsetsAreCorrect, - byte[][] payloads, - int[] flags) - throws IOException { - assertTokenStreamContents( - ts, - output, - startOffsets, - endOffsets, - types, - posIncrements, - posLengths, - finalOffset, - finalPosInc, - keywordAtts, - graphOffsetsAreCorrect, - payloads, - flags, - null); - } - public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -483,33 +438,6 @@ public static void assertTokenStreamContents( boolean[] keywordAtts, boolean graphOffsetsAreCorrect) throws IOException { - assertTokenStreamContents( - ts, - output, - startOffsets, - endOffsets, - types, - posIncrements, - posLengths, - finalOffset, - keywordAtts, - graphOffsetsAreCorrect, - null); - } - - public static void assertTokenStreamContents( - TokenStream ts, - String[] output, - int[] startOffsets, - int[] endOffsets, - String[] types, - int[] posIncrements, - int[] posLengths, - Integer finalOffset, - boolean[] keywordAtts, - boolean graphOffsetsAreCorrect, - float[] boost) - throws IOException { assertTokenStreamContents( ts, output, @@ -523,8 +451,7 @@ public static void assertTokenStreamContents( keywordAtts, graphOffsetsAreCorrect, null, - null, - boost); + null); } public static void assertTokenStreamContents( @@ -554,36 +481,9 @@ public static void assertTokenStreamContents( keywordAtts, graphOffsetsAreCorrect, payloads, - null, null); } - public static void assertTokenStreamContents( - TokenStream ts, - String[] output, - int[] startOffsets, - int[] endOffsets, - String[] types, - int[] posIncrements, - int[] posLengths, - Integer finalOffset, - boolean graphOffsetsAreCorrect, - float[] boost) - throws IOException { - assertTokenStreamContents( - ts, - output, - startOffsets, - endOffsets, - types, - posIncrements, - posLengths, - finalOffset, - null, - graphOffsetsAreCorrect, - boost); - } - public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -605,8 +505,7 @@ public static void assertTokenStreamContents( posLengths, finalOffset, null, - graphOffsetsAreCorrect, - null); + graphOffsetsAreCorrect); } public static void assertTokenStreamContents( @@ -623,30 +522,6 @@ public static void assertTokenStreamContents( ts, output, startOffsets, endOffsets, types, posIncrements, posLengths, finalOffset, true); } - public static void assertTokenStreamContents( - TokenStream ts, - String[] output, - int[] startOffsets, - int[] endOffsets, - String[] types, - int[] posIncrements, - int[] posLengths, - Integer finalOffset, - float[] boost) - throws IOException { - assertTokenStreamContents( - ts, - output, - startOffsets, - endOffsets, - types, - posIncrements, - posLengths, - finalOffset, - true, - boost); - } - public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -774,21 +649,6 @@ public static void assertAnalyzesTo( int[] posIncrements, int[] posLengths) throws IOException { - assertAnalyzesTo( - a, input, output, startOffsets, endOffsets, types, posIncrements, posLengths, null); - } - - public static void assertAnalyzesTo( - Analyzer a, - String input, - String[] output, - int[] startOffsets, - int[] endOffsets, - String[] types, - int[] posIncrements, - int[] posLengths, - float[] boost) - throws IOException { assertTokenStreamContents( a.tokenStream("dummy", input), output, @@ -797,8 +657,7 @@ public static void assertAnalyzesTo( types, posIncrements, posLengths, - input.length(), - boost); + input.length()); checkResetException(a, input); checkAnalysisConsistency(random(), a, true, input); }