From c047b06dfeda64ff237c58ed2cf18c384f412b14 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Fri, 28 Apr 2023 08:27:50 -0500 Subject: [PATCH] add ConcurrentOnHeapHnswGraph and Builder Tests run against the new Concurrent classes, except where that doesn't make sense --- .../org/apache/lucene/util/AtomicBitSet.java | 165 +++++++++ .../util/hnsw/ConcurrentHnswGraphBuilder.java | 337 ++++++++++++++++++ .../util/hnsw/ConcurrentNeighborSet.java | 161 +++++++++ .../util/hnsw/ConcurrentOnHeapHnswGraph.java | 291 +++++++++++++++ .../apache/lucene/util/hnsw/HnswGraph.java | 5 +- .../lucene/util/hnsw/HnswGraphSearcher.java | 2 - .../lucene/util/hnsw/OnHeapHnswGraph.java | 2 + .../lucene/util/hnsw/HnswGraphTestCase.java | 97 +++-- .../util/hnsw/TestConcurrentNeighborSet.java | 78 ++++ .../util/hnsw/TestHnswFloatVectorGraph.java | 7 +- 10 files changed, 1085 insertions(+), 60 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/util/AtomicBitSet.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentNeighborSet.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentOnHeapHnswGraph.java create mode 100644 lucene/core/src/test/org/apache/lucene/util/hnsw/TestConcurrentNeighborSet.java diff --git a/lucene/core/src/java/org/apache/lucene/util/AtomicBitSet.java b/lucene/core/src/java/org/apache/lucene/util/AtomicBitSet.java new file mode 100644 index 000000000000..92dbec8747a7 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/AtomicBitSet.java @@ -0,0 +1,165 @@ +package org.apache.lucene.util; + +import java.util.concurrent.atomic.AtomicLongArray; +import org.apache.lucene.search.DocIdSetIterator; + +/** + * A {@link BitSet} implementation that offers concurrent, lock-free access through an {@link + * AtomicLongArray} as bit storage. + */ +public class AtomicBitSet extends BitSet { + private static final long BASE_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(AtomicBitSet.class); + + private AtomicLongArray storage; + private int numBits; + + public AtomicBitSet(int numBits) { + this.numBits = numBits; + int numLongs = (numBits + 63) >>> 6; + storage = new AtomicLongArray(numLongs); + } + + private static int index(int bit) { + return bit >>> 6; + } + + private static long mask(int bit) { + return 1L << (bit & 63); + } + + @Override + public int length() { + return numBits; + } + + private void expandStorage(int minCapacity) { + int numLongs = (minCapacity + 63) >>> 6; + AtomicLongArray newStorage = new AtomicLongArray(numLongs); + for (int i = 0; i < storage.length(); i++) { + newStorage.set(i, storage.get(i)); + } + storage = newStorage; + numBits = numLongs << 6; + } + + @Override + public void set(int i) { + if (i >= numBits) { + expandStorage(i + 1); + } + int idx = index(i); + long mask = mask(i); + long prev, next; + do { + prev = storage.get(idx); + next = prev | mask; + } while (!storage.compareAndSet(idx, prev, next)); + } + + @Override + public boolean get(int i) { + if (i >= numBits) { + return false; + } + int idx = index(i); + long mask = mask(i); + long value = storage.get(idx); + return (value & mask) != 0; + } + + @Override + public boolean getAndSet(int i) { + if (i >= numBits) { + expandStorage(i + 1); + } + int idx = index(i); + long mask = mask(i); + long prev, next; + do { + prev = storage.get(idx); + next = prev | mask; + } while (!storage.compareAndSet(idx, prev, next)); + return (prev & mask) != 0; + } + + @Override + public void clear(int i) { + if (i >= numBits) { + return; + } + int idx = index(i); + long mask = mask(i); + long prev, next; + do { + prev = storage.get(idx); + next = prev & ~mask; + } while (!storage.compareAndSet(idx, prev, next)); + } + + @Override + public void clear(int startIndex, int endIndex) { + int startIdx = index(startIndex); + int endIdx = index(endIndex - 1); + for (int i = startIdx; i <= endIdx; i++) { + long prev, next; + do { + prev = storage.get(i); + next = (i == startIdx) ? prev & ~(-1L << (startIndex & 63)) : prev; + next = (i == endIdx) ? next & (-1L << (endIndex & 63)) : next; + } while (!storage.compareAndSet(i, prev, next)); + } + } + + @Override + public int cardinality() { + int count = 0; + for (int i = 0; i < storage.length(); i++) { + count += Long.bitCount(storage.get(i)); + } + return count; + } + + @Override + public int approximateCardinality() { + return cardinality(); + } + + @Override + public int prevSetBit(int index) { + int idx = index(index); + long mask = (1L << (index & 63)) - 1; + + for (int i = idx; i >= 0; i--) { + long word = storage.get(i) & mask; + if (word != 0) { + return (i << 6) + Long.numberOfTrailingZeros(Long.lowestOneBit(word)); + } + mask = -1L; + } + return -1; + } + + @Override + public int nextSetBit(int index) { + int idx = index(index); + long mask = -1L >>> (63 - (index & 63)); + + for (int i = idx; i < storage.length(); i++) { + long word = storage.get(i) & mask; + if (word != 0) { + return (i << 6) + Long.numberOfTrailingZeros(Long.lowestOneBit(word)); + } + mask = -1L; + } + return DocIdSetIterator.NO_MORE_DOCS; + } + + @Override + public long ramBytesUsed() { + final int longSizeInBytes = Long.BYTES; + final int arrayOverhead = 16; // Estimated overhead of AtomicLongArray object in bytes + long storageSize = (long) storage.length() * longSizeInBytes + arrayOverhead; + return BASE_RAM_BYTES_USED + storageSize; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.java new file mode 100644 index 000000000000..46243720bd3f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static java.lang.Math.log; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.AtomicBitSet; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.hnsw.ConcurrentOnHeapHnswGraph.NodeAtLevel; + +/** + * Builder for Concurrent HNSW graph. See {@link HnswGraph} for a high level overview, + * and the comments to `addGraphNode` for details on the concurrent building approach. + * + * @param the type of vector + */ +public final class ConcurrentHnswGraphBuilder { + + /** Default number of maximum connections per node */ + public static final int DEFAULT_MAX_CONN = 16; + + /** + * Default number of the size of the queue maintained while searching during a graph construction. + */ + public static final int DEFAULT_BEAM_WIDTH = 100; + + /** A name for the HNSW component for the info-stream * */ + public static final String HNSW_COMPONENT = "HNSW"; + + private final int beamWidth; + private final double ml; + private final ThreadLocal scratchNeighbors; + + private final VectorSimilarityFunction similarityFunction; + private final boolean parallelBuild; + private final VectorEncoding vectorEncoding; + private final RandomAccessVectorValues vectors; + private final ThreadLocal> graphSearcher; + + final ConcurrentOnHeapHnswGraph hnsw; + private final ConcurrentSkipListSet insertionsInProgress = + new ConcurrentSkipListSet<>(); + + private InfoStream infoStream = InfoStream.getDefault(); + + // we need two sources of vectors in order to perform diversity check comparisons without + // colliding + private final RandomAccessVectorValues vectorsCopy; + private final AtomicBitSet initializedNodes; + + /** + * This factory matches HnswGraphBuilder's signature for convenience. "_seed" is ignored since the + * Concurrent classes use ThreadLocalRandom. Building will be done in parallel. + */ + public static ConcurrentHnswGraphBuilder create( + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth, + long _seed) + throws IOException { + return new ConcurrentHnswGraphBuilder<>( + vectors, vectorEncoding, similarityFunction, M, beamWidth, true); + } + + /** This is the "native" factory for ConcurrentHnswGraphBuilder. */ + public static ConcurrentHnswGraphBuilder create( + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth, + boolean parallelize) + throws IOException { + return new ConcurrentHnswGraphBuilder<>( + vectors, vectorEncoding, similarityFunction, M, beamWidth, parallelize); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param vectors the vectors whose relations are represented by the graph - must provide a + * different view over those vectors than the one used to add via addGraphNode. + * @param M – graph fanout parameter used to calculate the maximum number of connections a node + * can have – M on upper layers, and M * 2 on the lowest level. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param parallelize use multiple threads to build the graph. + */ + private ConcurrentHnswGraphBuilder( + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth, + boolean parallelize) + throws IOException { + this.vectors = vectors; + this.vectorsCopy = vectors.copy(); + this.vectorEncoding = Objects.requireNonNull(vectorEncoding); + this.similarityFunction = Objects.requireNonNull(similarityFunction); + this.parallelBuild = parallelize; + if (M <= 0) { + throw new IllegalArgumentException("maxConn must be positive"); + } + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + this.beamWidth = beamWidth; + // normalization factor for level generation; currently not configurable + this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); + this.hnsw = new ConcurrentOnHeapHnswGraph(M); + this.graphSearcher = + ThreadLocal.withInitial( + () -> { + return new HnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(beamWidth, true), + new FixedBitSet(this.vectors.size())); + }); + // in scratch we store candidates in reverse order: worse candidates are first + scratchNeighbors = + ThreadLocal.withInitial(() -> new NeighborArray(Math.max(beamWidth, M + 1), false)); + this.initializedNodes = new AtomicBitSet(vectors.size()); + } + + /** + * Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two + * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * returned values. + * + * @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an + * independent accessor for the vectors + */ + public ConcurrentOnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) + throws IOException { + if (vectorsToAdd == this.vectors) { + throw new IllegalArgumentException( + "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors"); + } + addVectors(vectorsToAdd); + return hnsw; + } + + private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException { + var stream = IntStream.range(0, vectorsToAdd.size()); + if (parallelBuild) { // TODO + stream = stream.parallel(); + } + stream.forEach( + node -> { + try { + addGraphNode(node, vectorsToAdd); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + /** Set info-stream to output debugging information * */ + public void setInfoStream(InfoStream infoStream) { + this.infoStream = infoStream; + } + + public ConcurrentOnHeapHnswGraph getGraph() { + return hnsw; + } + + /** + * Inserts a doc with vector value to the graph. + *

+ * To allow correctness under concurrency, we track in-progress updates in a + * ConcurrentSkipListSet. After adding ourselves, we take a snapshot of this set, and consider all + * other in-progress updates as neighbor candidates (subject to normal level constraints). + */ + public void addGraphNode(int node, T value) throws IOException { + if (initializedNodes.getAndSet(node)) { + return; // already initialized + } + + // do this before adding to in-progress, so a concurrent writer checking + // the in-progress set doesn't have to worry about uninitialized neighbor sets + final int nodeLevel = getRandomGraphLevel(ml); + for (int level = nodeLevel; level >= 0; level--) { + hnsw.addNode(level, node); + } + + var progressMarker = new NodeAtLevel(nodeLevel, node); + insertionsInProgress.add(progressMarker); + var inProgressBefore = insertionsInProgress.clone(); + try { + NeighborQueue candidates; + int curMaxLevel = hnsw.numLevels() - 1; + + // find ANN of the new node by searching the graph + int ep = hnsw.entryNode(); + int[] entryPoints = ep >= 0 ? new int[] {ep} : new int[0]; + // for levels > nodeLevel search with topk = 1 + for (int level = curMaxLevel; level > nodeLevel; level--) { + candidates = + graphSearcher.get().searchLevel(value, 1, level, entryPoints, vectors, hnsw.getView()); + entryPoints = new int[] {candidates.pop()}; + } + // for levels <= nodeLevel search with topk = beamWidth, and add connections + for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) { + // find best candidates at this level with a beam search + candidates = + graphSearcher + .get() + .searchLevel(value, beamWidth, level, entryPoints, vectors, hnsw.getView()); + // any nodes that are being added concurrently at this level are also candidates + for (var concurrentCandidate : inProgressBefore) { + if (concurrentCandidate.level < level || concurrentCandidate == progressMarker) { + continue; + } + float score = scoreBetween(value, vectors.vectorValue(concurrentCandidate.node)); + candidates.add(concurrentCandidate.node, score); + if (candidates.size() > beamWidth) { + candidates.pop(); + } + } + // update entry points and neighbors with these candidates + entryPoints = candidates.nodes(); + addDiverseNeighbors(level, node, candidates); + } + + // update entry node last, once everything is wired together + hnsw.maybeUpdateEntryNode(nodeLevel, node); + } finally { + insertionsInProgress.remove(progressMarker); + } + } + + public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException { + addGraphNode(node, values.vectorValue(node)); + } + + private long printGraphBuildStatus(int node, long start, long t) { + long now = System.nanoTime(); + infoStream.message( + HNSW_COMPONENT, + String.format( + Locale.ROOT, + "built %d in %d/%d ms", + node, + TimeUnit.NANOSECONDS.toMillis(now - t), + TimeUnit.NANOSECONDS.toMillis(now - start))); + return now; + } + + private void addDiverseNeighbors(int level, int newNode, NeighborQueue candidates) { + // Add links from new node -> candidates. See ConcurrentNeighborSet for an explanation of + // "diverse." + var neighbors = hnsw.getNeighbors(level, newNode); + var scratch = popToScratch(candidates); + neighbors.insertDiverse(scratch, this::scoreBetween); + + // Add links from candidates -> new node (again applying diversity heuristic) + neighbors.stream() + .forEach( + entry -> { + var nbr = entry.getValue(); + var nbrScore = entry.getKey(); + var nbrNbr = hnsw.getNeighbors(level, nbr); + nbrNbr.insert(newNode, nbrScore, this::scoreBetween); + }); + } + + private NeighborArray popToScratch(NeighborQueue candidates) { + var scratch = this.scratchNeighbors.get(); + scratch.clear(); + int candidateCount = candidates.size(); + // extract all the Neighbors from the queue into an array; these will now be + // sorted from worst to best + for (int i = 0; i < candidateCount; i++) { + float maxSimilarity = candidates.topScore(); + scratch.add(candidates.pop(), maxSimilarity); + } + return scratch; + } + + private float scoreBetween(int i, int j) { + try { + final T v1 = vectors.vectorValue(i); + final T v2 = vectors.vectorValue(j); + return scoreBetween(v1, v2); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private float scoreBetween(T v1, T v2) { + return switch (vectorEncoding) { + case BYTE -> similarityFunction.compare((byte[]) v1, (byte[]) v2); + case FLOAT32 -> similarityFunction.compare((float[]) v1, (float[]) v2); + }; + } + + private static int getRandomGraphLevel(double ml) { + double randDouble; + do { + randDouble = + ThreadLocalRandom.current().nextDouble(); // avoid 0 value, as log(0) is undefined + } while (randDouble == 0.0); + return ((int) (-log(randDouble) * ml)); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentNeighborSet.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentNeighborSet.java new file mode 100644 index 000000000000..1b373eb9aaa2 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentNeighborSet.java @@ -0,0 +1,161 @@ +package org.apache.lucene.util.hnsw; + +import java.util.*; +import java.util.Map.Entry; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.stream.Stream; +import org.apache.lucene.util.NumericUtils; + +/** + * A concurrent set of neighbors + * + *

Neighbors are stored in a concurrent navigable set by encoding ordinal and score together in a + * long. This means we can quickly iterate either forwards, or backwards. + * + *

The maximum connection count is loosely maintained -- meaning, we tolerate temporarily + * exceeding the max size by a number of elements up to the number of threads performing concurrent + * inserts, but it will always be reduced back to the cap immediately afterwards. This avoids taking + * out a Big Lock to impose a strict cap. + */ +public class ConcurrentNeighborSet { + private final ConcurrentSkipListSet neighbors; + private final int maxConnections; + private final AtomicInteger size; + + public ConcurrentNeighborSet(int maxConnections) { + this.maxConnections = maxConnections; + neighbors = new ConcurrentSkipListSet<>(Comparator.naturalOrder().reversed()); + size = new AtomicInteger(); + } + + public Iterator nodeIterator() { + return neighbors.stream().map(ConcurrentNeighborSet::decodeNodeId).iterator(); + } + + public int size() { + return size.get(); + } + + public Stream> stream() { + return neighbors.stream() + .map(encoded -> Map.entry(decodeScore(encoded), decodeNodeId(encoded))); + } + + /** + * For each candidate (going from best to worst), select it only if it is closer to target than it + * is to any of the already-selected neighbors. This is maintained whether those other neighbors + * were selected by this method, or were added as a "backlink" to a node inserted concurrently + * that chose this one as a neighbor. + */ + public void insertDiverse( + NeighborArray candidates, BiFunction scoreBetween) { + for (int i = candidates.size() - 1; neighbors.size() < maxConnections && i >= 0; i--) { + int cNode = candidates.node[i]; + float cScore = candidates.score[i]; + // TODO in the paper, the diversity requirement is only enforced when there are more than + // maxConn + if (isDiverse(cNode, cScore, scoreBetween)) { + // raw inserts (invoked by other threads inserting neighbors) could happen concurrently, + // so don't "cheat" and do a raw put() + insert(cNode, cScore, scoreBetween); + } + } + // TODO follow the paper's suggestion and fill up the rest of the neighbors with non-diverse + // candidates? + } + + /** + * Insert a new neighbor, maintaining our size cap by removing the least diverse neighbor if + * necessary. + */ + public void insert(int node, float score, BiFunction scoreBetween) { + neighbors.add(encode(node, score)); + if (size.incrementAndGet() > maxConnections) { + removeLeastDiverse(scoreBetween); + size.decrementAndGet(); + } + } + + // is the candidate node with the given score closer to the base node than it is to any of the + // existing neighbors + private boolean isDiverse( + int node, float score, BiFunction scoreBetween) { + return stream().noneMatch(e -> scoreBetween.apply(e.getValue(), node) > score); + } + + /** + * find the first node e1 starting with the last neighbor (i.e. least similar to the base node), + * look at all nodes e2 that are closer to the base node than e1 is. if any e2 is closer to e1 + * than e1 is to the base node, remove e1. + */ + private void removeLeastDiverse(BiFunction scoreBetween) { + for (var e1 : neighbors.descendingSet()) { + var e1Id = decodeNodeId(e1); + var baseScore = decodeScore(e1); + + var e2Iterator = iteratorStartingAfter(neighbors, e1); + while (e2Iterator.hasNext()) { + var e2 = e2Iterator.next(); + var e2Id = decodeNodeId(e2); + var e1e2Score = scoreBetween.apply(e1Id, e2Id); + if (e1e2Score >= baseScore) { + if (neighbors.remove(e1)) { + return; + } + // else another thread already removed it, keep looking + } + } + } + // couldn't find any "non-diverse" neighbors, so remove the one farthest from the base node + neighbors.remove(neighbors.last()); + } + + /** + * Returns an iterator over the entries in the set, starting at the entry *after* the given key. + * So iteratorStartingAfter(map, 2) invoked on a set with keys [1, 2, 3, 4] would return an + * iterator over the entries [3, 4]. + */ + private static Iterator iteratorStartingAfter(NavigableSet set, K key) { + // this isn't ideal, since the iteration will be worst case O(N log N), but since the worst + // scores will usually be the first ones we iterate through, the average case is much better + return new Iterator<>() { + private K nextItem = set.lower(key); + + @Override + public boolean hasNext() { + return nextItem != null; + } + + @Override + public K next() { + K current = nextItem; + nextItem = set.lower(nextItem); + return current; + } + }; + } + + public boolean contains(int i) { + for (var e : neighbors) { + if (decodeNodeId(e) == i) { + return true; + } + } + return false; + } + + // as found in NeighborQueue + static long encode(int node, float score) { + return (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node); + } + + static float decodeScore(long heapValue) { + return NumericUtils.sortableIntToFloat((int) (heapValue >> 32)); + } + + static int decodeNodeId(long heapValue) { + return (int) ~(heapValue); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentOnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentOnHeapHnswGraph.java new file mode 100644 index 000000000000..ffbd30ab2321 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentOnHeapHnswGraph.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.util.Iterator; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; + +/** + * An {@link HnswGraph} that offers concurrent access; for typical graphs you will get significant + * speedups in construction and searching as you add threads. + * + *

To search this graph, you should use a View obtained from {@link #getView()} to perform `seek` + * and `nextNeighbor` operations. For convenience, you can use these methods directly on the graph + * instance, which will give you a ThreadLocal View, but you can call `getView` directly if you need + * more control, e.g. for performing a second search in the same thread while the first is still in + * progress. + */ +public final class ConcurrentOnHeapHnswGraph extends HnswGraph implements Accountable { + private final AtomicReference + entryPoint; // the current graph entry node on the top level. -1 if not set + + // views for compatibility with HnswGraph interface; prefer creating views explicitly + private final ThreadLocal views = + ThreadLocal.withInitial(ConcurrentHnswGraphView::new); + + // Unlike OnHeapHnswGraph (OHHG), we use the same data structure for Level 0 and higher node + // lists, + // a ConcurrentHashMap. While the ArrayList used for L0 in OHHG is faster for single-threaded + // workloads, it imposes an unacceptable contention burden for concurrent workloads. + private final ConcurrentMap graphLevel0; + private final ConcurrentMap> + graphUpperLevels; + + // Neighbours' size on upper levels (nsize) and level 0 (nsize0) + private final int nsize; + private final int nsize0; + + ConcurrentOnHeapHnswGraph(int M) { + this.graphLevel0 = new ConcurrentHashMap<>(); + this.entryPoint = + new AtomicReference<>( + new NodeAtLevel(0, -1)); // Entry node should be negative until a node is added + this.nsize = M; + this.nsize0 = 2 * M; + + this.graphUpperLevels = new ConcurrentHashMap<>(); + } + + /** + * Returns the neighbors connected to the given node. + * + * @param level level of the graph + * @param node the node whose neighbors are returned, represented as an ordinal on the level 0. + */ + public ConcurrentNeighborSet getNeighbors(int level, int node) { + if (level == 0) return graphLevel0.get(node); + return graphUpperLevels.get(level).get(node); + } + + @Override + public synchronized int size() { + return graphLevel0.size(); // all nodes are located on the 0th level + } + + /** + * Add node on the given level with an empty set of neighbors. + * + *

Nodes can be inserted out of order, but it requires that the nodes preceded by the node + * inserted out of order are eventually added. + * + *

Actually populating the neighbors, and establishing bidirectional links, is the + * responsibility of the caller. + * + * @param level level to add a node on + * @param node the node to add, represented as an ordinal on the level 0. + */ + public void addNode(int level, int node) { + if (level > 0) { + if (level >= graphUpperLevels.size()) { + for (int i = graphUpperLevels.size(); i <= level; i++) { + graphUpperLevels.putIfAbsent(i, new ConcurrentHashMap<>()); + } + } + + graphUpperLevels.get(level).put(node, new ConcurrentNeighborSet(connectionsOnLevel(level))); + } else { + // Add nodes all the way up to and including "node" in the new graph on level 0. This will + // cause the size of the + // graph to differ from the number of nodes added to the graph. The size of the graph and the + // number of nodes + // added will only be in sync once all nodes from 0...last_node are added into the graph. + if (node >= graphLevel0.size()) { + for (int i = graphLevel0.size(); i <= node; i++) { + graphLevel0.putIfAbsent(i, new ConcurrentNeighborSet(nsize0)); + } + } + } + } + + /** + * must be called after addNode to a level > 0 + * + *

we don't do this as part of addNode itself, since it may not yet have been added to all the + * levels + */ + void maybeUpdateEntryNode(int level, int node) { + while (true) { + var oldEntry = entryPoint.get(); + if (oldEntry.node >= 0 && oldEntry.level >= level) { + break; + } + entryPoint.compareAndSet(oldEntry, new NodeAtLevel(level, node)); + } + } + + private int connectionsOnLevel(int level) { + return level == 0 ? nsize0 : nsize; + } + + @Override + public void seek(int level, int target) throws IOException { + views.get().seek(level, target); + } + + @Override + public int nextNeighbor() throws IOException { + return views.get().nextNeighbor(); + } + + /** + * @return the current number of levels in the graph where nodes have been added and we have a + * valid entry point. + */ + @Override + public int numLevels() { + return entryPoint.get().level + 1; + } + + /** + * Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th + * level + * + * @return the graph's current entry node on the top level + */ + @Override + public int entryNode() { + return entryPoint.get().node; + } + + @Override + public NodesIterator getNodesOnLevel(int level) { + if (level == 0) { + return new ArrayNodesIterator(size()); + } else { + return new CollectionNodesIterator(graphUpperLevels.get(level).keySet()); + } + } + + @Override + public long ramBytesUsed() { + long neighborArrayBytes0 = + nsize0 * (Integer.BYTES + Float.BYTES) + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2 + + RamUsageEstimator.NUM_BYTES_OBJECT_REF + + Integer.BYTES * 2; + long neighborArrayBytes = + nsize * (Integer.BYTES + Float.BYTES) + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2 + + RamUsageEstimator.NUM_BYTES_OBJECT_REF + + Integer.BYTES * 2; + long total = 0; + for (int l = 0; l <= entryPoint.get().level; l++) { + if (l == 0) { + total += + size() * neighborArrayBytes0 + RamUsageEstimator.NUM_BYTES_OBJECT_REF; // for graph; + } else { + long numNodesOnLevel = graphUpperLevels.get(l).size(); + + // For levels > 0, we represent the graph structure with a tree map. + // A single node in the tree contains 3 references (left root, right root, value) as well + // as an Integer for the key and 1 extra byte for the color of the node (this is actually 1 + // bit, but + // because we do not have that granularity, we set to 1 byte). In addition, we include 1 + // more reference for + // the tree map itself. + total += + numNodesOnLevel * (3L * RamUsageEstimator.NUM_BYTES_OBJECT_REF + Integer.BYTES + 1) + + RamUsageEstimator.NUM_BYTES_OBJECT_REF; + + // Add the size neighbor of each node + total += numNodesOnLevel * neighborArrayBytes; + } + } + return total; + } + + @Override + public String toString() { + return "ConcurrentOnHeapHnswGraph(size=%d, entryPoint=%s)".formatted(size(), entryPoint.get()); + } + + /** + * Returns a view of the graph that is safe to use concurrently with updates performed on the + * underlying graph. + * + *

Multiple Views may be searched concurrently. + */ + public HnswGraph getView() { + return new ConcurrentHnswGraphView(); + } + + private class ConcurrentHnswGraphView extends HnswGraph { + private Iterator remainingNeighbors; + + @Override + public int size() { + return ConcurrentOnHeapHnswGraph.this.size(); + } + + @Override + public int numLevels() { + return ConcurrentOnHeapHnswGraph.this.numLevels(); + } + + @Override + public int entryNode() { + return ConcurrentOnHeapHnswGraph.this.entryNode(); + } + + @Override + public NodesIterator getNodesOnLevel(int level) { + return ConcurrentOnHeapHnswGraph.this.getNodesOnLevel(level); + } + + @Override + public void seek(int level, int targetNode) { + remainingNeighbors = getNeighbors(level, targetNode).nodeIterator(); + } + + @Override + public int nextNeighbor() { + return remainingNeighbors.hasNext() ? remainingNeighbors.next() : NO_MORE_DOCS; + } + } + + static final class NodeAtLevel implements Comparable { + public final int level; + public final int node; + + public NodeAtLevel(int level, int node) { + this.level = level; + this.node = node; + } + + @Override + public int compareTo(NodeAtLevel o) { + int cmp = Integer.compare(level, o.level); + if (cmp == 0) { + cmp = Integer.compare(node, o.node); + } + return cmp; + } + + @Override + public String toString() { + return "NodeAtLevel [level=%d, node=%d]".formatted(level, node); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index e708cdfbd76a..030efc05ad97 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -44,9 +44,8 @@ * many of the efConst neighbors are connected to the new node * * - *

Note: The graph may be searched by multiple threads concurrently, but updates are not - * thread-safe. The search method optionally takes a set of "accepted nodes", which can be used to - * exclude deleted documents. + *

Note: The search method optionally takes a set of "accepted nodes", which can be used to + * exclude deleted documents. Thread safety of searches depends on the implementation. */ public abstract class HnswGraph { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 9716e62f8f52..1f9963f02e08 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -224,7 +224,6 @@ private NeighborQueue searchLevel( Bits acceptOrds, int visitedLimit) throws IOException { - int size = graph.size(); NeighborQueue results = new NeighborQueue(topK, false); prepareScratchState(vectors.size()); @@ -261,7 +260,6 @@ private NeighborQueue searchLevel( graph.seek(level, topCandidateNode); int friendOrd; while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) { - assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; if (visited.getAndSet(friendOrd)) { continue; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index a49fbe46bbe8..0ba73ba13372 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -29,6 +29,8 @@ /** * An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to * construct the HNSW graph before it's written to the index. + * + *

This implementation is NOT threadsafe for insertion or for searching. */ public final class OnHeapHnswGraph extends HnswGraph implements Accountable { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 8dc368123491..65d42b349e90 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -54,7 +54,6 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -166,8 +165,6 @@ public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException { int beamWidth = random().nextInt(10) + 5; VectorSimilarityFunction similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values()); - long seed = random().nextLong(); - HnswGraphBuilder.randSeed = seed; IndexWriterConfig iwc = new IndexWriterConfig() .setCodec( @@ -309,10 +306,10 @@ public void testAknnDiverse() throws IOException { int nDoc = 100; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomAccessVectorValues vectors = circularVectorValues(nDoc); - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 10, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + ConcurrentOnHeapHnswGraph hnsw = builder.build(vectors.copy()); // run some searches NeighborQueue nn = switch (getVectorEncoding()) { @@ -347,11 +344,11 @@ public void testAknnDiverse() throws IOException { assertTrue("sum(result docs)=" + sum, sum < 75); for (int i = 0; i < nDoc; i++) { - NeighborArray neighbors = hnsw.getNeighbors(0, i); - int[] nnodes = neighbors.node; - for (int j = 0; j < neighbors.size(); j++) { + var neighbors = hnsw.getNeighbors(0, i); + var it = neighbors.nodeIterator(); + while (it.hasNext()) { // all neighbors should be valid node ids. - assertTrue(nnodes[j] < nDoc); + assertTrue(it.next() < nDoc); } } } @@ -361,10 +358,10 @@ public void testSearchWithAcceptOrds() throws IOException { int nDoc = 100; RandomAccessVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + ConcurrentOnHeapHnswGraph hnsw = builder.build(vectors.copy()); // the first 10 docs must not be deleted to ensure the expected recall Bits acceptOrds = createRandomAcceptOrds(10, nDoc); NeighborQueue nn = @@ -405,10 +402,10 @@ public void testSearchWithSelectiveAcceptOrds() throws IOException { int nDoc = 100; RandomAccessVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + ConcurrentOnHeapHnswGraph hnsw = builder.build(vectors.copy()); // Only mark a few vectors as accepted BitSet acceptOrds = new FixedBitSet(nDoc); for (int i = 0; i < nDoc; i += random().nextInt(15, 20)) { @@ -464,7 +461,7 @@ public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { } } - OnHeapHnswGraph topDownOrderReversedHnsw = new OnHeapHnswGraph(10); + ConcurrentOnHeapHnswGraph topDownOrderReversedHnsw = new ConcurrentOnHeapHnswGraph(10); for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) { List currLevelNodes = nodesPerLevel.get(currLevel); int currLevelNodesSize = currLevelNodes.size(); @@ -473,7 +470,7 @@ public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { } } - OnHeapHnswGraph bottomUpOrderReversedHnsw = new OnHeapHnswGraph(10); + ConcurrentOnHeapHnswGraph bottomUpOrderReversedHnsw = new ConcurrentOnHeapHnswGraph(10); for (int currLevel = 0; currLevel < numLevels; currLevel++) { List currLevelNodes = nodesPerLevel.get(currLevel); int currLevelNodesSize = currLevelNodes.size(); @@ -482,7 +479,7 @@ public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { } } - OnHeapHnswGraph topDownOrderRandomHnsw = new OnHeapHnswGraph(10); + ConcurrentOnHeapHnswGraph topDownOrderRandomHnsw = new ConcurrentOnHeapHnswGraph(10); for (int currLevel = numLevels - 1; currLevel >= 0; currLevel--) { List currLevelNodes = new ArrayList<>(nodesPerLevel.get(currLevel)); Collections.shuffle(currLevelNodes, random()); @@ -491,7 +488,7 @@ public void testBuildOnHeapHnswGraphOutOfOrder() throws IOException { } } - OnHeapHnswGraph bottomUpExpectedHnsw = new OnHeapHnswGraph(10); + ConcurrentOnHeapHnswGraph bottomUpExpectedHnsw = new ConcurrentOnHeapHnswGraph(10); for (int currLevel = 0; currLevel < numLevels; currLevel++) { for (Integer currNode : nodesPerLevel.get(currLevel)) { bottomUpExpectedHnsw.addNode(currLevel, currNode); @@ -679,10 +676,10 @@ public void testVisitedLimit() throws IOException { int nDoc = 500; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomAccessVectorValues vectors = circularVectorValues(nDoc); - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + ConcurrentOnHeapHnswGraph hnsw = builder.build(vectors.copy()); int topK = 50; int visitedLimit = topK + random().nextInt(5); @@ -715,12 +712,13 @@ public void testVisitedLimit() throws IOException { public void testHnswGraphBuilderInvalid() { expectThrows( - NullPointerException.class, () -> HnswGraphBuilder.create(null, null, null, 0, 0, 0)); + NullPointerException.class, + () -> ConcurrentHnswGraphBuilder.create(null, null, null, 0, 0, 0)); // M must be > 0 expectThrows( IllegalArgumentException.class, () -> - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder.create( vectorValues(1, 1), getVectorEncoding(), VectorSimilarityFunction.EUCLIDEAN, @@ -731,7 +729,7 @@ public void testHnswGraphBuilderInvalid() { expectThrows( IllegalArgumentException.class, () -> - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder.create( vectorValues(1, 1), getVectorEncoding(), VectorSimilarityFunction.EUCLIDEAN, @@ -749,10 +747,10 @@ public void testRamUsageEstimate() throws IOException { RandomizedTest.randomFrom(VectorSimilarityFunction.values()); RandomAccessVectorValues vectors = vectorValues(size, dim); - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, M, M * 2, random().nextLong()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + ConcurrentOnHeapHnswGraph hnsw = builder.build(vectors.copy()); long estimated = RamUsageEstimator.sizeOfObject(hnsw); long actual = ramUsed(hnsw); @@ -774,8 +772,8 @@ public void testDiversity() throws IOException { }; AbstractMockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt()); // node 0 is added by the builder constructor RandomAccessVectorValues vectorsCopy = vectors.copy(); @@ -830,8 +828,8 @@ public void testDiversityFallback() throws IOException { }; AbstractMockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt()); RandomAccessVectorValues vectorsCopy = vectors.copy(); builder.addGraphNode(0, vectorsCopy); @@ -862,8 +860,8 @@ public void testDiversity3d() throws IOException { }; AbstractMockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt()); RandomAccessVectorValues vectorsCopy = vectors.copy(); builder.addGraphNode(0, vectorsCopy); @@ -883,10 +881,10 @@ public void testDiversity3d() throws IOException { assertLevel0Neighbors(builder.hnsw, 3, 0, 1); } - private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) { + private void assertLevel0Neighbors(ConcurrentOnHeapHnswGraph graph, int node, int... expected) { Arrays.sort(expected); - NeighborArray nn = graph.getNeighbors(0, node); - int[] actual = ArrayUtil.copyOfSubArray(nn.node, 0, nn.size()); + var nn = graph.getNeighbors(0, node); + var actual = nn.stream().mapToInt(e -> e.getValue()).toArray(); Arrays.sort(actual); assertArrayEquals( "expected: " + Arrays.toString(expected) + " actual: " + Arrays.toString(actual), @@ -900,10 +898,10 @@ public void testRandom() throws IOException { int dim = atLeast(10); AbstractMockVectorValues vectors = vectorValues(size, dim); int topK = 5; - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + ConcurrentOnHeapHnswGraph hnsw = builder.build(vectors.copy()); Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); int totalMatches = 0; @@ -982,13 +980,11 @@ private int computeOverlap(int[] a, int[] b) { static class CircularFloatVectorValues extends FloatVectorValues implements RandomAccessVectorValues { private final int size; - private final float[] value; int doc = -1; CircularFloatVectorValues(int size) { this.size = size; - value = new float[2]; } @Override @@ -1033,7 +1029,7 @@ public int advance(int target) { @Override public float[] vectorValue(int ord) { - return unitVector2d(ord / (double) size, value); + return unitVector2d(ord / (double) size); } } @@ -1041,15 +1037,11 @@ public float[] vectorValue(int ord) { static class CircularByteVectorValues extends ByteVectorValues implements RandomAccessVectorValues { private final int size; - private final float[] value; - private final byte[] bValue; int doc = -1; CircularByteVectorValues(int size) { this.size = size; - value = new float[2]; - bValue = new byte[2]; } @Override @@ -1094,7 +1086,8 @@ public int advance(int target) { @Override public byte[] vectorValue(int ord) { - unitVector2d(ord / (double) size, value); + var value = unitVector2d(ord / (double) size); + var bValue = new byte[value.length]; for (int i = 0; i < value.length; i++) { bValue[i] = (byte) (value[i] * 127); } @@ -1107,9 +1100,9 @@ private static float[] unitVector2d(double piRadians) { } private static float[] unitVector2d(double piRadians, float[] value) { - value[0] = (float) Math.cos(Math.PI * piRadians); - value[1] = (float) Math.sin(Math.PI * piRadians); - return value; + return new float[] { + (float) Math.cos(Math.PI * piRadians), (float) Math.sin(Math.PI * piRadians) + }; } private Set getNeighborNodes(HnswGraph g) throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestConcurrentNeighborSet.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestConcurrentNeighborSet.java new file mode 100644 index 000000000000..a8f0fb1fce91 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestConcurrentNeighborSet.java @@ -0,0 +1,78 @@ +package org.apache.lucene.util.hnsw; + +import static org.apache.lucene.util.hnsw.ConcurrentNeighborSet.*; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.IntStream; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestConcurrentNeighborSet extends LuceneTestCase { + private static final BiFunction simpleScore = + (a, b) -> { + return VectorSimilarityFunction.EUCLIDEAN.compare(new float[] {a}, new float[] {b}); + }; + + private static float baseScore(int neighbor) { + return simpleScore.apply(0, neighbor); + } + + public void testInsertAndSize() { + ConcurrentNeighborSet neighbors = new ConcurrentNeighborSet(2); + neighbors.insert(1, baseScore(1), simpleScore); + neighbors.insert(2, baseScore(2), simpleScore); + assertEquals(2, neighbors.size()); + + neighbors.insert(3, baseScore(3), simpleScore); + assertEquals(2, neighbors.size()); + } + + public void testRemoveLeastDiverseFromEnd() { + ConcurrentNeighborSet neighbors = new ConcurrentNeighborSet(3); + neighbors.insert(1, baseScore(1), simpleScore); + neighbors.insert(2, baseScore(2), simpleScore); + neighbors.insert(3, baseScore(3), simpleScore); + assertEquals(3, neighbors.size()); + + neighbors.insert(4, baseScore(4), simpleScore); + assertEquals(3, neighbors.size()); + + List expectedValues = Arrays.asList(1, 2, 3); + Iterator iterator = neighbors.nodeIterator(); + for (Integer expectedValue : expectedValues) { + assertTrue(iterator.hasNext()); + assertEquals(expectedValue, iterator.next()); + } + assertFalse(iterator.hasNext()); + } + + public void testInsertDiverse() { + var similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; + var vectors = new HnswGraphTestCase.CircularFloatVectorValues(10); + var candidates = new NeighborArray(10, false); + BiFunction scoreBetween = + (a, b) -> { + return similarityFunction.compare(vectors.vectorValue(a), vectors.vectorValue(b)); + }; + var L = + IntStream.range(0, 10) + .filter(i -> i != 7) + .mapToLong(i -> encode(i, scoreBetween.apply(7, i))) + .sorted() + .toArray(); + for (int i = 0; i < L.length; i++) { + var encoded = L[i]; + candidates.add(decodeNodeId(encoded), decodeScore(encoded)); + } + assert candidates.size() == 9; + + var neighbors = new ConcurrentNeighborSet(3); + neighbors.insertDiverse(candidates, scoreBetween); + assert neighbors.size() == 2; + assert neighbors.contains(8); + assert neighbors.contains(6); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 5dda5bf0a838..379722beb97a 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -21,6 +21,7 @@ import com.carrotsearch.randomizedtesting.RandomizedTest; import java.io.IOException; +import java.util.*; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FloatVectorValues; @@ -127,10 +128,10 @@ public void testSearchWithSkewedAcceptOrds() throws IOException { int nDoc = 1000; similarityFunction = VectorSimilarityFunction.EUCLIDEAN; RandomAccessVectorValues vectors = circularVectorValues(nDoc); - HnswGraphBuilder builder = - HnswGraphBuilder.create( + ConcurrentHnswGraphBuilder builder = + ConcurrentHnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt()); - OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + ConcurrentOnHeapHnswGraph hnsw = builder.build(vectors.copy()); // Skip over half of the documents that are closest to the query vector FixedBitSet acceptOrds = new FixedBitSet(nDoc);