From d885bd58fadb01d1e0e33494cc09beb093b8eba7 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Wed, 3 Apr 2024 12:32:19 -0400 Subject: [PATCH 1/4] Add test assumption of fully-connected hnsw graph to BaseVectorSimilarityQueryTestCase --- .../BaseVectorSimilarityQueryTestCase.java | 28 +++++ .../test-framework/src/java/module-info.java | 1 + .../asserting/AssertingKnnVectorsFormat.java | 9 +- .../lucene/tests/util/hnsw/HnswTestUtil.java | 110 ++++++++++++++++++ .../lucene/tests/util/hnsw/package-info.java | 19 +++ 5 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java create mode 100644 lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/package-info.java diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java index 5ad66b90d59d..2dcb8ad3bde8 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java @@ -26,17 +26,23 @@ import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; +import org.apache.lucene.codecs.HnswGraphProvider; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.IntField; +import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.hnsw.HnswTestUtil; +import org.apache.lucene.util.hnsw.HnswGraph; @LuceneTestCase.SuppressCodecs("SimpleText") abstract class BaseVectorSimilarityQueryTestCase< @@ -165,6 +171,9 @@ public void testRandomFilter() throws IOException { try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); IndexReader reader = DirectoryReader.open(indexStore)) { + if (graphIsDisconnected(reader)) { + return; + } IndexSearcher searcher = newSearcher(reader); Query query = @@ -289,6 +298,9 @@ public void testSomeDeletes() throws IOException { w.commit(); try (IndexReader reader = DirectoryReader.open(indexStore)) { + if (graphIsDisconnected(reader)) { + return; + } IndexSearcher searcher = newSearcher(reader); Query query = @@ -522,4 +534,20 @@ final Directory getIndexStore(V... vectors) throws IOException { } return dir; } + + private boolean graphIsDisconnected(IndexReader reader) throws IOException { + for (LeafReaderContext ctx : reader.leaves()) { + HnswGraph graph = + ((HnswGraphProvider) + ((PerFieldKnnVectorsFormat.FieldsReader) + ((CodecReader) ctx.reader()).getVectorReader()) + .getFieldReader(vectorField)) + .getGraph(vectorField); + if (HnswTestUtil.isFullyConnected(graph) == false) { + // for now just bail out. + return true; + } + } + return false; + } } diff --git a/lucene/test-framework/src/java/module-info.java b/lucene/test-framework/src/java/module-info.java index 2af42e6b12dd..a6971845608b 100644 --- a/lucene/test-framework/src/java/module-info.java +++ b/lucene/test-framework/src/java/module-info.java @@ -51,6 +51,7 @@ exports org.apache.lucene.tests.store; exports org.apache.lucene.tests.util.automaton; exports org.apache.lucene.tests.util.fst; + exports org.apache.lucene.tests.util.hnsw; exports org.apache.lucene.tests.util; provides org.apache.lucene.codecs.Codec with diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index e93260d39b84..62aca64b57e2 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -18,6 +18,7 @@ package org.apache.lucene.tests.codecs.asserting; import java.io.IOException; +import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; @@ -34,6 +35,7 @@ import org.apache.lucene.search.KnnCollector; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.HnswGraph; /** Wraps the default KnnVectorsFormat and provides additional assertions. */ public class AssertingKnnVectorsFormat extends KnnVectorsFormat { @@ -100,7 +102,7 @@ public long ramBytesUsed() { } } - static class AssertingKnnVectorsReader extends KnnVectorsReader { + static class AssertingKnnVectorsReader extends KnnVectorsReader implements HnswGraphProvider { final KnnVectorsReader delegate; final FieldInfos fis; @@ -173,5 +175,10 @@ public void close() throws IOException { public long ramBytesUsed() { return delegate.ramBytesUsed(); } + + @Override + public HnswGraph getGraph(String field) throws IOException { + return ((HnswGraphProvider) delegate).getGraph(field); + } } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java new file mode 100644 index 000000000000..6164a1a35b48 --- /dev/null +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.tests.util.hnsw; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.hnsw.HnswGraph; + +public class HnswTestUtil { + + /** + * Returns true iff level 0 of the graph is fully connected - that is every node is reachable from + * any entry point. + */ + public static boolean isFullyConnected(HnswGraph knnValues) throws IOException { + return componentSizes(knnValues).size() < 2; + } + + /** + * Returns the sizes of the distinct graph components on level 0. If the graph is fully-connected + * there will only be a single component. If the graph is empty, the returned list will be empty. + */ + public static List componentSizes(HnswGraph hnsw) throws IOException { + List sizes = new ArrayList<>(); + FixedBitSet connectedNodes = new FixedBitSet(hnsw.size()); + assert hnsw.size() == hnsw.getNodesOnLevel(0).size(); + System.out.println("size=" + hnsw.size()); + int total = 0; + while (total < connectedNodes.length()) { + int componentSize = traverseConnectedNodes(hnsw, connectedNodes); + assert componentSize > 0; + sizes.add(componentSize); + total += componentSize; + } + return sizes; + } + + // count the nodes in a connected component of the graph and set the bits of its nodes in + // connectedNodes bitset + private static int traverseConnectedNodes(HnswGraph hnswGraph, FixedBitSet connectedNodes) + throws IOException { + // Start at entry point and search all nodes on this level + int entryPoint = nextClearBit(connectedNodes, 0); + if (entryPoint == NO_MORE_DOCS) { + return 0; + } + Deque stack = new ArrayDeque<>(); + stack.push(entryPoint); + int count = 0; + while (!stack.isEmpty()) { + int node = stack.pop(); + if (connectedNodes.get(node)) { + continue; + } + count++; + connectedNodes.set(node); + hnswGraph.seek(0, node); + int friendOrd; + while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) { + stack.push(friendOrd); + } + } + return count; + } + + private static int nextClearBit(FixedBitSet bits, int index) { + // Depends on the ghost bits being clear! + long[] barray = bits.getBits(); + assert index >= 0 && index < bits.length() : "index=" + index + ", numBits=" + bits.length(); + int i = index >> 6; + long word = ~(barray[i] >> index); // skip all the bits to the right of index + + if (word != 0) { + return index + Long.numberOfTrailingZeros(word); + } + + while (++i < barray.length) { + word = ~barray[i]; + if (word != 0) { + int next = (i << 6) + Long.numberOfTrailingZeros(word); + if (next >= bits.length()) { + return NO_MORE_DOCS; + } else { + return next; + } + } + } + return NO_MORE_DOCS; + } +} diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/package-info.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/package-info.java new file mode 100644 index 000000000000..05966f9df0a6 --- /dev/null +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Support for HNSW testing. */ +package org.apache.lucene.tests.util.hnsw; From f289cb7b3dff3d8691097197cfbeb2ff3f003ae8 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Wed, 3 Apr 2024 13:15:18 -0400 Subject: [PATCH 2/4] refactored; use assumeTrue; add class javadoc --- .../BaseVectorSimilarityQueryTestCase.java | 29 ++----------------- .../lucene/tests/util/hnsw/HnswTestUtil.java | 22 ++++++++++++++ 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java index 2dcb8ad3bde8..f0bce64acd07 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java @@ -26,23 +26,18 @@ import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; -import org.apache.lucene.codecs.HnswGraphProvider; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.IntField; -import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.hnsw.HnswTestUtil; -import org.apache.lucene.util.hnsw.HnswGraph; @LuceneTestCase.SuppressCodecs("SimpleText") abstract class BaseVectorSimilarityQueryTestCase< @@ -171,9 +166,7 @@ public void testRandomFilter() throws IOException { try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); IndexReader reader = DirectoryReader.open(indexStore)) { - if (graphIsDisconnected(reader)) { - return; - } + assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField)); IndexSearcher searcher = newSearcher(reader); Query query = @@ -298,9 +291,7 @@ public void testSomeDeletes() throws IOException { w.commit(); try (IndexReader reader = DirectoryReader.open(indexStore)) { - if (graphIsDisconnected(reader)) { - return; - } + assumeTrue("graph is disconnected", HnswTestUtil.graphIsConnected(reader, vectorField)); IndexSearcher searcher = newSearcher(reader); Query query = @@ -534,20 +525,4 @@ final Directory getIndexStore(V... vectors) throws IOException { } return dir; } - - private boolean graphIsDisconnected(IndexReader reader) throws IOException { - for (LeafReaderContext ctx : reader.leaves()) { - HnswGraph graph = - ((HnswGraphProvider) - ((PerFieldKnnVectorsFormat.FieldsReader) - ((CodecReader) ctx.reader()).getVectorReader()) - .getFieldReader(vectorField)) - .getGraph(vectorField); - if (HnswTestUtil.isFullyConnected(graph) == false) { - // for now just bail out. - return true; - } - } - return false; - } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java index 6164a1a35b48..21e22202c359 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java @@ -23,9 +23,15 @@ import java.util.ArrayList; import java.util.Deque; import java.util.List; +import org.apache.lucene.codecs.HnswGraphProvider; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.hnsw.HnswGraph; +/** Utilities for use in tests involving HNSW graphs */ public class HnswTestUtil { /** @@ -107,4 +113,20 @@ private static int nextClearBit(FixedBitSet bits, int index) { } return NO_MORE_DOCS; } + + public static boolean graphIsConnected(IndexReader reader, String vectorField) + throws IOException { + for (LeafReaderContext ctx : reader.leaves()) { + HnswGraph graph = + ((HnswGraphProvider) + ((PerFieldKnnVectorsFormat.FieldsReader) + ((CodecReader) ctx.reader()).getVectorReader()) + .getFieldReader(vectorField)) + .getGraph(vectorField); + if (HnswTestUtil.isFullyConnected(graph) == false) { + return false; + } + } + return true; + } } From 841287feac2c5d8493a1adde220a97fc324a8329 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Wed, 3 Apr 2024 15:49:30 -0400 Subject: [PATCH 3/4] added safe unwrapping; removed stray print --- .../apache/lucene/tests/util/hnsw/HnswTestUtil.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java index 21e22202c359..a6284bdecb1e 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.Unwrappable; import org.apache.lucene.util.hnsw.HnswGraph; /** Utilities for use in tests involving HNSW graphs */ @@ -50,7 +51,6 @@ public static List componentSizes(HnswGraph hnsw) throws IOException { List sizes = new ArrayList<>(); FixedBitSet connectedNodes = new FixedBitSet(hnsw.size()); assert hnsw.size() == hnsw.getNodesOnLevel(0).size(); - System.out.println("size=" + hnsw.size()); int total = 0; while (total < connectedNodes.length()) { int componentSize = traverseConnectedNodes(hnsw, connectedNodes); @@ -90,7 +90,7 @@ private static int traverseConnectedNodes(HnswGraph hnswGraph, FixedBitSet conne } private static int nextClearBit(FixedBitSet bits, int index) { - // Depends on the ghost bits being clear! + // Does not depend on the ghost bits being clear! long[] barray = bits.getBits(); assert index >= 0 && index < bits.length() : "index=" + index + ", numBits=" + bits.length(); int i = index >> 6; @@ -117,13 +117,13 @@ private static int nextClearBit(FixedBitSet bits, int index) { public static boolean graphIsConnected(IndexReader reader, String vectorField) throws IOException { for (LeafReaderContext ctx : reader.leaves()) { + CodecReader codecReader = (CodecReader) Unwrappable.unwrapAll(ctx.reader()); HnswGraph graph = ((HnswGraphProvider) - ((PerFieldKnnVectorsFormat.FieldsReader) - ((CodecReader) ctx.reader()).getVectorReader()) + ((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader()) .getFieldReader(vectorField)) .getGraph(vectorField); - if (HnswTestUtil.isFullyConnected(graph) == false) { + if (isFullyConnected(graph) == false) { return false; } } From 08f65a4d0471270f741dcb372da08115063c34db Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Wed, 3 Apr 2024 16:07:45 -0400 Subject: [PATCH 4/4] use FilterLeafReader.unwrap --- .../java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java index a6284bdecb1e..ddd85a68d562 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java @@ -26,10 +26,10 @@ import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.FixedBitSet; -import org.apache.lucene.util.Unwrappable; import org.apache.lucene.util.hnsw.HnswGraph; /** Utilities for use in tests involving HNSW graphs */ @@ -117,7 +117,7 @@ private static int nextClearBit(FixedBitSet bits, int index) { public static boolean graphIsConnected(IndexReader reader, String vectorField) throws IOException { for (LeafReaderContext ctx : reader.leaves()) { - CodecReader codecReader = (CodecReader) Unwrappable.unwrapAll(ctx.reader()); + CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader()); HnswGraph graph = ((HnswGraphProvider) ((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader())