diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index a7016627cd78..91f1ea039a71 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -135,6 +135,8 @@ New Features
crash the JVM. To disable this feature, pass the following sysprop on Java command line:
"-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler)
+* GITHUB#12169: Introduce a new token filter to expand synonyms based on Word2Vec DL4j models. (Daniele Antuzi, Ilaria Petreti, Alessandro Benedetti)
+
Improvements
---------------------
diff --git a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java
index 8c245e7058c7..988deaf99e59 100644
--- a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java
+++ b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java
@@ -89,6 +89,8 @@
import org.apache.lucene.analysis.standard.StandardTokenizer;
import org.apache.lucene.analysis.stempel.StempelStemmer;
import org.apache.lucene.analysis.synonym.SynonymMap;
+import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel;
+import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider;
import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.tests.analysis.MockTokenFilter;
@@ -99,8 +101,10 @@
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil;
import org.apache.lucene.util.AttributeFactory;
import org.apache.lucene.util.AttributeSource;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.IgnoreRandomChains;
+import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.Version;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
@@ -415,6 +419,27 @@ private String randomNonEmptyString(Random random) {
}
}
});
+ put(
+ Word2VecSynonymProvider.class,
+ random -> {
+ final int numEntries = atLeast(10);
+ final int vectorDimension = random.nextInt(99) + 1;
+ Word2VecModel model = new Word2VecModel(numEntries, vectorDimension);
+ for (int j = 0; j < numEntries; j++) {
+ String s = TestUtil.randomSimpleString(random, 10, 20);
+ float[] vec = new float[vectorDimension];
+ for (int i = 0; i < vectorDimension; i++) {
+ vec[i] = random.nextFloat();
+ }
+ model.addTermAndVector(new TermAndVector(new BytesRef(s), vec));
+ }
+ try {
+ return new Word2VecSynonymProvider(model);
+ } catch (IOException e) {
+ Rethrow.rethrow(e);
+ return null; // unreachable code
+ }
+ });
put(
DateFormat.class,
random -> {
diff --git a/lucene/analysis/common/src/java/module-info.java b/lucene/analysis/common/src/java/module-info.java
index 96ec0d293d3d..84654191961b 100644
--- a/lucene/analysis/common/src/java/module-info.java
+++ b/lucene/analysis/common/src/java/module-info.java
@@ -79,6 +79,7 @@
exports org.apache.lucene.analysis.sr;
exports org.apache.lucene.analysis.sv;
exports org.apache.lucene.analysis.synonym;
+ exports org.apache.lucene.analysis.synonym.word2vec;
exports org.apache.lucene.analysis.ta;
exports org.apache.lucene.analysis.te;
exports org.apache.lucene.analysis.th;
@@ -257,6 +258,7 @@
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory,
org.apache.lucene.analysis.synonym.SynonymFilterFactory,
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory,
+ org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory,
org.apache.lucene.analysis.core.FlattenGraphFilterFactory,
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory,
org.apache.lucene.analysis.te.TeluguStemFilterFactory,
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java
new file mode 100644
index 000000000000..f022dd8eca67
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedReader;
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.Locale;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipInputStream;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TermAndVector;
+
+/**
+ * Dl4jModelReader reads the file generated by the library Deeplearning4j and provide a
+ * Word2VecModel with normalized vectors
+ *
+ *
Dl4j Word2Vec documentation:
+ * https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to
+ * generate a model using dl4j:
+ * https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java
+ *
+ * @lucene.experimental
+ */
+public class Dl4jModelReader implements Closeable {
+
+ private static final String MODEL_FILE_NAME_PREFIX = "syn0";
+
+ private final ZipInputStream word2VecModelZipFile;
+
+ public Dl4jModelReader(InputStream stream) {
+ this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream));
+ }
+
+ public Word2VecModel read() throws IOException {
+
+ ZipEntry entry;
+ while ((entry = word2VecModelZipFile.getNextEntry()) != null) {
+ String fileName = entry.getName();
+ if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) {
+ BufferedReader reader =
+ new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8));
+
+ String header = reader.readLine();
+ String[] headerValues = header.split(" ");
+ int dictionarySize = Integer.parseInt(headerValues[0]);
+ int vectorDimension = Integer.parseInt(headerValues[1]);
+
+ Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension);
+ String line = reader.readLine();
+ boolean isTermB64Encoded = false;
+ if (line != null) {
+ String[] tokens = line.split(" ");
+ isTermB64Encoded =
+ tokens[0].substring(0, 3).toLowerCase(Locale.ROOT).compareTo("b64") == 0;
+ model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
+ }
+ while ((line = reader.readLine()) != null) {
+ String[] tokens = line.split(" ");
+ model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
+ }
+ return model;
+ }
+ }
+ throw new IllegalArgumentException(
+ "Cannot read Dl4j word2vec model - '"
+ + MODEL_FILE_NAME_PREFIX
+ + "' file is missing in the zip. '"
+ + MODEL_FILE_NAME_PREFIX
+ + "' is a mandatory file containing the mapping between terms and vectors generated by the DL4j library.");
+ }
+
+ private static TermAndVector extractTermAndVector(
+ String[] tokens, int vectorDimension, boolean isTermB64Encoded) {
+ BytesRef term = isTermB64Encoded ? decodeB64Term(tokens[0]) : new BytesRef((tokens[0]));
+
+ float[] vector = new float[tokens.length - 1];
+
+ if (vectorDimension != vector.length) {
+ throw new RuntimeException(
+ String.format(
+ Locale.ROOT,
+ "Word2Vec model file corrupted. "
+ + "Declared vectors of size %d but found vector of size %d for word %s (%s)",
+ vectorDimension,
+ vector.length,
+ tokens[0],
+ term.utf8ToString()));
+ }
+
+ for (int i = 1; i < tokens.length; i++) {
+ vector[i - 1] = Float.parseFloat(tokens[i]);
+ }
+ return new TermAndVector(term, vector);
+ }
+
+ static BytesRef decodeB64Term(String term) {
+ byte[] buffer = Base64.getDecoder().decode(term.substring(4));
+ return new BytesRef(buffer, 0, buffer.length);
+ }
+
+ @Override
+ public void close() throws IOException {
+ word2VecModelZipFile.close();
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java
new file mode 100644
index 000000000000..6719639b67d9
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefHash;
+import org.apache.lucene.util.TermAndVector;
+import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
+
+/**
+ * Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
+ * word in dictionary
+ *
+ * @lucene.experimental
+ */
+public class Word2VecModel implements RandomAccessVectorValues {
+
+ private final int dictionarySize;
+ private final int vectorDimension;
+ private final TermAndVector[] termsAndVectors;
+ private final BytesRefHash word2Vec;
+ private int loadedCount = 0;
+
+ public Word2VecModel(int dictionarySize, int vectorDimension) {
+ this.dictionarySize = dictionarySize;
+ this.vectorDimension = vectorDimension;
+ this.termsAndVectors = new TermAndVector[dictionarySize];
+ this.word2Vec = new BytesRefHash();
+ }
+
+ private Word2VecModel(
+ int dictionarySize,
+ int vectorDimension,
+ TermAndVector[] termsAndVectors,
+ BytesRefHash word2Vec) {
+ this.dictionarySize = dictionarySize;
+ this.vectorDimension = vectorDimension;
+ this.termsAndVectors = termsAndVectors;
+ this.word2Vec = word2Vec;
+ }
+
+ public void addTermAndVector(TermAndVector modelEntry) {
+ modelEntry.normalizeVector();
+ this.termsAndVectors[loadedCount++] = modelEntry;
+ this.word2Vec.add(modelEntry.getTerm());
+ }
+
+ @Override
+ public float[] vectorValue(int targetOrd) {
+ return termsAndVectors[targetOrd].getVector();
+ }
+
+ public float[] vectorValue(BytesRef term) {
+ int termOrd = this.word2Vec.find(term);
+ if (termOrd < 0) return null;
+ TermAndVector entry = this.termsAndVectors[termOrd];
+ return (entry == null) ? null : entry.getVector();
+ }
+
+ public BytesRef termValue(int targetOrd) {
+ return termsAndVectors[targetOrd].getTerm();
+ }
+
+ @Override
+ public int dimension() {
+ return vectorDimension;
+ }
+
+ @Override
+ public int size() {
+ return dictionarySize;
+ }
+
+ @Override
+ public RandomAccessVectorValues copy() throws IOException {
+ return new Word2VecModel(
+ this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec);
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java
new file mode 100644
index 000000000000..29f9d37d5fdf
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.List;
+import org.apache.lucene.analysis.TokenFilter;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.synonym.SynonymGraphFilter;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
+import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
+import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefBuilder;
+import org.apache.lucene.util.TermAndBoost;
+
+/**
+ * Applies single-token synonyms from a Word2Vec trained network to an incoming {@link TokenStream}.
+ *
+ * @lucene.experimental
+ */
+public final class Word2VecSynonymFilter extends TokenFilter {
+
+ private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
+ private final PositionIncrementAttribute posIncrementAtt =
+ addAttribute(PositionIncrementAttribute.class);
+ private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class);
+ private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class);
+
+ private final Word2VecSynonymProvider synonymProvider;
+ private final int maxSynonymsPerTerm;
+ private final float minAcceptedSimilarity;
+ private final LinkedList synonymBuffer = new LinkedList<>();
+ private State lastState;
+
+ /**
+ * Apply previously built synonymProvider to incoming tokens.
+ *
+ * @param input input tokenstream
+ * @param synonymProvider synonym provider
+ * @param maxSynonymsPerTerm maximum number of result returned by the synonym search
+ * @param minAcceptedSimilarity minimal value of cosine similarity between the searched vector and
+ * the retrieved ones
+ */
+ public Word2VecSynonymFilter(
+ TokenStream input,
+ Word2VecSynonymProvider synonymProvider,
+ int maxSynonymsPerTerm,
+ float minAcceptedSimilarity) {
+ super(input);
+ this.synonymProvider = synonymProvider;
+ this.maxSynonymsPerTerm = maxSynonymsPerTerm;
+ this.minAcceptedSimilarity = minAcceptedSimilarity;
+ }
+
+ @Override
+ public boolean incrementToken() throws IOException {
+
+ if (!synonymBuffer.isEmpty()) {
+ TermAndBoost synonym = synonymBuffer.pollFirst();
+ clearAttributes();
+ restoreState(this.lastState);
+ termAtt.setEmpty();
+ termAtt.append(synonym.term.utf8ToString());
+ typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM);
+ posLenAtt.setPositionLength(1);
+ posIncrementAtt.setPositionIncrement(0);
+ return true;
+ }
+
+ if (input.incrementToken()) {
+ BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
+ bytesRefBuilder.copyChars(termAtt.buffer(), 0, termAtt.length());
+ BytesRef term = bytesRefBuilder.get();
+ List synonyms =
+ this.synonymProvider.getSynonyms(term, maxSynonymsPerTerm, minAcceptedSimilarity);
+ if (synonyms.size() > 0) {
+ this.lastState = captureState();
+ this.synonymBuffer.addAll(synonyms);
+ }
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void reset() throws IOException {
+ super.reset();
+ synonymBuffer.clear();
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java
new file mode 100644
index 000000000000..32b6288926fc
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Map;
+import org.apache.lucene.analysis.TokenFilterFactory;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProviderFactory.Word2VecSupportedFormats;
+import org.apache.lucene.util.ResourceLoader;
+import org.apache.lucene.util.ResourceLoaderAware;
+
+/**
+ * Factory for {@link Word2VecSynonymFilter}.
+ *
+ * @lucene.experimental
+ * @lucene.spi {@value #NAME}
+ */
+public class Word2VecSynonymFilterFactory extends TokenFilterFactory
+ implements ResourceLoaderAware {
+
+ /** SPI name */
+ public static final String NAME = "Word2VecSynonym";
+
+ public static final int DEFAULT_MAX_SYNONYMS_PER_TERM = 5;
+ public static final float DEFAULT_MIN_ACCEPTED_SIMILARITY = 0.8f;
+
+ private final int maxSynonymsPerTerm;
+ private final float minAcceptedSimilarity;
+ private final Word2VecSupportedFormats format;
+ private final String word2vecModelFileName;
+
+ private Word2VecSynonymProvider synonymProvider;
+
+ public Word2VecSynonymFilterFactory(Map args) {
+ super(args);
+ this.maxSynonymsPerTerm = getInt(args, "maxSynonymsPerTerm", DEFAULT_MAX_SYNONYMS_PER_TERM);
+ this.minAcceptedSimilarity =
+ getFloat(args, "minAcceptedSimilarity", DEFAULT_MIN_ACCEPTED_SIMILARITY);
+ this.word2vecModelFileName = require(args, "model");
+
+ String modelFormat = get(args, "format", "dl4j").toUpperCase(Locale.ROOT);
+ try {
+ this.format = Word2VecSupportedFormats.valueOf(modelFormat);
+ } catch (IllegalArgumentException exc) {
+ throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc);
+ }
+
+ if (!args.isEmpty()) {
+ throw new IllegalArgumentException("Unknown parameters: " + args);
+ }
+ if (minAcceptedSimilarity <= 0 || minAcceptedSimilarity > 1) {
+ throw new IllegalArgumentException(
+ "minAcceptedSimilarity must be in the range (0, 1]. Found: " + minAcceptedSimilarity);
+ }
+ if (maxSynonymsPerTerm <= 0) {
+ throw new IllegalArgumentException(
+ "maxSynonymsPerTerm must be a positive integer greater than 0. Found: "
+ + maxSynonymsPerTerm);
+ }
+ }
+
+ /** Default ctor for compatibility with SPI */
+ public Word2VecSynonymFilterFactory() {
+ throw defaultCtorException();
+ }
+
+ Word2VecSynonymProvider getSynonymProvider() {
+ return this.synonymProvider;
+ }
+
+ @Override
+ public TokenStream create(TokenStream input) {
+ return synonymProvider == null
+ ? input
+ : new Word2VecSynonymFilter(
+ input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
+ }
+
+ @Override
+ public void inform(ResourceLoader loader) throws IOException {
+ this.synonymProvider =
+ Word2VecSynonymProviderFactory.getSynonymProvider(loader, word2vecModelFileName, format);
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
new file mode 100644
index 000000000000..1d1f06fc991f
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH;
+import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.List;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TermAndBoost;
+import org.apache.lucene.util.hnsw.HnswGraph;
+import org.apache.lucene.util.hnsw.HnswGraphBuilder;
+import org.apache.lucene.util.hnsw.HnswGraphSearcher;
+import org.apache.lucene.util.hnsw.NeighborQueue;
+
+/**
+ * The Word2VecSynonymProvider generates the list of sysnonyms of a term.
+ *
+ * @lucene.experimental
+ */
+public class Word2VecSynonymProvider {
+
+ private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
+ VectorSimilarityFunction.DOT_PRODUCT;
+ private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32;
+ private final Word2VecModel word2VecModel;
+ private final HnswGraph hnswGraph;
+
+ /**
+ * Word2VecSynonymProvider constructor
+ *
+ * @param model containing the set of TermAndVector entries
+ */
+ public Word2VecSynonymProvider(Word2VecModel model) throws IOException {
+ word2VecModel = model;
+
+ HnswGraphBuilder builder =
+ HnswGraphBuilder.create(
+ word2VecModel,
+ VECTOR_ENCODING,
+ SIMILARITY_FUNCTION,
+ DEFAULT_MAX_CONN,
+ DEFAULT_BEAM_WIDTH,
+ HnswGraphBuilder.randSeed);
+ this.hnswGraph = builder.build(word2VecModel.copy());
+ }
+
+ public List getSynonyms(
+ BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException {
+
+ if (term == null) {
+ throw new IllegalArgumentException("Term must not be null");
+ }
+
+ LinkedList result = new LinkedList<>();
+ float[] query = word2VecModel.vectorValue(term);
+ if (query != null) {
+ NeighborQueue synonyms =
+ HnswGraphSearcher.search(
+ query,
+ // The query vector is in the model. When looking for the top-k
+ // it's always the nearest neighbour of itself so, we look for the top-k+1
+ maxSynonymsPerTerm + 1,
+ word2VecModel,
+ VECTOR_ENCODING,
+ SIMILARITY_FUNCTION,
+ hnswGraph,
+ null,
+ word2VecModel.size());
+
+ int size = synonyms.size();
+ for (int i = 0; i < size; i++) {
+ float similarity = synonyms.topScore();
+ int id = synonyms.pop();
+
+ BytesRef synonym = word2VecModel.termValue(id);
+ // We remove the original query term
+ if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) {
+ result.addFirst(new TermAndBoost(synonym, similarity));
+ }
+ }
+ }
+ return result;
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java
new file mode 100644
index 000000000000..ea849e653cd6
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.lucene.util.ResourceLoader;
+
+/**
+ * Supply Word2Vec Word2VecSynonymProvider cache avoiding that multiple instances of
+ * Word2VecSynonymFilterFactory will instantiate multiple instances of the same SynonymProvider.
+ * Assumes synonymProvider implementations are thread-safe.
+ */
+public class Word2VecSynonymProviderFactory {
+
+ enum Word2VecSupportedFormats {
+ DL4J
+ }
+
+ private static Map word2vecSynonymProviders =
+ new ConcurrentHashMap<>();
+
+ public static Word2VecSynonymProvider getSynonymProvider(
+ ResourceLoader loader, String modelFileName, Word2VecSupportedFormats format)
+ throws IOException {
+ Word2VecSynonymProvider synonymProvider = word2vecSynonymProviders.get(modelFileName);
+ if (synonymProvider == null) {
+ try (InputStream stream = loader.openResource(modelFileName)) {
+ try (Dl4jModelReader reader = getModelReader(format, stream)) {
+ synonymProvider = new Word2VecSynonymProvider(reader.read());
+ }
+ }
+ word2vecSynonymProviders.put(modelFileName, synonymProvider);
+ }
+ return synonymProvider;
+ }
+
+ private static Dl4jModelReader getModelReader(
+ Word2VecSupportedFormats format, InputStream stream) {
+ switch (format) {
+ case DL4J:
+ return new Dl4jModelReader(stream);
+ }
+ return null;
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java
new file mode 100644
index 000000000000..e8d69ab3cf9b
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Analysis components for Synonyms using Word2Vec model. */
+package org.apache.lucene.analysis.synonym.word2vec;
diff --git a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory
index 19a34b7840a8..1e4e17eaeadf 100644
--- a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory
+++ b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory
@@ -118,6 +118,7 @@ org.apache.lucene.analysis.sv.SwedishLightStemFilterFactory
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory
org.apache.lucene.analysis.synonym.SynonymFilterFactory
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory
+org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory
org.apache.lucene.analysis.core.FlattenGraphFilterFactory
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory
org.apache.lucene.analysis.te.TeluguStemFilterFactory
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java
new file mode 100644
index 000000000000..213dcdaccd33
--- /dev/null
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Test;
+
+public class TestDl4jModelReader extends LuceneTestCase {
+
+ private static final String MODEL_FILE = "word2vec-model.zip";
+ private static final String MODEL_EMPTY_FILE = "word2vec-empty-model.zip";
+ private static final String CORRUPTED_VECTOR_DIMENSION_MODEL_FILE =
+ "word2vec-corrupted-vector-dimension-model.zip";
+
+ InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_FILE);
+ Dl4jModelReader unit = new Dl4jModelReader(stream);
+
+ @Test
+ public void read_zipFileWithMetadata_shouldReturnDictionarySize() throws Exception {
+ Word2VecModel model = unit.read();
+ long expectedDictionarySize = 235;
+ assertEquals(expectedDictionarySize, model.size());
+ }
+
+ @Test
+ public void read_zipFileWithMetadata_shouldReturnVectorLength() throws Exception {
+ Word2VecModel model = unit.read();
+ int expectedVectorDimension = 100;
+ assertEquals(expectedVectorDimension, model.dimension());
+ }
+
+ @Test
+ public void read_zipFile_shouldReturnDecodedTerm() throws Exception {
+ Word2VecModel model = unit.read();
+ BytesRef expectedDecodedFirstTerm = new BytesRef("it");
+ assertEquals(expectedDecodedFirstTerm, model.termValue(0));
+ }
+
+ @Test
+ public void decodeTerm_encodedTerm_shouldReturnDecodedTerm() throws Exception {
+ byte[] originalInput = "lucene".getBytes(StandardCharsets.UTF_8);
+ String B64encodedLuceneTerm = Base64.getEncoder().encodeToString(originalInput);
+ String word2vecEncodedLuceneTerm = "B64:" + B64encodedLuceneTerm;
+ assertEquals(new BytesRef("lucene"), Dl4jModelReader.decodeB64Term(word2vecEncodedLuceneTerm));
+ }
+
+ @Test
+ public void read_EmptyZipFile_shouldThrowException() throws Exception {
+ try (InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_EMPTY_FILE)) {
+ Dl4jModelReader unit = new Dl4jModelReader(stream);
+ expectThrows(IllegalArgumentException.class, unit::read);
+ }
+ }
+
+ @Test
+ public void read_corruptedVectorDimensionModelFile_shouldThrowException() throws Exception {
+ try (InputStream stream =
+ TestDl4jModelReader.class.getResourceAsStream(CORRUPTED_VECTOR_DIMENSION_MODEL_FILE)) {
+ Dl4jModelReader unit = new Dl4jModelReader(stream);
+ expectThrows(RuntimeException.class, unit::read);
+ }
+ }
+}
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java
new file mode 100644
index 000000000000..3999931dd758
--- /dev/null
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.Tokenizer;
+import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
+import org.apache.lucene.tests.analysis.MockTokenizer;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TermAndVector;
+import org.junit.Test;
+
+public class TestWord2VecSynonymFilter extends BaseTokenStreamTestCase {
+
+ @Test
+ public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() throws Exception {
+ int maxSynonymPerTerm = 10;
+ float minAcceptedSimilarity = 0.9f;
+ Word2VecModel model = new Word2VecModel(6, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
+
+ Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
+
+ Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
+ assertAnalyzesTo(
+ a,
+ "pre a post", // input
+ new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output
+ new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset
+ new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset
+ new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types
+ new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements
+ new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts
+ a.close();
+ }
+
+ @Test
+ public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() throws Exception {
+ int maxSynonymPerTerm = 2;
+ float minAcceptedSimilarity = 0.9f;
+ Word2VecModel model = new Word2VecModel(5, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
+
+ Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
+
+ Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
+ assertAnalyzesTo(
+ a,
+ "pre a post", // input
+ new String[] {"pre", "a", "d", "e", "post"}, // output
+ new int[] {0, 4, 4, 4, 6}, // start offset
+ new int[] {3, 5, 5, 5, 10}, // end offset
+ new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types
+ new int[] {1, 1, 0, 0, 1}, // posIncrements
+ new int[] {1, 1, 1, 1, 1}); // posLenghts
+ a.close();
+ }
+
+ @Test
+ public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Exception {
+ Word2VecModel model = new Word2VecModel(8, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {1, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("post"), new float[] {-10, -11}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("after"), new float[] {-8, -10}));
+
+ Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
+
+ Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f);
+ assertAnalyzesTo(
+ a,
+ "pre a post", // input
+ new String[] {"pre", "a", "d", "e", "c", "b", "post", "after"}, // output
+ new int[] {0, 4, 4, 4, 4, 4, 6, 6}, // start offset
+ new int[] {3, 5, 5, 5, 5, 5, 10, 10}, // end offset
+ new String[] { // types
+ "word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM"
+ },
+ new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements
+ new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths
+ a.close();
+ }
+
+ @Test
+ public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms()
+ throws Exception {
+ Word2VecModel model = new Word2VecModel(4, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, -10}));
+
+ Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
+
+ Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f);
+ assertAnalyzesTo(
+ a,
+ "pre a post", // input
+ new String[] {"pre", "a", "post"}, // output
+ new int[] {0, 4, 6}, // start offset
+ new int[] {3, 5, 10}, // end offset
+ new String[] {"word", "word", "word"}, // types
+ new int[] {1, 1, 1}, // posIncrements
+ new int[] {1, 1, 1}); // posLengths
+ a.close();
+ }
+
+ private Analyzer getAnalyzer(
+ Word2VecSynonymProvider synonymProvider,
+ int maxSynonymsPerTerm,
+ float minAcceptedSimilarity) {
+ return new Analyzer() {
+ @Override
+ protected TokenStreamComponents createComponents(String fieldName) {
+ Tokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false);
+ // Make a local variable so testRandomHuge doesn't share it across threads!
+ Word2VecSynonymFilter synFilter =
+ new Word2VecSynonymFilter(
+ tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
+ return new TokenStreamComponents(tokenizer, synFilter);
+ }
+ };
+ }
+}
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java
new file mode 100644
index 000000000000..007fedf4abed
--- /dev/null
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import org.apache.lucene.tests.analysis.BaseTokenStreamFactoryTestCase;
+import org.apache.lucene.util.ClasspathResourceLoader;
+import org.apache.lucene.util.ResourceLoader;
+import org.junit.Test;
+
+public class TestWord2VecSynonymFilterFactory extends BaseTokenStreamFactoryTestCase {
+
+ public static final String FACTORY_NAME = "Word2VecSynonym";
+ private static final String WORD2VEC_MODEL_FILE = "word2vec-model.zip";
+
+ @Test
+ public void testInform() throws Exception {
+ ResourceLoader loader = new ClasspathResourceLoader(getClass());
+ assertTrue("loader is null and it shouldn't be", loader != null);
+ Word2VecSynonymFilterFactory factory =
+ (Word2VecSynonymFilterFactory)
+ tokenFilterFactory(
+ FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "minAcceptedSimilarity", "0.7");
+
+ Word2VecSynonymProvider synonymProvider = factory.getSynonymProvider();
+ assertNotEquals(null, synonymProvider);
+ }
+
+ @Test
+ public void missingRequiredArgument_shouldThrowException() throws Exception {
+ IllegalArgumentException expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "format",
+ "dl4j",
+ "minAcceptedSimilarity",
+ "0.7",
+ "maxSynonymsPerTerm",
+ "10");
+ });
+ assertTrue(expected.getMessage().contains("Configuration Error: missing parameter 'model'"));
+ }
+
+ @Test
+ public void unsupportedModelFormat_shouldThrowException() throws Exception {
+ IllegalArgumentException expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "format", "bogusValue");
+ });
+ assertTrue(expected.getMessage().contains("Model format 'BOGUSVALUE' not supported"));
+ }
+
+ @Test
+ public void bogusArgument_shouldThrowException() throws Exception {
+ IllegalArgumentException expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "bogusArg", "bogusValue");
+ });
+ assertTrue(expected.getMessage().contains("Unknown parameters"));
+ }
+
+ @Test
+ public void illegalArguments_shouldThrowException() throws Exception {
+ IllegalArgumentException expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "model",
+ WORD2VEC_MODEL_FILE,
+ "minAcceptedSimilarity",
+ "2",
+ "maxSynonymsPerTerm",
+ "10");
+ });
+ assertTrue(
+ expected
+ .getMessage()
+ .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 2"));
+
+ expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "model",
+ WORD2VEC_MODEL_FILE,
+ "minAcceptedSimilarity",
+ "0",
+ "maxSynonymsPerTerm",
+ "10");
+ });
+ assertTrue(
+ expected
+ .getMessage()
+ .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 0"));
+
+ expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "model",
+ WORD2VEC_MODEL_FILE,
+ "minAcceptedSimilarity",
+ "0.7",
+ "maxSynonymsPerTerm",
+ "-1");
+ });
+ assertTrue(
+ expected
+ .getMessage()
+ .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: -1"));
+
+ expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "model",
+ WORD2VEC_MODEL_FILE,
+ "minAcceptedSimilarity",
+ "0.7",
+ "maxSynonymsPerTerm",
+ "0");
+ });
+ assertTrue(
+ expected
+ .getMessage()
+ .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: 0"));
+ }
+}
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java
new file mode 100644
index 000000000000..0803406bdafa
--- /dev/null
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import java.util.List;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TermAndBoost;
+import org.apache.lucene.util.TermAndVector;
+import org.junit.Test;
+
+public class TestWord2VecSynonymProvider extends LuceneTestCase {
+
+ private static final int MAX_SYNONYMS_PER_TERM = 10;
+ private static final float MIN_ACCEPTED_SIMILARITY = 0.85f;
+
+ private final Word2VecSynonymProvider unit;
+
+ public TestWord2VecSynonymProvider() throws IOException {
+ Word2VecModel model = new Word2VecModel(2, 3);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {0.24f, 0.78f, 0.28f}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {0.44f, 0.01f, 0.81f}));
+ unit = new Word2VecSynonymProvider(model);
+ }
+
+ @Test
+ public void getSynonyms_nullToken_shouldThrowException() {
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> unit.getSynonyms(null, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY));
+ }
+
+ @Test
+ public void getSynonyms_shouldReturnSynonymsBasedOnMinAcceptedSimilarity() throws Exception {
+ Word2VecModel model = new Word2VecModel(6, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
+
+ Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
+
+ BytesRef inputTerm = new BytesRef("a");
+ String[] expectedSynonyms = {"d", "e", "c", "b"};
+ List actualSynonymsResults =
+ unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
+
+ assertEquals(4, actualSynonymsResults.size());
+ for (int i = 0; i < expectedSynonyms.length; i++) {
+ assertEquals(new BytesRef(expectedSynonyms[i]), actualSynonymsResults.get(i).term);
+ }
+ }
+
+ @Test
+ public void getSynonyms_shouldReturnSynonymsBoost() throws Exception {
+ Word2VecModel model = new Word2VecModel(3, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {99, 101}));
+
+ Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
+
+ BytesRef inputTerm = new BytesRef("a");
+ List actualSynonymsResults =
+ unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
+
+ BytesRef expectedFirstSynonymTerm = new BytesRef("b");
+ double expectedFirstSynonymBoost = 1.0;
+ assertEquals(expectedFirstSynonymTerm, actualSynonymsResults.get(0).term);
+ assertEquals(expectedFirstSynonymBoost, actualSynonymsResults.get(0).boost, 0.001f);
+ }
+
+ @Test
+ public void noSynonymsWithinAcceptedSimilarity_shouldReturnNoSynonyms() throws Exception {
+ Word2VecModel model = new Word2VecModel(4, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {6, -6}));
+
+ Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
+
+ BytesRef inputTerm = newBytesRef("a");
+ List actualSynonymsResults =
+ unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
+ assertEquals(0, actualSynonymsResults.size());
+ }
+
+ @Test
+ public void testModel_shouldReturnNormalizedVectors() {
+ Word2VecModel model = new Word2VecModel(4, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
+
+ float[] vectorIdA = model.vectorValue(new BytesRef("a"));
+ float[] vectorIdF = model.vectorValue(new BytesRef("f"));
+ assertArrayEquals(new float[] {0.70710f, 0.70710f}, vectorIdA, 0.001f);
+ assertArrayEquals(new float[] {-0.0995f, 0.99503f}, vectorIdF, 0.001f);
+ }
+
+ @Test
+ public void normalizedVector_shouldReturnModule1() {
+ TermAndVector synonymTerm = new TermAndVector(new BytesRef("a"), new float[] {10, 10});
+ synonymTerm.normalizeVector();
+ float[] vector = synonymTerm.getVector();
+ float len = 0;
+ for (int i = 0; i < vector.length; i++) {
+ len += vector[i] * vector[i];
+ }
+ assertEquals(1, Math.sqrt(len), 0.0001f);
+ }
+}
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip
new file mode 100644
index 000000000000..e25693dd83cf
Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip differ
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip
new file mode 100644
index 000000000000..57d7832dd787
Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip differ
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip
new file mode 100644
index 000000000000..6d31b8d5a3fa
Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip differ
diff --git a/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java b/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java
index 3ad17dca5ce9..adf30baec294 100644
--- a/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/QueryBuilder.java
@@ -62,20 +62,6 @@ public class QueryBuilder {
protected boolean enableGraphQueries = true;
protected boolean autoGenerateMultiTermSynonymsPhraseQuery = false;
- /** Wraps a term and boost */
- public static class TermAndBoost {
- /** the term */
- public final BytesRef term;
- /** the boost */
- public final float boost;
-
- /** Creates a new TermAndBoost */
- public TermAndBoost(BytesRef term, float boost) {
- this.term = BytesRef.deepCopyOf(term);
- this.boost = boost;
- }
- }
-
/** Creates a new QueryBuilder using the given analyzer. */
public QueryBuilder(Analyzer analyzer) {
this.analyzer = analyzer;
diff --git a/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java b/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java
new file mode 100644
index 000000000000..0c7958a99a08
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/TermAndBoost.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util;
+
+/** Wraps a term and boost */
+public class TermAndBoost {
+ /** the term */
+ public final BytesRef term;
+ /** the boost */
+ public final float boost;
+
+ /** Creates a new TermAndBoost */
+ public TermAndBoost(BytesRef term, float boost) {
+ this.term = BytesRef.deepCopyOf(term);
+ this.boost = boost;
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java
new file mode 100644
index 000000000000..1ade19a19803
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.util.Locale;
+
+/**
+ * Word2Vec unit composed by a term with the associated vector
+ *
+ * @lucene.experimental
+ */
+public class TermAndVector {
+
+ private final BytesRef term;
+ private final float[] vector;
+
+ public TermAndVector(BytesRef term, float[] vector) {
+ this.term = term;
+ this.vector = vector;
+ }
+
+ public BytesRef getTerm() {
+ return this.term;
+ }
+
+ public float[] getVector() {
+ return this.vector;
+ }
+
+ public int size() {
+ return vector.length;
+ }
+
+ public void normalizeVector() {
+ float vectorLength = 0;
+ for (int i = 0; i < vector.length; i++) {
+ vectorLength += vector[i] * vector[i];
+ }
+ vectorLength = (float) Math.sqrt(vectorLength);
+ for (int i = 0; i < vector.length; i++) {
+ vector[i] /= vectorLength;
+ }
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder builder = new StringBuilder(this.term.utf8ToString());
+ builder.append(" [");
+ if (vector.length > 0) {
+ for (int i = 0; i < vector.length - 1; i++) {
+ builder.append(String.format(Locale.ROOT, "%.3f,", vector[i]));
+ }
+ builder.append(String.format(Locale.ROOT, "%.3f]", vector[vector.length - 1]));
+ }
+ return builder.toString();
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index 9f1e6c505254..edefb696b8c5 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -41,8 +41,17 @@
*/
public final class HnswGraphBuilder {
+ /** Default number of maximum connections per node */
+ public static final int DEFAULT_MAX_CONN = 16;
+
+ /**
+ * Default number of the size of the queue maintained while searching during a graph construction.
+ */
+ public static final int DEFAULT_BEAM_WIDTH = 100;
+
/** Default random seed for level generation * */
private static final long DEFAULT_RAND_SEED = 42;
+
/** A name for the HNSW component for the info-stream * */
public static final String HNSW_COMPONENT = "HNSW";
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index 08f089430ba5..372384df5572 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -54,6 +54,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
+import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.After;
import org.junit.Before;
@@ -62,7 +63,7 @@ public class TestKnnGraph extends LuceneTestCase {
private static final String KNN_GRAPH_FIELD = "vector";
- private static int M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
+ private static int M = HnswGraphBuilder.DEFAULT_MAX_CONN;
private Codec codec;
private Codec float32Codec;
@@ -80,7 +81,7 @@ public void setup() {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
+ return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
@@ -92,7 +93,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
+ return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
@@ -103,7 +104,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
+ return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
}
@@ -115,7 +116,7 @@ private VectorEncoding randomVectorEncoding() {
@After
public void cleanup() {
- M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
+ M = HnswGraphBuilder.DEFAULT_MAX_CONN;
}
/** Basic test of creating documents in a graph */
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java
index b2ce16d80776..f9cd607ccacc 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java
@@ -55,6 +55,7 @@
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
+import org.apache.lucene.search.BoostAttribute;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
@@ -154,7 +155,8 @@ public static void assertTokenStreamContents(
boolean[] keywordAtts,
boolean graphOffsetsAreCorrect,
byte[][] payloads,
- int[] flags)
+ int[] flags,
+ float[] boost)
throws IOException {
assertNotNull(output);
CheckClearAttributesAttribute checkClearAtt =
@@ -221,6 +223,12 @@ public static void assertTokenStreamContents(
flagsAtt = ts.getAttribute(FlagsAttribute.class);
}
+ BoostAttribute boostAtt = null;
+ if (boost != null) {
+ assertTrue("has no BoostAttribute", ts.hasAttribute(BoostAttribute.class));
+ boostAtt = ts.getAttribute(BoostAttribute.class);
+ }
+
// Maps position to the start/end offset:
final Map posToStartOffset = new HashMap<>();
final Map posToEndOffset = new HashMap<>();
@@ -243,6 +251,7 @@ public static void assertTokenStreamContents(
if (payloadAtt != null)
payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24}));
if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's
+ if (boostAtt != null) boostAtt.setBoost(-1f);
checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before
assertTrue("token " + i + " does not exist", ts.incrementToken());
@@ -278,6 +287,9 @@ public static void assertTokenStreamContents(
if (flagsAtt != null) {
assertEquals("flagsAtt " + i + " term=" + termAtt, flags[i], flagsAtt.getFlags());
}
+ if (boostAtt != null) {
+ assertEquals("boostAtt " + i + " term=" + termAtt, boost[i], boostAtt.getBoost(), 0.001);
+ }
if (payloads != null) {
if (payloads[i] != null) {
assertEquals("payloads " + i, new BytesRef(payloads[i]), payloadAtt.getPayload());
@@ -405,6 +417,7 @@ public static void assertTokenStreamContents(
if (payloadAtt != null)
payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24}));
if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's
+ if (boostAtt != null) boostAtt.setBoost(-1);
checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before
@@ -426,6 +439,38 @@ public static void assertTokenStreamContents(
ts.close();
}
+ public static void assertTokenStreamContents(
+ TokenStream ts,
+ String[] output,
+ int[] startOffsets,
+ int[] endOffsets,
+ String[] types,
+ int[] posIncrements,
+ int[] posLengths,
+ Integer finalOffset,
+ Integer finalPosInc,
+ boolean[] keywordAtts,
+ boolean graphOffsetsAreCorrect,
+ byte[][] payloads,
+ int[] flags)
+ throws IOException {
+ assertTokenStreamContents(
+ ts,
+ output,
+ startOffsets,
+ endOffsets,
+ types,
+ posIncrements,
+ posLengths,
+ finalOffset,
+ finalPosInc,
+ keywordAtts,
+ graphOffsetsAreCorrect,
+ payloads,
+ flags,
+ null);
+ }
+
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
@@ -438,6 +483,33 @@ public static void assertTokenStreamContents(
boolean[] keywordAtts,
boolean graphOffsetsAreCorrect)
throws IOException {
+ assertTokenStreamContents(
+ ts,
+ output,
+ startOffsets,
+ endOffsets,
+ types,
+ posIncrements,
+ posLengths,
+ finalOffset,
+ keywordAtts,
+ graphOffsetsAreCorrect,
+ null);
+ }
+
+ public static void assertTokenStreamContents(
+ TokenStream ts,
+ String[] output,
+ int[] startOffsets,
+ int[] endOffsets,
+ String[] types,
+ int[] posIncrements,
+ int[] posLengths,
+ Integer finalOffset,
+ boolean[] keywordAtts,
+ boolean graphOffsetsAreCorrect,
+ float[] boost)
+ throws IOException {
assertTokenStreamContents(
ts,
output,
@@ -451,7 +523,8 @@ public static void assertTokenStreamContents(
keywordAtts,
graphOffsetsAreCorrect,
null,
- null);
+ null,
+ boost);
}
public static void assertTokenStreamContents(
@@ -481,9 +554,36 @@ public static void assertTokenStreamContents(
keywordAtts,
graphOffsetsAreCorrect,
payloads,
+ null,
null);
}
+ public static void assertTokenStreamContents(
+ TokenStream ts,
+ String[] output,
+ int[] startOffsets,
+ int[] endOffsets,
+ String[] types,
+ int[] posIncrements,
+ int[] posLengths,
+ Integer finalOffset,
+ boolean graphOffsetsAreCorrect,
+ float[] boost)
+ throws IOException {
+ assertTokenStreamContents(
+ ts,
+ output,
+ startOffsets,
+ endOffsets,
+ types,
+ posIncrements,
+ posLengths,
+ finalOffset,
+ null,
+ graphOffsetsAreCorrect,
+ boost);
+ }
+
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
@@ -505,7 +605,8 @@ public static void assertTokenStreamContents(
posLengths,
finalOffset,
null,
- graphOffsetsAreCorrect);
+ graphOffsetsAreCorrect,
+ null);
}
public static void assertTokenStreamContents(
@@ -522,6 +623,30 @@ public static void assertTokenStreamContents(
ts, output, startOffsets, endOffsets, types, posIncrements, posLengths, finalOffset, true);
}
+ public static void assertTokenStreamContents(
+ TokenStream ts,
+ String[] output,
+ int[] startOffsets,
+ int[] endOffsets,
+ String[] types,
+ int[] posIncrements,
+ int[] posLengths,
+ Integer finalOffset,
+ float[] boost)
+ throws IOException {
+ assertTokenStreamContents(
+ ts,
+ output,
+ startOffsets,
+ endOffsets,
+ types,
+ posIncrements,
+ posLengths,
+ finalOffset,
+ true,
+ boost);
+ }
+
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
@@ -649,6 +774,21 @@ public static void assertAnalyzesTo(
int[] posIncrements,
int[] posLengths)
throws IOException {
+ assertAnalyzesTo(
+ a, input, output, startOffsets, endOffsets, types, posIncrements, posLengths, null);
+ }
+
+ public static void assertAnalyzesTo(
+ Analyzer a,
+ String input,
+ String[] output,
+ int[] startOffsets,
+ int[] endOffsets,
+ String[] types,
+ int[] posIncrements,
+ int[] posLengths,
+ float[] boost)
+ throws IOException {
assertTokenStreamContents(
a.tokenStream("dummy", input),
output,
@@ -657,7 +797,8 @@ public static void assertAnalyzesTo(
types,
posIncrements,
posLengths,
- input.length());
+ input.length(),
+ boost);
checkResetException(a, input);
checkAnalysisConsistency(random(), a, true, input);
}