diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 230c33b875f0..4df94417bdcd 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -29,6 +29,8 @@ Optimizations * GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai) +* GITHUB#12235: Optimize HNSW diversity calculation. (Patrick Zhai) + Bug Fixes --------------------- (No changes) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index e7f16b4f3fc9..4b1f7068a5f2 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -183,10 +183,10 @@ private void addDiverseNeighbors(int node, NeighborQueue candidates) throws IOEx int size = neighbors.size(); for (int i = 0; i < size; i++) { int nbr = neighbors.node()[i]; - Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr); - nbrNbr.add(node, neighbors.score()[i]); - if (nbrNbr.size() > maxConn) { - diversityUpdate(nbrNbr); + Lucene90NeighborArray nbrsOfNbr = hnsw.getNeighbors(nbr); + nbrsOfNbr.add(node, neighbors.score()[i]); + if (nbrsOfNbr.size() > maxConn) { + diversityUpdate(nbrsOfNbr); } } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java index b0e9d160457b..c82920181cc4 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java @@ -204,10 +204,10 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) int size = neighbors.size(); for (int i = 0; i < size; i++) { int nbr = neighbors.node[i]; - Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr); - nbrNbr.add(node, neighbors.score[i]); - if (nbrNbr.size() > maxConn) { - diversityUpdate(nbrNbr); + Lucene91NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); + nbrsOfNbr.add(node, neighbors.score[i]); + if (nbrsOfNbr.size() > maxConn) { + diversityUpdate(nbrsOfNbr); } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 2c5e84be2859..9686186945da 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -220,7 +220,9 @@ private void initializeFromGraph( binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor)); break; } - newNeighbors.insertSorted(newNeighbor, score); + // we are not sure whether the previous graph contains + // unchecked nodes, so we have to assume they're all unchecked + newNeighbors.addOutOfOrder(newNeighbor, score); } } } @@ -316,11 +318,11 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) int size = neighbors.size(); for (int i = 0; i < size; i++) { int nbr = neighbors.node[i]; - NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr); - nbrNbr.insertSorted(node, neighbors.score[i]); - if (nbrNbr.size() > maxConnOnLevel) { - int indexToRemove = findWorstNonDiverse(nbrNbr); - nbrNbr.removeIndex(indexToRemove); + NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); + nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]); + if (nbrsOfNbr.size() > maxConnOnLevel) { + int indexToRemove = findWorstNonDiverse(nbrsOfNbr); + nbrsOfNbr.removeIndex(indexToRemove); } } } @@ -335,7 +337,7 @@ private void selectAndLinkDiverse( float cScore = candidates.score[i]; assert cNode < hnsw.size(); if (diversityCheck(cNode, cScore, neighbors)) { - neighbors.add(cNode, cScore); + neighbors.addInOrder(cNode, cScore); } } } @@ -347,7 +349,7 @@ private void popToScratch(NeighborQueue candidates) { // sorted from worst to best for (int i = 0; i < candidateCount; i++) { float maxSimilarity = candidates.topScore(); - scratch.add(candidates.pop(), maxSimilarity); + scratch.addInOrder(candidates.pop(), maxSimilarity); } } @@ -405,53 +407,119 @@ private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score * neighbours */ private int findWorstNonDiverse(NeighborArray neighbors) throws IOException { + int[] uncheckedIndexes = neighbors.sort(); + if (uncheckedIndexes == null) { + // all nodes are checked, we will directly return the most distant one + return neighbors.size() - 1; + } + int uncheckedCursor = uncheckedIndexes.length - 1; for (int i = neighbors.size() - 1; i > 0; i--) { - if (isWorstNonDiverse(i, neighbors)) { + if (uncheckedCursor < 0) { + // no unchecked node left + break; + } + if (isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) { return i; } + if (i == uncheckedIndexes[uncheckedCursor]) { + uncheckedCursor--; + } } return neighbors.size() - 1; } - private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors) + private boolean isWorstNonDiverse( + int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor) throws IOException { int candidateNode = neighbors.node[candidateIndex]; switch (vectorEncoding) { case BYTE: return isWorstNonDiverse( - candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors); + candidateIndex, + (byte[]) vectors.vectorValue(candidateNode), + neighbors, + uncheckedIndexes, + uncheckedCursor); default: case FLOAT32: return isWorstNonDiverse( - candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors); + candidateIndex, + (float[]) vectors.vectorValue(candidateNode), + neighbors, + uncheckedIndexes, + uncheckedCursor); } } private boolean isWorstNonDiverse( - int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException { + int candidateIndex, + float[] candidateVector, + NeighborArray neighbors, + int[] uncheckedIndexes, + int uncheckedCursor) + throws IOException { float minAcceptedSimilarity = neighbors.score[candidateIndex]; - for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = - similarityFunction.compare( - candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i])); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { - return true; + if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { + // the candidate itself is unchecked + for (int i = candidateIndex - 1; i >= 0; i--) { + float neighborSimilarity = + similarityFunction.compare( + candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } + } + } else { + // else we just need to make sure candidate does not violate diversity with the (newly + // inserted) unchecked nodes + assert candidateIndex > uncheckedIndexes[uncheckedCursor]; + for (int i = uncheckedCursor; i >= 0; i--) { + float neighborSimilarity = + similarityFunction.compare( + candidateVector, + (float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } } } return false; } private boolean isWorstNonDiverse( - int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException { + int candidateIndex, + byte[] candidateVector, + NeighborArray neighbors, + int[] uncheckedIndexes, + int uncheckedCursor) + throws IOException { float minAcceptedSimilarity = neighbors.score[candidateIndex]; - for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = - similarityFunction.compare( - candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i])); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { - return true; + if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { + // the candidate itself is unchecked + for (int i = candidateIndex - 1; i >= 0; i--) { + float neighborSimilarity = + similarityFunction.compare( + candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } + } + } else { + // else we just need to make sure candidate does not violate diversity with the (newly + // inserted) unchecked nodes + assert candidateIndex > uncheckedIndexes[uncheckedCursor]; + for (int i = uncheckedCursor; i >= 0; i--) { + float neighborSimilarity = + similarityFunction.compare( + candidateVector, + (byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } } } return false; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index ec1b5ec3e897..a23b9b5254ee 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -34,6 +34,7 @@ public class NeighborArray { float[] score; int[] node; + private int sortedNodeSize; public NeighborArray(int maxSize, boolean descOrder) { node = new int[maxSize]; @@ -43,9 +44,10 @@ public NeighborArray(int maxSize, boolean descOrder) { /** * Add a new node to the NeighborArray. The new node must be worse than all previously stored - * nodes. + * nodes. This cannot be called after {@link #addOutOfOrder(int, float)} */ - public void add(int newNode, float newScore) { + public void addInOrder(int newNode, float newScore) { + assert size == sortedNodeSize : "cannot call addInOrder after addOutOfOrder"; if (size == node.length) { node = ArrayUtil.grow(node); score = ArrayUtil.growExact(score, node.length); @@ -59,23 +61,72 @@ public void add(int newNode, float newScore) { node[size] = newNode; score[size] = newScore; ++size; + ++sortedNodeSize; } - /** Add a new node to the NeighborArray into a correct sort position according to its score. */ - public void insertSorted(int newNode, float newScore) { + /** Add node and score but do not insert as sorted */ + public void addOutOfOrder(int newNode, float newScore) { if (size == node.length) { node = ArrayUtil.grow(node); score = ArrayUtil.growExact(score, node.length); } + node[size] = newNode; + score[size] = newScore; + size++; + } + + /** + * Sort the array according to scores, and return the sorted indexes of previous unsorted nodes + * (unchecked nodes) + * + * @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is + * already fully sorted + */ + public int[] sort() { + if (size == sortedNodeSize) { + // all nodes checked and sorted + return null; + } + assert sortedNodeSize < size; + int[] uncheckedIndexes = new int[size - sortedNodeSize]; + int count = 0; + while (sortedNodeSize != size) { + uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside + for (int i = 0; i < count; i++) { + if (uncheckedIndexes[i] >= uncheckedIndexes[count]) { + // the previous inserted nodes has been shifted + uncheckedIndexes[i]++; + } + } + count++; + } + Arrays.sort(uncheckedIndexes); + return uncheckedIndexes; + } + + /** insert the first unsorted node into its sorted position */ + private int insertSortedInternal() { + assert sortedNodeSize < size : "Call this method only when there's unsorted node"; + int tmpNode = node[sortedNodeSize]; + float tmpScore = score[sortedNodeSize]; int insertionPoint = scoresDescOrder - ? descSortFindRightMostInsertionPoint(newScore) - : ascSortFindRightMostInsertionPoint(newScore); - System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint); - System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint); - node[insertionPoint] = newNode; - score[insertionPoint] = newScore; - ++size; + ? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize) + : ascSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize); + System.arraycopy( + node, insertionPoint, node, insertionPoint + 1, sortedNodeSize - insertionPoint); + System.arraycopy( + score, insertionPoint, score, insertionPoint + 1, sortedNodeSize - insertionPoint); + node[insertionPoint] = tmpNode; + score[insertionPoint] = tmpScore; + ++sortedNodeSize; + return insertionPoint; + } + + /** This method is for test only. */ + void insertSorted(int newNode, float newScore) { + addOutOfOrder(newNode, newScore); + insertSortedInternal(); } public int size() { @@ -97,15 +148,20 @@ public float[] score() { public void clear() { size = 0; + sortedNodeSize = 0; } public void removeLast() { size--; + sortedNodeSize = Math.min(sortedNodeSize, size); } public void removeIndex(int idx) { System.arraycopy(node, idx + 1, node, idx, size - idx - 1); System.arraycopy(score, idx + 1, score, idx, size - idx - 1); + if (idx < sortedNodeSize) { + sortedNodeSize--; + } size--; } @@ -114,11 +170,11 @@ public String toString() { return "NeighborArray[" + size + "]"; } - private int ascSortFindRightMostInsertionPoint(float newScore) { - int insertionPoint = Arrays.binarySearch(score, 0, size, newScore); + private int ascSortFindRightMostInsertionPoint(float newScore, int bound) { + int insertionPoint = Arrays.binarySearch(score, 0, bound, newScore); if (insertionPoint >= 0) { // find the right most position with the same score - while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) { + while ((insertionPoint < bound - 1) && (score[insertionPoint + 1] == score[insertionPoint])) { insertionPoint++; } insertionPoint++; @@ -128,9 +184,9 @@ private int ascSortFindRightMostInsertionPoint(float newScore) { return insertionPoint; } - private int descSortFindRightMostInsertionPoint(float newScore) { + private int descSortFindRightMostInsertionPoint(float newScore, int bound) { int start = 0; - int end = size - 1; + int end = bound - 1; while (start <= end) { int mid = (start + end) / 2; if (score[mid] < newScore) end = mid - 1; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 9862536de08c..75a77af056a2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -170,14 +170,14 @@ public NodesIterator getNodesOnLevel(int level) { public long ramBytesUsed() { long neighborArrayBytes0 = nsize0 * (Integer.BYTES + Float.BYTES) - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2 - + RamUsageEstimator.NUM_BYTES_OBJECT_REF - + Integer.BYTES * 2; + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 + + Integer.BYTES * 3; long neighborArrayBytes = nsize * (Integer.BYTES + Float.BYTES) - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2 - + RamUsageEstimator.NUM_BYTES_OBJECT_REF - + Integer.BYTES * 2; + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 + + Integer.BYTES * 3; long total = 0; for (int l = 0; l < numLevels; l++) { if (l == 0) { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java index b8ae24f62009..039f69c9dc4c 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java @@ -23,100 +23,160 @@ public class TestNeighborArray extends LuceneTestCase { public void testScoresDescOrder() { NeighborArray neighbors = new NeighborArray(10, true); - neighbors.add(0, 1); - neighbors.add(1, 0.8f); + neighbors.addInOrder(0, 1); + neighbors.addInOrder(1, 0.8f); - AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.9f)); + AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f)); assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); neighbors.insertSorted(3, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 3, 1}, neighbors); neighbors.insertSorted(4, 1f); assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors); neighbors.insertSorted(5, 1.1f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors); neighbors.insertSorted(6, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors); neighbors.insertSorted(7, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors); neighbors.removeIndex(2); assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors); neighbors.removeIndex(0); assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors); neighbors.removeIndex(4); assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors); neighbors.removeLast(); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 3, 1}, neighbors); neighbors.insertSorted(8, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 8, 1}, neighbors); + assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors); } public void testScoresAscOrder() { NeighborArray neighbors = new NeighborArray(10, false); - neighbors.add(0, 0.1f); - neighbors.add(1, 0.3f); + neighbors.addInOrder(0, 0.1f); + neighbors.addInOrder(1, 0.3f); - AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.15f)); + AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f)); assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); neighbors.insertSorted(3, 0.3f); assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {0, 1, 3}, neighbors); + assertNodesEqual(new int[] {0, 1, 3}, neighbors); neighbors.insertSorted(4, 0.2f); assertScoresEqual(new float[] {0.1f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {0, 4, 1, 3}, neighbors); + assertNodesEqual(new int[] {0, 4, 1, 3}, neighbors); neighbors.insertSorted(5, 0.05f); assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors); neighbors.insertSorted(6, 0.2f); assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors); neighbors.insertSorted(7, 0.2f); assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors); neighbors.removeIndex(2); assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors); + assertNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors); neighbors.removeIndex(0); assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors); + assertNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors); neighbors.removeIndex(4); assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f}, neighbors); - asserNodesEqual(new int[] {0, 6, 7, 1}, neighbors); + assertNodesEqual(new int[] {0, 6, 7, 1}, neighbors); neighbors.removeLast(); assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f}, neighbors); - asserNodesEqual(new int[] {0, 6, 7}, neighbors); + assertNodesEqual(new int[] {0, 6, 7}, neighbors); neighbors.insertSorted(8, 0.01f); assertScoresEqual(new float[] {0.01f, 0.1f, 0.2f, 0.2f}, neighbors); - asserNodesEqual(new int[] {8, 0, 6, 7}, neighbors); + assertNodesEqual(new int[] {8, 0, 6, 7}, neighbors); + } + + public void testSortAsc() { + NeighborArray neighbors = new NeighborArray(10, false); + neighbors.addOutOfOrder(1, 2); + // we disallow calling addInOrder after addOutOfOrder even if they're actual in order + expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2)); + neighbors.addOutOfOrder(2, 3); + neighbors.addOutOfOrder(5, 6); + neighbors.addOutOfOrder(3, 4); + neighbors.addOutOfOrder(7, 8); + neighbors.addOutOfOrder(6, 7); + neighbors.addOutOfOrder(4, 5); + int[] unchecked = neighbors.sort(); + assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); + assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); + assertScoresEqual(new float[] {2, 3, 4, 5, 6, 7, 8}, neighbors); + + NeighborArray neighbors2 = new NeighborArray(10, false); + neighbors2.addInOrder(0, 1); + neighbors2.addInOrder(1, 2); + neighbors2.addInOrder(4, 5); + neighbors2.addOutOfOrder(2, 3); + neighbors2.addOutOfOrder(6, 7); + neighbors2.addOutOfOrder(5, 6); + neighbors2.addOutOfOrder(3, 4); + unchecked = neighbors2.sort(); + assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked); + assertNodesEqual(new int[] {0, 1, 2, 3, 4, 5, 6}, neighbors2); + assertScoresEqual(new float[] {1, 2, 3, 4, 5, 6, 7}, neighbors2); + } + + public void testSortDesc() { + NeighborArray neighbors = new NeighborArray(10, true); + neighbors.addOutOfOrder(1, 7); + // we disallow calling addInOrder after addOutOfOrder even if they're actual in order + expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2)); + neighbors.addOutOfOrder(2, 6); + neighbors.addOutOfOrder(5, 3); + neighbors.addOutOfOrder(3, 5); + neighbors.addOutOfOrder(7, 1); + neighbors.addOutOfOrder(6, 2); + neighbors.addOutOfOrder(4, 4); + int[] unchecked = neighbors.sort(); + assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); + assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); + assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); + + NeighborArray neighbors2 = new NeighborArray(10, true); + neighbors2.addInOrder(1, 7); + neighbors2.addInOrder(2, 6); + neighbors2.addInOrder(5, 3); + neighbors2.addOutOfOrder(3, 5); + neighbors2.addOutOfOrder(7, 1); + neighbors2.addOutOfOrder(6, 2); + neighbors2.addOutOfOrder(4, 4); + unchecked = neighbors2.sort(); + assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked); + assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors2); + assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors2); } private void assertScoresEqual(float[] scores, NeighborArray neighbors) { @@ -125,7 +185,7 @@ private void assertScoresEqual(float[] scores, NeighborArray neighbors) { } } - private void asserNodesEqual(int[] nodes, NeighborArray neighbors) { + private void assertNodesEqual(int[] nodes, NeighborArray neighbors) { for (int i = 0; i < nodes.length; i++) { assertEquals(nodes[i], neighbors.node[i]); }