From e22d1f9f69764e559d1f04b20636ea30d676ed0f Mon Sep 17 00:00:00 2001
From: Alessandro Benedetti <a.benedetti@sease.io>
Date: Fri, 28 Apr 2023 11:50:04 +0200
Subject: [PATCH] reasoning about thread safety

---
 .../word2vec/Word2VecSynonymProvider.java     |  10 +-
 .../lucene/util/hnsw/HnswGraphSearcher.java   | 109 ++++++++++--------
 .../lucene/util/hnsw/OnHeapHnswGraph.java     |   1 +
 3 files changed, 60 insertions(+), 60 deletions(-)

diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
index 0f0394280528..c9227b90b919 100644
--- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
@@ -65,15 +65,7 @@ public Word2VecSynonymProvider(Word2VecModel model) throws IOException {
     this.hnswGraph = builder.build(word2VecModel.copy());
   }
 
-  /**
-   * Returns the list of synonyms of a provided term. This method is synchronized because it uses
-   * the {@link org.apache.lucene.util.hnsw.OnHeapHnswGraph} that is not thread-safe.
-   *
-   * @param term term to search to find synonyms
-   * @param maxSynonymsPerTerm limit of synonyms returned
-   * @param minAcceptedSimilarity lower similarity threshold to consider another term as synonym
-   */
-  public synchronized List<TermAndBoost> getSynonyms(
+  public List<TermAndBoost> getSynonyms(
       BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException {
 
     if (term == null) {
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index 4857d5b9d577..1f1a5f9349a7 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -208,6 +208,11 @@ public NeighborQueue searchLevel(
     return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
   }
 
+  /**
+   * <p>This method is not thread-safe at the moment because it relies on {@link HnswGraph#seek} that changes the status of the
+   * stateful {@link OnHeapHnswGraph}, for this reason is synchronized on the {@link  HnswGraph}.
+   * 
+   */
   private NeighborQueue searchLevel(
       T query,
       int topK,
@@ -218,69 +223,71 @@ private NeighborQueue searchLevel(
       Bits acceptOrds,
       int visitedLimit)
       throws IOException {
-    int size = graph.size();
-    NeighborQueue results = new NeighborQueue(topK, false);
-    prepareScratchState(vectors.size());
+    synchronized (graph) {
+      int size = graph.size();
+      NeighborQueue results = new NeighborQueue(topK, false);
+      prepareScratchState(vectors.size());
 
-    int numVisited = 0;
-    for (int ep : eps) {
-      if (visited.getAndSet(ep) == false) {
-        if (numVisited >= visitedLimit) {
-          results.markIncomplete();
-          break;
-        }
-        float score = compare(query, vectors, ep);
-        numVisited++;
-        candidates.add(ep, score);
-        if (acceptOrds == null || acceptOrds.get(ep)) {
-          results.add(ep, score);
+      int numVisited = 0;
+      for (int ep : eps) {
+        if (visited.getAndSet(ep) == false) {
+          if (numVisited >= visitedLimit) {
+            results.markIncomplete();
+            break;
+          }
+          float score = compare(query, vectors, ep);
+          numVisited++;
+          candidates.add(ep, score);
+          if (acceptOrds == null || acceptOrds.get(ep)) {
+            results.add(ep, score);
+          }
         }
       }
-    }
 
-    // A bound that holds the minimum similarity to the query vector that a candidate vector must
-    // have to be considered.
-    float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
-    if (results.size() >= topK) {
-      minAcceptedSimilarity = results.topScore();
-    }
-    while (candidates.size() > 0 && results.incomplete() == false) {
-      // get the best candidate (closest or best scoring)
-      float topCandidateSimilarity = candidates.topScore();
-      if (topCandidateSimilarity < minAcceptedSimilarity) {
-        break;
+      // A bound that holds the minimum similarity to the query vector that a candidate vector must
+      // have to be considered.
+      float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
+      if (results.size() >= topK) {
+        minAcceptedSimilarity = results.topScore();
       }
-
-      int topCandidateNode = candidates.pop();
-      graph.seek(level, topCandidateNode);
-      int friendOrd;
-      while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
-        assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
-        if (visited.getAndSet(friendOrd)) {
-          continue;
-        }
-
-        if (numVisited >= visitedLimit) {
-          results.markIncomplete();
+      while (candidates.size() > 0 && results.incomplete() == false) {
+        // get the best candidate (closest or best scoring)
+        float topCandidateSimilarity = candidates.topScore();
+        if (topCandidateSimilarity < minAcceptedSimilarity) {
           break;
         }
-        float friendSimilarity = compare(query, vectors, friendOrd);
-        numVisited++;
-        if (friendSimilarity >= minAcceptedSimilarity) {
-          candidates.add(friendOrd, friendSimilarity);
-          if (acceptOrds == null || acceptOrds.get(friendOrd)) {
-            if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) {
-              minAcceptedSimilarity = results.topScore();
+
+        int topCandidateNode = candidates.pop();
+        graph.seek(level, topCandidateNode);
+        int friendOrd;
+        while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
+          assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
+          if (visited.getAndSet(friendOrd)) {
+            continue;
+          }
+
+          if (numVisited >= visitedLimit) {
+            results.markIncomplete();
+            break;
+          }
+          float friendSimilarity = compare(query, vectors, friendOrd);
+          numVisited++;
+          if (friendSimilarity >= minAcceptedSimilarity) {
+            candidates.add(friendOrd, friendSimilarity);
+            if (acceptOrds == null || acceptOrds.get(friendOrd)) {
+              if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) {
+                minAcceptedSimilarity = results.topScore();
+              }
             }
           }
         }
       }
+      while (results.size() > topK) {
+        results.pop();
+      }
+      results.setVisitedCount(numVisited);
+      return results;
     }
-    while (results.size() > topK) {
-      results.pop();
-    }
-    results.setVisitedCount(numVisited);
-    return results;
   }
 
   private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index 9862536de08c..c473c9558224 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -28,6 +28,7 @@
 /**
  * An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
  * construct the HNSW graph before it's written to the index.
+ * This class is stateful and not thread-safe as there are iterator members (cur, upto).
  */
 public final class OnHeapHnswGraph extends HnswGraph implements Accountable {