Skip to content

Commit

Permalink
reasoning about thread safety
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandrobenedetti committed Apr 28, 2023
1 parent c15078f commit e22d1f9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,7 @@ public Word2VecSynonymProvider(Word2VecModel model) throws IOException {
this.hnswGraph = builder.build(word2VecModel.copy());
}

/**
* Returns the list of synonyms of a provided term. This method is synchronized because it uses
* the {@link org.apache.lucene.util.hnsw.OnHeapHnswGraph} that is not thread-safe.
*
* @param term term to search to find synonyms
* @param maxSynonymsPerTerm limit of synonyms returned
* @param minAcceptedSimilarity lower similarity threshold to consider another term as synonym
*/
public synchronized List<TermAndBoost> getSynonyms(
public List<TermAndBoost> getSynonyms(
BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException {

if (term == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ public NeighborQueue searchLevel(
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
}

/**
* <p>This method is not thread-safe at the moment because it relies on {@link HnswGraph#seek} that changes the status of the
* stateful {@link OnHeapHnswGraph}, for this reason is synchronized on the {@link HnswGraph}.
*
*/
private NeighborQueue searchLevel(
T query,
int topK,
Expand All @@ -218,69 +223,71 @@ private NeighborQueue searchLevel(
Bits acceptOrds,
int visitedLimit)
throws IOException {
int size = graph.size();
NeighborQueue results = new NeighborQueue(topK, false);
prepareScratchState(vectors.size());
synchronized (graph) {
int size = graph.size();
NeighborQueue results = new NeighborQueue(topK, false);
prepareScratchState(vectors.size());

int numVisited = 0;
for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float score = compare(query, vectors, ep);
numVisited++;
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
results.add(ep, score);
int numVisited = 0;
for (int ep : eps) {
if (visited.getAndSet(ep) == false) {
if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float score = compare(query, vectors, ep);
numVisited++;
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
results.add(ep, score);
}
}
}
}

// A bound that holds the minimum similarity to the query vector that a candidate vector must
// have to be considered.
float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
if (results.size() >= topK) {
minAcceptedSimilarity = results.topScore();
}
while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring)
float topCandidateSimilarity = candidates.topScore();
if (topCandidateSimilarity < minAcceptedSimilarity) {
break;
// A bound that holds the minimum similarity to the query vector that a candidate vector must
// have to be considered.
float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
if (results.size() >= topK) {
minAcceptedSimilarity = results.topScore();
}

int topCandidateNode = candidates.pop();
graph.seek(level, topCandidateNode);
int friendOrd;
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;
}

if (numVisited >= visitedLimit) {
results.markIncomplete();
while (candidates.size() > 0 && results.incomplete() == false) {
// get the best candidate (closest or best scoring)
float topCandidateSimilarity = candidates.topScore();
if (topCandidateSimilarity < minAcceptedSimilarity) {
break;
}
float friendSimilarity = compare(query, vectors, friendOrd);
numVisited++;
if (friendSimilarity >= minAcceptedSimilarity) {
candidates.add(friendOrd, friendSimilarity);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) {
minAcceptedSimilarity = results.topScore();

int topCandidateNode = candidates.pop();
graph.seek(level, topCandidateNode);
int friendOrd;
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;
}

if (numVisited >= visitedLimit) {
results.markIncomplete();
break;
}
float friendSimilarity = compare(query, vectors, friendOrd);
numVisited++;
if (friendSimilarity >= minAcceptedSimilarity) {
candidates.add(friendOrd, friendSimilarity);
if (acceptOrds == null || acceptOrds.get(friendOrd)) {
if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) {
minAcceptedSimilarity = results.topScore();
}
}
}
}
}
while (results.size() > topK) {
results.pop();
}
results.setVisitedCount(numVisited);
return results;
}
while (results.size() > topK) {
results.pop();
}
results.setVisitedCount(numVisited);
return results;
}

private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
/**
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
* construct the HNSW graph before it's written to the index.
* This class is stateful and not thread-safe as there are iterator members (cur, upto).
*/
public final class OnHeapHnswGraph extends HnswGraph implements Accountable {

Expand Down

0 comments on commit e22d1f9

Please sign in to comment.