From e22d1f9f69764e559d1f04b20636ea30d676ed0f Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti <a.benedetti@sease.io> Date: Fri, 28 Apr 2023 11:50:04 +0200 Subject: [PATCH] reasoning about thread safety --- .../word2vec/Word2VecSynonymProvider.java | 10 +- .../lucene/util/hnsw/HnswGraphSearcher.java | 109 ++++++++++-------- .../lucene/util/hnsw/OnHeapHnswGraph.java | 1 + 3 files changed, 60 insertions(+), 60 deletions(-) diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java index 0f0394280528..c9227b90b919 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java @@ -65,15 +65,7 @@ public Word2VecSynonymProvider(Word2VecModel model) throws IOException { this.hnswGraph = builder.build(word2VecModel.copy()); } - /** - * Returns the list of synonyms of a provided term. This method is synchronized because it uses - * the {@link org.apache.lucene.util.hnsw.OnHeapHnswGraph} that is not thread-safe. - * - * @param term term to search to find synonyms - * @param maxSynonymsPerTerm limit of synonyms returned - * @param minAcceptedSimilarity lower similarity threshold to consider another term as synonym - */ - public synchronized List<TermAndBoost> getSynonyms( + public List<TermAndBoost> getSynonyms( BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException { if (term == null) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 4857d5b9d577..1f1a5f9349a7 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -208,6 +208,11 @@ public NeighborQueue searchLevel( return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); } + /** + * <p>This method is not thread-safe at the moment because it relies on {@link HnswGraph#seek} that changes the status of the + * stateful {@link OnHeapHnswGraph}, for this reason is synchronized on the {@link HnswGraph}. + * + */ private NeighborQueue searchLevel( T query, int topK, @@ -218,69 +223,71 @@ private NeighborQueue searchLevel( Bits acceptOrds, int visitedLimit) throws IOException { - int size = graph.size(); - NeighborQueue results = new NeighborQueue(topK, false); - prepareScratchState(vectors.size()); + synchronized (graph) { + int size = graph.size(); + NeighborQueue results = new NeighborQueue(topK, false); + prepareScratchState(vectors.size()); - int numVisited = 0; - for (int ep : eps) { - if (visited.getAndSet(ep) == false) { - if (numVisited >= visitedLimit) { - results.markIncomplete(); - break; - } - float score = compare(query, vectors, ep); - numVisited++; - candidates.add(ep, score); - if (acceptOrds == null || acceptOrds.get(ep)) { - results.add(ep, score); + int numVisited = 0; + for (int ep : eps) { + if (visited.getAndSet(ep) == false) { + if (numVisited >= visitedLimit) { + results.markIncomplete(); + break; + } + float score = compare(query, vectors, ep); + numVisited++; + candidates.add(ep, score); + if (acceptOrds == null || acceptOrds.get(ep)) { + results.add(ep, score); + } } } - } - // A bound that holds the minimum similarity to the query vector that a candidate vector must - // have to be considered. - float minAcceptedSimilarity = Float.NEGATIVE_INFINITY; - if (results.size() >= topK) { - minAcceptedSimilarity = results.topScore(); - } - while (candidates.size() > 0 && results.incomplete() == false) { - // get the best candidate (closest or best scoring) - float topCandidateSimilarity = candidates.topScore(); - if (topCandidateSimilarity < minAcceptedSimilarity) { - break; + // A bound that holds the minimum similarity to the query vector that a candidate vector must + // have to be considered. + float minAcceptedSimilarity = Float.NEGATIVE_INFINITY; + if (results.size() >= topK) { + minAcceptedSimilarity = results.topScore(); } - - int topCandidateNode = candidates.pop(); - graph.seek(level, topCandidateNode); - int friendOrd; - while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) { - assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; - if (visited.getAndSet(friendOrd)) { - continue; - } - - if (numVisited >= visitedLimit) { - results.markIncomplete(); + while (candidates.size() > 0 && results.incomplete() == false) { + // get the best candidate (closest or best scoring) + float topCandidateSimilarity = candidates.topScore(); + if (topCandidateSimilarity < minAcceptedSimilarity) { break; } - float friendSimilarity = compare(query, vectors, friendOrd); - numVisited++; - if (friendSimilarity >= minAcceptedSimilarity) { - candidates.add(friendOrd, friendSimilarity); - if (acceptOrds == null || acceptOrds.get(friendOrd)) { - if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) { - minAcceptedSimilarity = results.topScore(); + + int topCandidateNode = candidates.pop(); + graph.seek(level, topCandidateNode); + int friendOrd; + while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) { + assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; + if (visited.getAndSet(friendOrd)) { + continue; + } + + if (numVisited >= visitedLimit) { + results.markIncomplete(); + break; + } + float friendSimilarity = compare(query, vectors, friendOrd); + numVisited++; + if (friendSimilarity >= minAcceptedSimilarity) { + candidates.add(friendOrd, friendSimilarity); + if (acceptOrds == null || acceptOrds.get(friendOrd)) { + if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) { + minAcceptedSimilarity = results.topScore(); + } } } } } + while (results.size() > topK) { + results.pop(); + } + results.setVisitedCount(numVisited); + return results; } - while (results.size() > topK) { - results.pop(); - } - results.setVisitedCount(numVisited); - return results; } private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 9862536de08c..c473c9558224 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -28,6 +28,7 @@ /** * An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to * construct the HNSW graph before it's written to the index. + * This class is stateful and not thread-safe as there are iterator members (cur, upto). */ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {