diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 39d7403a11c..de350ab0406 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -76,6 +76,8 @@ Optimizations * GITHUB#14133: Dense blocks of postings are now encoded as bit sets. (Adrien Grand) +* GITHUB#14094: Early terminate when HNSW nearest neighbor queue saturates (Tommaso Teofili) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index b29f9da5b41..3ce538eccbb 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -36,6 +36,7 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.HnswQueueSaturationCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -314,10 +315,16 @@ private void search( return; } final RandomVectorScorer scorer = scorerSupplier.get(); - final KnnCollector collector = - new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); if (knnCollector.k() < scorer.maxOrd()) { + final KnnCollector collector; + OrdinalTranslatedKnnCollector ordinalTranslatedKnnCollector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + if (scorer.maxOrd() > 1000) { + collector = new HnswQueueSaturationCollector(ordinalTranslatedKnnCollector); + } else { + collector = ordinalTranslatedKnnCollector; + } HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds); } else { // if k is larger than the number of vectors, we can just iterate over all vectors diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java new file mode 100644 index 00000000000..e145ea99dd6 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +/** {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. */ +public interface HnswKnnCollector extends KnnCollector { + + /** Indicates exploration of the next HNSW candidate graph node. */ + void nextCandidate(); +} diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java new file mode 100644 index 00000000000..b2af5cbd685 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +/** + * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a + * 'patience' parameter. This records the rate of collection of new nearest neighbors in the {@code + * delegate} KnnCollector queue, at each HNSW node candidate visit. Once it saturates for a number + * of consecutive node visits (e.g., the patience parameter), this early terminates. + */ +public class HnswQueueSaturationCollector implements HnswKnnCollector { + + private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; + + private final KnnCollector delegate; + private final double saturationThreshold; + private final int patience; + private boolean patienceFinished; + private int countSaturated; + private int previousQueueSize; + private int currentQueueSize; + + HnswQueueSaturationCollector( + KnnCollector delegate, double saturationThreshold, int patience) { + this.delegate = delegate; + this.previousQueueSize = 0; + this.currentQueueSize = 0; + this.countSaturated = 0; + this.patienceFinished = false; + this.saturationThreshold = saturationThreshold; + this.patience = patience; + } + + public HnswQueueSaturationCollector(KnnCollector delegate) { + this.delegate = delegate; + this.previousQueueSize = 0; + this.currentQueueSize = 0; + this.countSaturated = 0; + this.patienceFinished = false; + this.saturationThreshold = DEFAULT_SATURATION_THRESHOLD; + this.patience = defaultPatience(); + } + + private int defaultPatience() { + return Math.max(7, (int) (k() * 0.3)); + } + + @Override + public boolean earlyTerminated() { + return delegate.earlyTerminated() || patienceFinished; + } + + @Override + public void incVisitedCount(int count) { + delegate.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return delegate.visitedCount(); + } + + @Override + public long visitLimit() { + return delegate.visitLimit(); + } + + @Override + public int k() { + return delegate.k(); + } + + @Override + public boolean collect(int docId, float similarity) { + boolean collect = delegate.collect(docId, similarity); + if (collect) { + currentQueueSize++; + } + return collect; + } + + @Override + public float minCompetitiveSimilarity() { + return delegate.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + TopDocs topDocs; + if (patienceFinished && delegate.earlyTerminated() == false) { + TopDocs delegateDocs = delegate.topDocs(); + TotalHits totalHits = + new TotalHits(delegateDocs.totalHits.value(), TotalHits.Relation.EQUAL_TO); + topDocs = new TopDocs(totalHits, delegateDocs.scoreDocs); + } else { + topDocs = delegate.topDocs(); + } + return topDocs; + } + + @Override + public void nextCandidate() { + double queueSaturation = + (double) Math.min(currentQueueSize, previousQueueSize) / currentQueueSize; + previousQueueSize = currentQueueSize; + if (queueSaturation >= saturationThreshold) { + countSaturated++; + } else { + countSaturated = 0; + } + if (countSaturated > patience) { + patienceFinished = true; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index e8f0d316fd8..fb93398d159 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -21,6 +21,7 @@ import java.io.IOException; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.HnswKnnCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.search.knn.EntryPointProvider; @@ -272,6 +273,9 @@ void searchLevel( } } } + if (results instanceof HnswKnnCollector hnswKnnCollector) { + hnswKnnCollector.nextCandidate(); + } } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java new file mode 100644 index 00000000000..1ef29845016 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.util.Random; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.Test; + +/** Tests for {@link HnswQueueSaturationCollector} */ +public class TestHnswQueueSaturationCollector extends LuceneTestCase { + + @Test + public void testDelegate() { + Random random = random(); + int numDocs = 100; + int k = random.nextInt(1, 10); + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + } + assertEquals(delegate.k(), queueSaturationCollector.k()); + assertEquals(delegate.visitedCount(), queueSaturationCollector.visitedCount()); + assertEquals(delegate.visitLimit(), queueSaturationCollector.visitLimit()); + assertEquals( + delegate.minCompetitiveSimilarity(), + queueSaturationCollector.minCompetitiveSimilarity(), + 1e-3); + } + + @Test + public void testEarlyExpectedExit() { + int numDocs = 1000; + int k = 10; + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate, 0.9, 10); + for (int i = 0; i < numDocs; i++) { + queueSaturationCollector.collect(i, 1.0f - i * 1e-3f); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } + if (queueSaturationCollector.earlyTerminated()) { + assertEquals(120, i); + break; + } + } + } + + @Test + public void testDelegateVsSaturateEarlyExit() { + Random random = random(); + int numDocs = 10000; + int k = random.nextInt(1, 100); + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } + boolean earlyTerminatedSaturation = queueSaturationCollector.earlyTerminated(); + boolean earlyTerminatedDelegate = delegate.earlyTerminated(); + assertTrue(earlyTerminatedSaturation || !earlyTerminatedDelegate); + } + } + + @Test + public void testEarlyExitRelation() { + Random random = random(); + int numDocs = 10000; + int k = random.nextInt(100); + KnnCollector delegate = new TopKnnCollector(k, random.nextInt(numDocs)); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } + if (delegate.earlyTerminated()) { + TopDocs topDocs = queueSaturationCollector.topDocs(); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, topDocs.totalHits.relation()); + } + if (queueSaturationCollector.earlyTerminated()) { + TopDocs topDocs = queueSaturationCollector.topDocs(); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocs.totalHits.relation()); + break; + } + } + } +}