-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for similarity-based vector searches (#12679)
### Description Background in #12579 Add support for getting "all vectors within a radius" as opposed to getting the "topK closest vectors" in the current system ### Considerations I've tried to keep this change minimal and non-invasive by not modifying any APIs and re-using existing HNSW graphs -- changing the graph traversal and result collection criteria to: 1. Visit all nodes (reachable from the entry node in the last level) that are within an outer "traversal" radius 2. Collect all nodes that are within an inner "result" radius ### Advantages 1. Queries that have a high number of "relevant" results will get all of those (not limited by `topK`) 2. Conversely, arbitrary queries where many results are not "relevant" will not waste time in getting all `topK` (when some of them will be removed later) 3. Results of HNSW searches need not be sorted - and we can store them in a plain list as opposed to min-max heaps (saving on `heapify` calls). Merging results from segments is also cheaper, where we just concatenate results as opposed to calculating the index-level `topK` On a higher level, finding `topK` results needed HNSW searches to happen in `#rewrite` because of an interdependence of results between segments - where we want to find the index-level `topK` from multiple segment-level results. This is kind of against Lucene's concept of segments being independently searchable sub-indexes? Moreover, we needed explicit concurrency (#12160) to perform these in parallel, and these shortcomings would be naturally overcome with the new objective of finding "all vectors within a radius" - inherently independent of results from another segment (so we can move searches to a more fitting place?) ### Caveats I could not find much precedent in using HNSW graphs this way (or even the radius-based search for that matter - please add links to existing work if someone is aware) and consequently marked all classes as `@lucene.experimental` For now I have re-used lots of functionality from `AbstractKnnVectorQuery` to keep this minimal, but if the use-case is accepted more widely we can look into writing more suitable queries (as mentioned above briefly)
- Loading branch information
Showing
9 changed files
with
1,403 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
288 changes: 288 additions & 0 deletions
288
lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.search; | ||
|
||
import java.io.IOException; | ||
import java.util.Arrays; | ||
import java.util.Comparator; | ||
import java.util.Objects; | ||
import org.apache.lucene.index.LeafReader; | ||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.util.BitSet; | ||
import org.apache.lucene.util.BitSetIterator; | ||
import org.apache.lucene.util.Bits; | ||
|
||
/** | ||
* Search for all (approximate) vectors above a similarity threshold. | ||
* | ||
* @lucene.experimental | ||
*/ | ||
abstract class AbstractVectorSimilarityQuery extends Query { | ||
protected final String field; | ||
protected final float traversalSimilarity, resultSimilarity; | ||
protected final Query filter; | ||
|
||
/** | ||
* Search for all (approximate) vectors above a similarity threshold using {@link | ||
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of | ||
* the filter, and then falls back to exact search if results are incomplete. | ||
* | ||
* @param field a field that has been indexed as a vector field. | ||
* @param traversalSimilarity (lower) similarity score for graph traversal. | ||
* @param resultSimilarity (higher) similarity score for result collection. | ||
* @param filter a filter applied before the vector search. | ||
*/ | ||
AbstractVectorSimilarityQuery( | ||
String field, float traversalSimilarity, float resultSimilarity, Query filter) { | ||
if (traversalSimilarity > resultSimilarity) { | ||
throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity"); | ||
} | ||
this.field = Objects.requireNonNull(field, "field"); | ||
this.traversalSimilarity = traversalSimilarity; | ||
this.resultSimilarity = resultSimilarity; | ||
this.filter = filter; | ||
} | ||
|
||
abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException; | ||
|
||
protected abstract TopDocs approximateSearch( | ||
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException; | ||
|
||
@Override | ||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) | ||
throws IOException { | ||
return new Weight(this) { | ||
final Weight filterWeight = | ||
filter == null | ||
? null | ||
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1); | ||
|
||
@Override | ||
public Explanation explain(LeafReaderContext context, int doc) throws IOException { | ||
if (filterWeight != null) { | ||
Scorer filterScorer = filterWeight.scorer(context); | ||
if (filterScorer == null || filterScorer.iterator().advance(doc) > doc) { | ||
return Explanation.noMatch("Doc does not match the filter"); | ||
} | ||
} | ||
|
||
VectorScorer scorer = createVectorScorer(context); | ||
if (scorer == null) { | ||
return Explanation.noMatch("Not indexed as the correct vector field"); | ||
} else if (scorer.advanceExact(doc)) { | ||
float score = scorer.score(); | ||
if (score >= resultSimilarity) { | ||
return Explanation.match(boost * score, "Score above threshold"); | ||
} else { | ||
return Explanation.noMatch("Score below threshold"); | ||
} | ||
} else { | ||
return Explanation.noMatch("No vector found for doc"); | ||
} | ||
} | ||
|
||
@Override | ||
public Scorer scorer(LeafReaderContext context) throws IOException { | ||
@SuppressWarnings("resource") | ||
LeafReader leafReader = context.reader(); | ||
Bits liveDocs = leafReader.getLiveDocs(); | ||
|
||
// If there is no filter | ||
if (filterWeight == null) { | ||
// Return exhaustive results | ||
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE); | ||
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); | ||
} | ||
|
||
Scorer scorer = filterWeight.scorer(context); | ||
if (scorer == null) { | ||
// If the filter does not match any documents | ||
return null; | ||
} | ||
|
||
BitSet acceptDocs; | ||
if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) { | ||
// If there are no deletions, and matching docs are already cached | ||
acceptDocs = bitSetIterator.getBitSet(); | ||
} else { | ||
// Else collect all matching docs | ||
FilteredDocIdSetIterator filtered = | ||
new FilteredDocIdSetIterator(scorer.iterator()) { | ||
@Override | ||
protected boolean match(int doc) { | ||
return liveDocs == null || liveDocs.get(doc); | ||
} | ||
}; | ||
acceptDocs = BitSet.of(filtered, leafReader.maxDoc()); | ||
} | ||
|
||
int cardinality = acceptDocs.cardinality(); | ||
if (cardinality == 0) { | ||
// If there are no live matching docs | ||
return null; | ||
} | ||
|
||
// Perform an approximate search | ||
TopDocs results = approximateSearch(context, acceptDocs, cardinality); | ||
|
||
// If the limit was exhausted | ||
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) { | ||
// Return a lazy-loading iterator | ||
return VectorSimilarityScorer.fromAcceptDocs( | ||
this, | ||
boost, | ||
createVectorScorer(context), | ||
new BitSetIterator(acceptDocs, cardinality), | ||
resultSimilarity); | ||
} else { | ||
// Return an iterator over the collected results | ||
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); | ||
} | ||
} | ||
|
||
@Override | ||
public boolean isCacheable(LeafReaderContext ctx) { | ||
return true; | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public void visit(QueryVisitor visitor) { | ||
if (visitor.acceptField(field)) { | ||
visitor.visitLeaf(this); | ||
} | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
return sameClassAs(o) | ||
&& Objects.equals(field, ((AbstractVectorSimilarityQuery) o).field) | ||
&& Float.compare( | ||
((AbstractVectorSimilarityQuery) o).traversalSimilarity, traversalSimilarity) | ||
== 0 | ||
&& Float.compare(((AbstractVectorSimilarityQuery) o).resultSimilarity, resultSimilarity) | ||
== 0 | ||
&& Objects.equals(filter, ((AbstractVectorSimilarityQuery) o).filter); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter); | ||
} | ||
|
||
private static class VectorSimilarityScorer extends Scorer { | ||
final DocIdSetIterator iterator; | ||
final float[] cachedScore; | ||
|
||
VectorSimilarityScorer(Weight weight, DocIdSetIterator iterator, float[] cachedScore) { | ||
super(weight); | ||
this.iterator = iterator; | ||
this.cachedScore = cachedScore; | ||
} | ||
|
||
static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) { | ||
// Sort in ascending order of docid | ||
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); | ||
|
||
float[] cachedScore = new float[1]; | ||
DocIdSetIterator iterator = | ||
new DocIdSetIterator() { | ||
int index = -1; | ||
|
||
@Override | ||
public int docID() { | ||
if (index < 0) { | ||
return -1; | ||
} else if (index >= scoreDocs.length) { | ||
return NO_MORE_DOCS; | ||
} else { | ||
cachedScore[0] = boost * scoreDocs[index].score; | ||
return scoreDocs[index].doc; | ||
} | ||
} | ||
|
||
@Override | ||
public int nextDoc() { | ||
index++; | ||
return docID(); | ||
} | ||
|
||
@Override | ||
public int advance(int target) { | ||
index = | ||
Arrays.binarySearch( | ||
scoreDocs, | ||
new ScoreDoc(target, 0), | ||
Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); | ||
if (index < 0) { | ||
index = -1 - index; | ||
} | ||
return docID(); | ||
} | ||
|
||
@Override | ||
public long cost() { | ||
return scoreDocs.length; | ||
} | ||
}; | ||
|
||
return new VectorSimilarityScorer(weight, iterator, cachedScore); | ||
} | ||
|
||
static VectorSimilarityScorer fromAcceptDocs( | ||
Weight weight, | ||
float boost, | ||
VectorScorer scorer, | ||
DocIdSetIterator acceptDocs, | ||
float threshold) { | ||
float[] cachedScore = new float[1]; | ||
DocIdSetIterator iterator = | ||
new FilteredDocIdSetIterator(acceptDocs) { | ||
@Override | ||
protected boolean match(int doc) throws IOException { | ||
// Compute the dot product | ||
float score = scorer.score(); | ||
cachedScore[0] = score * boost; | ||
return score >= threshold; | ||
} | ||
}; | ||
|
||
return new VectorSimilarityScorer(weight, iterator, cachedScore); | ||
} | ||
|
||
@Override | ||
public int docID() { | ||
return iterator.docID(); | ||
} | ||
|
||
@Override | ||
public DocIdSetIterator iterator() { | ||
return iterator; | ||
} | ||
|
||
@Override | ||
public float getMaxScore(int upTo) { | ||
return Float.POSITIVE_INFINITY; | ||
} | ||
|
||
@Override | ||
public float score() { | ||
return cachedScore[0]; | ||
} | ||
} | ||
} |
Oops, something went wrong.