diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index d100fc310dd5..88087525bc9a 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -171,7 +171,11 @@ API Changes New Features --------------------- -(No changes) + +* GITHUB#12679: Add support for similarity-based vector searches using [Byte|Float]VectorSimilarityQuery. Uses a new + VectorSimilarityCollector to find all vectors scoring above a `resultSimilarity` while traversing the HNSW graph till + better-scoring nodes are available, or the best candidate is below a score of `traversalSimilarity` in the lowest + level. (Aditya Prakash, Kaival Parikh) Improvements --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java new file mode 100644 index 000000000000..393d4b5731f6 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Objects; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; + +/** + * Search for all (approximate) vectors above a similarity threshold. + * + * @lucene.experimental + */ +abstract class AbstractVectorSimilarityQuery extends Query { + protected final String field; + protected final float traversalSimilarity, resultSimilarity; + protected final Query filter; + + /** + * Search for all (approximate) vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of + * the filter, and then falls back to exact search if results are incomplete. + * + * @param field a field that has been indexed as a vector field. + * @param traversalSimilarity (lower) similarity score for graph traversal. + * @param resultSimilarity (higher) similarity score for result collection. + * @param filter a filter applied before the vector search. + */ + AbstractVectorSimilarityQuery( + String field, float traversalSimilarity, float resultSimilarity, Query filter) { + if (traversalSimilarity > resultSimilarity) { + throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity"); + } + this.field = Objects.requireNonNull(field, "field"); + this.traversalSimilarity = traversalSimilarity; + this.resultSimilarity = resultSimilarity; + this.filter = filter; + } + + abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException; + + protected abstract TopDocs approximateSearch( + LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException; + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + return new Weight(this) { + final Weight filterWeight = + filter == null + ? null + : searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1); + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + if (filterWeight != null) { + Scorer filterScorer = filterWeight.scorer(context); + if (filterScorer == null || filterScorer.iterator().advance(doc) > doc) { + return Explanation.noMatch("Doc does not match the filter"); + } + } + + VectorScorer scorer = createVectorScorer(context); + if (scorer == null) { + return Explanation.noMatch("Not indexed as the correct vector field"); + } else if (scorer.advanceExact(doc)) { + float score = scorer.score(); + if (score >= resultSimilarity) { + return Explanation.match(boost * score, "Score above threshold"); + } else { + return Explanation.noMatch("Score below threshold"); + } + } else { + return Explanation.noMatch("No vector found for doc"); + } + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + @SuppressWarnings("resource") + LeafReader leafReader = context.reader(); + Bits liveDocs = leafReader.getLiveDocs(); + + // If there is no filter + if (filterWeight == null) { + // Return exhaustive results + TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE); + return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); + } + + Scorer scorer = filterWeight.scorer(context); + if (scorer == null) { + // If the filter does not match any documents + return null; + } + + BitSet acceptDocs; + if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) { + // If there are no deletions, and matching docs are already cached + acceptDocs = bitSetIterator.getBitSet(); + } else { + // Else collect all matching docs + FilteredDocIdSetIterator filtered = + new FilteredDocIdSetIterator(scorer.iterator()) { + @Override + protected boolean match(int doc) { + return liveDocs == null || liveDocs.get(doc); + } + }; + acceptDocs = BitSet.of(filtered, leafReader.maxDoc()); + } + + int cardinality = acceptDocs.cardinality(); + if (cardinality == 0) { + // If there are no live matching docs + return null; + } + + // Perform an approximate search + TopDocs results = approximateSearch(context, acceptDocs, cardinality); + + // If the limit was exhausted + if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) { + // Return a lazy-loading iterator + return VectorSimilarityScorer.fromAcceptDocs( + this, + boost, + createVectorScorer(context), + new BitSetIterator(acceptDocs, cardinality), + resultSimilarity); + } else { + // Return an iterator over the collected results + return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); + } + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public void visit(QueryVisitor visitor) { + if (visitor.acceptField(field)) { + visitor.visitLeaf(this); + } + } + + @Override + public boolean equals(Object o) { + return sameClassAs(o) + && Objects.equals(field, ((AbstractVectorSimilarityQuery) o).field) + && Float.compare( + ((AbstractVectorSimilarityQuery) o).traversalSimilarity, traversalSimilarity) + == 0 + && Float.compare(((AbstractVectorSimilarityQuery) o).resultSimilarity, resultSimilarity) + == 0 + && Objects.equals(filter, ((AbstractVectorSimilarityQuery) o).filter); + } + + @Override + public int hashCode() { + return Objects.hash(field, traversalSimilarity, resultSimilarity, filter); + } + + private static class VectorSimilarityScorer extends Scorer { + final DocIdSetIterator iterator; + final float[] cachedScore; + + VectorSimilarityScorer(Weight weight, DocIdSetIterator iterator, float[] cachedScore) { + super(weight); + this.iterator = iterator; + this.cachedScore = cachedScore; + } + + static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) { + // Sort in ascending order of docid + Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); + + float[] cachedScore = new float[1]; + DocIdSetIterator iterator = + new DocIdSetIterator() { + int index = -1; + + @Override + public int docID() { + if (index < 0) { + return -1; + } else if (index >= scoreDocs.length) { + return NO_MORE_DOCS; + } else { + cachedScore[0] = boost * scoreDocs[index].score; + return scoreDocs[index].doc; + } + } + + @Override + public int nextDoc() { + index++; + return docID(); + } + + @Override + public int advance(int target) { + index = + Arrays.binarySearch( + scoreDocs, + new ScoreDoc(target, 0), + Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); + if (index < 0) { + index = -1 - index; + } + return docID(); + } + + @Override + public long cost() { + return scoreDocs.length; + } + }; + + return new VectorSimilarityScorer(weight, iterator, cachedScore); + } + + static VectorSimilarityScorer fromAcceptDocs( + Weight weight, + float boost, + VectorScorer scorer, + DocIdSetIterator acceptDocs, + float threshold) { + float[] cachedScore = new float[1]; + DocIdSetIterator iterator = + new FilteredDocIdSetIterator(acceptDocs) { + @Override + protected boolean match(int doc) throws IOException { + // Compute the dot product + float score = scorer.score(); + cachedScore[0] = score * boost; + return score >= threshold; + } + }; + + return new VectorSimilarityScorer(weight, iterator, cachedScore); + } + + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public float getMaxScore(int upTo) { + return Float.POSITIVE_INFINITY; + } + + @Override + public float score() { + return cachedScore[0]; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java new file mode 100644 index 000000000000..e410ad06343d --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; +import java.util.Objects; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.util.Bits; + +/** + * Search for all (approximate) byte vectors above a similarity threshold. + * + * @lucene.experimental + */ +public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery { + private final byte[] target; + + /** + * Search for all (approximate) byte vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of + * the filter, and then falls back to exact search if results are incomplete. + * + * @param field a field that has been indexed as a {@link KnnByteVectorField}. + * @param target the target of the search. + * @param traversalSimilarity (lower) similarity score for graph traversal. + * @param resultSimilarity (higher) similarity score for result collection. + * @param filter a filter applied before the vector search. + */ + public ByteVectorSimilarityQuery( + String field, + byte[] target, + float traversalSimilarity, + float resultSimilarity, + Query filter) { + super(field, traversalSimilarity, resultSimilarity, filter); + this.target = Objects.requireNonNull(target, "target"); + } + + /** + * Search for all (approximate) byte vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. + * + * @param field a field that has been indexed as a {@link KnnByteVectorField}. + * @param target the target of the search. + * @param traversalSimilarity (lower) similarity score for graph traversal. + * @param resultSimilarity (higher) similarity score for result collection. + */ + public ByteVectorSimilarityQuery( + String field, byte[] target, float traversalSimilarity, float resultSimilarity) { + this(field, target, traversalSimilarity, resultSimilarity, null); + } + + /** + * Search for all (approximate) byte vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of + * the filter, and then falls back to exact search if results are incomplete. + * + * @param field a field that has been indexed as a {@link KnnByteVectorField}. + * @param target the target of the search. + * @param resultSimilarity similarity score for result collection. + * @param filter a filter applied before the vector search. + */ + public ByteVectorSimilarityQuery( + String field, byte[] target, float resultSimilarity, Query filter) { + this(field, target, resultSimilarity, resultSimilarity, filter); + } + + /** + * Search for all (approximate) byte vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. + * + * @param field a field that has been indexed as a {@link KnnByteVectorField}. + * @param target the target of the search. + * @param resultSimilarity similarity score for result collection. + */ + public ByteVectorSimilarityQuery(String field, byte[] target, float resultSimilarity) { + this(field, target, resultSimilarity, resultSimilarity, null); + } + + @Override + VectorScorer createVectorScorer(LeafReaderContext context) throws IOException { + @SuppressWarnings("resource") + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); + if (fi == null || fi.getVectorEncoding() != VectorEncoding.BYTE) { + return null; + } + return VectorScorer.create(context, fi, target); + } + + @Override + @SuppressWarnings("resource") + protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit) + throws IOException { + KnnCollector collector = + new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit); + context.reader().searchNearestVectors(field, target, collector, acceptDocs); + return collector.topDocs(); + } + + @Override + public String toString(String field) { + return String.format( + Locale.ROOT, + "%s[field=%s target=[%d...] traversalSimilarity=%f resultSimilarity=%f filter=%s]", + getClass().getSimpleName(), + field, + target[0], + traversalSimilarity, + resultSimilarity, + filter); + } + + @Override + public boolean equals(Object o) { + return sameClassAs(o) + && super.equals(o) + && Arrays.equals(target, ((ByteVectorSimilarityQuery) o).target); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Arrays.hashCode(target); + return result; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java new file mode 100644 index 000000000000..44d06c163c0f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; +import java.util.Objects; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; + +/** + * Search for all (approximate) float vectors above a similarity threshold. + * + * @lucene.experimental + */ +public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery { + private final float[] target; + + /** + * Search for all (approximate) float vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of + * the filter, and then falls back to exact search if results are incomplete. + * + * @param field a field that has been indexed as a {@link KnnFloatVectorField}. + * @param target the target of the search. + * @param traversalSimilarity (lower) similarity score for graph traversal. + * @param resultSimilarity (higher) similarity score for result collection. + * @param filter a filter applied before the vector search. + */ + public FloatVectorSimilarityQuery( + String field, + float[] target, + float traversalSimilarity, + float resultSimilarity, + Query filter) { + super(field, traversalSimilarity, resultSimilarity, filter); + this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); + } + + /** + * Search for all (approximate) float vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. + * + * @param field a field that has been indexed as a {@link KnnFloatVectorField}. + * @param target the target of the search. + * @param traversalSimilarity (lower) similarity score for graph traversal. + * @param resultSimilarity (higher) similarity score for result collection. + */ + public FloatVectorSimilarityQuery( + String field, float[] target, float traversalSimilarity, float resultSimilarity) { + this(field, target, traversalSimilarity, resultSimilarity, null); + } + + /** + * Search for all (approximate) float vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of + * the filter, and then falls back to exact search if results are incomplete. + * + * @param field a field that has been indexed as a {@link KnnFloatVectorField}. + * @param target the target of the search. + * @param resultSimilarity similarity score for result collection. + * @param filter a filter applied before the vector search. + */ + public FloatVectorSimilarityQuery( + String field, float[] target, float resultSimilarity, Query filter) { + this(field, target, resultSimilarity, resultSimilarity, filter); + } + + /** + * Search for all (approximate) float vectors above a similarity threshold using {@link + * VectorSimilarityCollector}. + * + * @param field a field that has been indexed as a {@link KnnFloatVectorField}. + * @param target the target of the search. + * @param resultSimilarity similarity score for result collection. + */ + public FloatVectorSimilarityQuery(String field, float[] target, float resultSimilarity) { + this(field, target, resultSimilarity, resultSimilarity, null); + } + + @Override + VectorScorer createVectorScorer(LeafReaderContext context) throws IOException { + @SuppressWarnings("resource") + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); + if (fi == null || fi.getVectorEncoding() != VectorEncoding.FLOAT32) { + return null; + } + return VectorScorer.create(context, fi, target); + } + + @Override + @SuppressWarnings("resource") + protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit) + throws IOException { + KnnCollector collector = + new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit); + context.reader().searchNearestVectors(field, target, collector, acceptDocs); + return collector.topDocs(); + } + + @Override + public String toString(String field) { + return String.format( + Locale.ROOT, + "%s[field=%s target=[%f...] traversalSimilarity=%f resultSimilarity=%f filter=%s]", + getClass().getSimpleName(), + field, + target[0], + traversalSimilarity, + resultSimilarity, + filter); + } + + @Override + public boolean equals(Object o) { + return sameClassAs(o) + && super.equals(o) + && Arrays.equals(target, ((FloatVectorSimilarityQuery) o).target); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Arrays.hashCode(target); + return result; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityCollector.java b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityCollector.java new file mode 100644 index 000000000000..6005f3ebef51 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/VectorSimilarityCollector.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.util.ArrayList; +import java.util.List; + +/** + * Perform a similarity-based graph search. + * + * @lucene.experimental + */ +class VectorSimilarityCollector extends AbstractKnnCollector { + private final float traversalSimilarity, resultSimilarity; + private float maxSimilarity; + private final List scoreDocList; + + /** + * Perform a similarity-based graph search. The graph is traversed till better scoring nodes are + * available, or the best candidate is below {@link #traversalSimilarity}. All traversed nodes + * above {@link #resultSimilarity} are collected. + * + * @param traversalSimilarity (lower) similarity score for graph traversal. + * @param resultSimilarity (higher) similarity score for result collection. + * @param visitLimit limit on number of nodes to visit. + */ + public VectorSimilarityCollector( + float traversalSimilarity, float resultSimilarity, long visitLimit) { + super(1, visitLimit); + if (traversalSimilarity > resultSimilarity) { + throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity"); + } + this.traversalSimilarity = traversalSimilarity; + this.resultSimilarity = resultSimilarity; + this.maxSimilarity = Float.NEGATIVE_INFINITY; + this.scoreDocList = new ArrayList<>(); + } + + @Override + public boolean collect(int docId, float similarity) { + maxSimilarity = Math.max(maxSimilarity, similarity); + if (similarity >= resultSimilarity) { + scoreDocList.add(new ScoreDoc(docId, similarity)); + } + return true; + } + + @Override + public float minCompetitiveSimilarity() { + return Math.min(traversalSimilarity, maxSimilarity); + } + + @Override + public TopDocs topDocs() { + // Results are not returned in a sorted order to prevent unnecessary calculations (because we do + // not need to maintain the topK) + TotalHits.Relation relation = + earlyTerminated() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + return new TopDocs( + new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new)); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java new file mode 100644 index 000000000000..98a1e177948c --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java @@ -0,0 +1,516 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.IntStream; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.LuceneTestCase; + +abstract class BaseVectorSimilarityQueryTestCase< + V, F extends Field, Q extends AbstractVectorSimilarityQuery> + extends LuceneTestCase { + String vectorField, idField; + VectorSimilarityFunction function; + int numDocs, dim; + + abstract V getRandomVector(int dim); + + abstract float compare(V vector1, V vector2); + + abstract boolean checkEquals(V vector1, V vector2); + + abstract F getVectorField(String name, V vector, VectorSimilarityFunction function); + + abstract Q getVectorQuery( + String field, V vector, float traversalSimilarity, float resultSimilarity, Query filter); + + abstract Q getThrowingVectorQuery( + String field, V vector, float traversalSimilarity, float resultSimilarity, Query filter); + + public void testEquals() { + String field1 = "f1", field2 = "f2"; + + V vector1 = getRandomVector(dim); + V vector2; + do { + vector2 = getRandomVector(dim); + } while (checkEquals(vector1, vector2)); + + float traversalSimilarity1 = 0.3f, traversalSimilarity2 = 0.4f; + float resultSimilarity1 = 0.4f, resultSimilarity2 = 0.5f; + + Query filter1 = new TermQuery(new Term("t1", "v1")); + Query filter2 = new TermQuery(new Term("t2", "v2")); + + Query query = getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter1); + + // Everything is equal + assertEquals( + query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter1)); + + // Null check + assertNotEquals(query, null); + + // Different field + assertNotEquals( + query, getVectorQuery(field2, vector1, traversalSimilarity1, resultSimilarity1, filter1)); + + // Different vector + assertNotEquals( + query, getVectorQuery(field1, vector2, traversalSimilarity1, resultSimilarity1, filter1)); + + // Different traversalSimilarity + assertNotEquals( + query, getVectorQuery(field1, vector1, traversalSimilarity2, resultSimilarity1, filter1)); + + // Different resultSimilarity + assertNotEquals( + query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity2, filter1)); + + // Different filter + assertNotEquals( + query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter2)); + } + + public void testEmptyIndex() throws IOException { + // Do not index any vectors + numDocs = 0; + + try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null); + + // Check that no vectors are found + assertEquals(0, searcher.count(query)); + } + } + + public void testExtremes() throws IOException { + try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + // All vectors are above -Infinity + Query query1 = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null); + + // Check that all vectors are found + assertEquals(numDocs, searcher.count(query1)); + + // No vectors are above +Infinity + Query query2 = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.POSITIVE_INFINITY, + Float.POSITIVE_INFINITY, + null); + + // Check that no vectors are found + assertEquals(0, searcher.count(query2)); + } + } + + public void testRandomFilter() throws IOException { + // Filter a sub-range from 0 to numDocs + int startIndex = random().nextInt(numDocs); + int endIndex = random().nextInt(startIndex, numDocs); + Query filter = IntField.newRangeQuery(idField, startIndex, endIndex); + + try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + filter); + + ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs; + for (ScoreDoc scoreDoc : scoreDocs) { + int id = getId(searcher, scoreDoc.doc); + + // Check that returned document is in selected range + assertTrue(id >= startIndex && id <= endIndex); + } + // Check that all filtered vectors are found + assertEquals(endIndex - startIndex + 1, scoreDocs.length); + } + } + + public void testFilterWithNoMatches() throws IOException { + try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + // Non-existent field + Query filter1 = new TermQuery(new Term("random_field", "random_value")); + Query query1 = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + filter1); + + // Check that no vectors are found + assertEquals(0, searcher.count(query1)); + + // Field exists, but value of -1 is not indexed + Query filter2 = IntField.newExactQuery(idField, -1); + Query query2 = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + filter2); + + // Check that no vectors are found + assertEquals(0, searcher.count(query2)); + } + } + + public void testDimensionMismatch() throws IOException { + // Different dimension + int newDim = atLeast(dim + 1); + + try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + getVectorQuery( + vectorField, + getRandomVector(newDim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null); + + // Check that an exception for differing dimensions is thrown + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> searcher.count(query)); + assertEquals( + String.format( + Locale.ROOT, + "vector query dimension: %d differs from field dimension: %d", + newDim, + dim), + e.getMessage()); + } + } + + public void testNonVectorsField() throws IOException { + try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + // Non-existent field + Query query1 = + getVectorQuery( + "random_field", + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null); + assertEquals(0, searcher.count(query1)); + + // Indexed as int field + Query query2 = + getVectorQuery( + idField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null); + assertEquals(0, searcher.count(query2)); + } + } + + public void testSomeDeletes() throws IOException { + // Delete a sub-range from 0 to numDocs + int startIndex = random().nextInt(numDocs); + int endIndex = random().nextInt(startIndex, numDocs); + Query delete = IntField.newRangeQuery(idField, startIndex, endIndex); + + try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); + IndexWriter w = new IndexWriter(indexStore, newIndexWriterConfig())) { + + w.deleteDocuments(delete); + w.commit(); + + try (IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null); + + ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs; + for (ScoreDoc scoreDoc : scoreDocs) { + int id = getId(searcher, scoreDoc.doc); + + // Check that returned document is not deleted + assertFalse(id >= startIndex && id <= endIndex); + } + // Check that all live docs are returned + assertEquals(numDocs - endIndex + startIndex - 1, scoreDocs.length); + } + } + } + + public void testAllDeletes() throws IOException { + try (Directory dir = getIndexStore(getRandomVectors(numDocs, dim)); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + // Delete all documents + w.deleteDocuments(new MatchAllDocsQuery()); + w.commit(); + + try (IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null); + + // Check that no vectors are found + assertEquals(0, searcher.count(query)); + } + } + } + + public void testBoostQuery() throws IOException { + // Define the boost and allowed delta + float boost = random().nextFloat(5, 10); + float delta = 1e-3f; + + try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim)); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + Query query1 = + getVectorQuery( + vectorField, + getRandomVector(dim), + Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, + null); + ScoreDoc[] scoreDocs1 = searcher.search(query1, numDocs).scoreDocs; + + Query query2 = new BoostQuery(query1, boost); + ScoreDoc[] scoreDocs2 = searcher.search(query2, numDocs).scoreDocs; + + // Check that all docs are identical, with boosted scores + assertEquals(scoreDocs1.length, scoreDocs2.length); + for (int i = 0; i < scoreDocs1.length; i++) { + assertEquals(scoreDocs1[i].doc, scoreDocs2[i].doc); + assertEquals(boost * scoreDocs1[i].score, scoreDocs2[i].score, delta); + } + } + } + + public void testVectorsAboveSimilarity() throws IOException { + // Pick number of docs to accept + int numAccepted = random().nextInt(numDocs / 3, numDocs / 2); + float delta = 1e-3f; + + V[] vectors = getRandomVectors(numDocs, dim); + V queryVector = getRandomVector(dim); + + // Find score above which we get (at least) numAccepted vectors + float resultSimilarity = getSimilarity(vectors, queryVector, numAccepted); + + // Cache scores of vectors + Map scores = new HashMap<>(); + for (int i = 0; i < numDocs; i++) { + float score = compare(queryVector, vectors[i]); + if (score >= resultSimilarity) { + scores.put(i, score); + } + } + + try (Directory indexStore = getIndexStore(vectors); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + getVectorQuery(vectorField, queryVector, Float.NEGATIVE_INFINITY, resultSimilarity, null); + + ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs; + for (ScoreDoc scoreDoc : scoreDocs) { + int id = getId(searcher, scoreDoc.doc); + + // Check that the collected result is above accepted similarity + assertTrue(scores.containsKey(id)); + + // Check that the score is correct + assertEquals(scores.get(id), scoreDoc.score, delta); + } + + // Check that all results are collected + assertEquals(scores.size(), scoreDocs.length); + } + } + + public void testFallbackToExact() throws IOException { + // Restrictive filter, along with similarity to visit a large number of nodes + int numFiltered = random().nextInt(numDocs / 10, numDocs / 5); + int targetVisited = random().nextInt(numFiltered * 2, numDocs); + + V[] vectors = getRandomVectors(numDocs, dim); + V queryVector = getRandomVector(dim); + + float resultSimilarity = getSimilarity(vectors, queryVector, targetVisited); + Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered)); + + try (Directory indexStore = getIndexStore(vectors); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + getThrowingVectorQuery( + vectorField, queryVector, resultSimilarity, resultSimilarity, filter); + + // Falls back to exact search + expectThrows(UnsupportedOperationException.class, () -> searcher.count(query)); + } + } + + public void testApproximate() throws IOException { + // Non-restrictive filter, along with similarity to visit a small number of nodes + int numFiltered = random().nextInt((numDocs * 4) / 5, numDocs); + int targetVisited = random().nextInt(numFiltered / 10, numFiltered / 8); + + V[] vectors = getRandomVectors(numDocs, dim); + V queryVector = getRandomVector(dim); + + float resultSimilarity = getSimilarity(vectors, queryVector, targetVisited); + Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered)); + + try (Directory indexStore = getIndexStore(vectors); + IndexWriter w = new IndexWriter(indexStore, newIndexWriterConfig())) { + // Force merge because smaller segments have few filtered docs and often fall back to exact + // search, making this test flaky + w.forceMerge(1); + w.commit(); + + try (IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + getThrowingVectorQuery( + vectorField, queryVector, resultSimilarity, resultSimilarity, filter); + + // Does not fall back to exact search + assertTrue(searcher.count(query) <= numFiltered); + } + } + } + + private float getSimilarity(V[] vectors, V queryVector, int targetVisited) { + assertTrue(targetVisited >= 0 && targetVisited <= numDocs); + if (targetVisited == 0) { + return Float.POSITIVE_INFINITY; + } + + float[] scores = new float[numDocs]; + for (int i = 0; i < numDocs; i++) { + scores[i] = compare(queryVector, vectors[i]); + } + Arrays.sort(scores); + + return scores[numDocs - targetVisited]; + } + + private int[] getFiltered(int numFiltered) { + Set accepted = new HashSet<>(); + for (int i = 0; i < numFiltered; ) { + int index = random().nextInt(numDocs); + if (!accepted.contains(index)) { + accepted.add(index); + i++; + } + } + return accepted.stream().mapToInt(Integer::intValue).toArray(); + } + + private int getId(IndexSearcher searcher, int doc) throws IOException { + return Objects.requireNonNull(searcher.storedFields().document(doc).getField(idField)) + .numericValue() + .intValue(); + } + + @SuppressWarnings("unchecked") + V[] getRandomVectors(int numDocs, int dim) { + return (V[]) IntStream.range(0, numDocs).mapToObj(i -> getRandomVector(dim)).toArray(); + } + + @SafeVarargs + final Directory getIndexStore(V... vectors) throws IOException { + Directory dir = newDirectory(); + try (RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) { + for (int i = 0; i < vectors.length; ++i) { + Document doc = new Document(); + doc.add(getVectorField(vectorField, vectors[i], function)); + doc.add(new IntField(idField, i, Field.Store.YES)); + writer.addDocument(doc); + } + } + return dir; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestByteVectorSimilarityQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestByteVectorSimilarityQuery.java new file mode 100644 index 000000000000..e2855999b44a --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestByteVectorSimilarityQuery.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.util.Arrays; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.TestVectorUtil; +import org.junit.Before; + +public class TestByteVectorSimilarityQuery + extends BaseVectorSimilarityQueryTestCase< + byte[], KnnByteVectorField, ByteVectorSimilarityQuery> { + + @Before + public void setup() { + vectorField = getClass().getSimpleName() + ":VectorField"; + idField = getClass().getSimpleName() + ":IdField"; + function = VectorSimilarityFunction.EUCLIDEAN; + numDocs = atLeast(100); + dim = atLeast(50); + } + + @Override + byte[] getRandomVector(int dim) { + return TestVectorUtil.randomVectorBytes(dim); + } + + @Override + float compare(byte[] vector1, byte[] vector2) { + return function.compare(vector1, vector2); + } + + @Override + boolean checkEquals(byte[] vector1, byte[] vector2) { + return Arrays.equals(vector1, vector2); + } + + @Override + KnnByteVectorField getVectorField(String name, byte[] vector, VectorSimilarityFunction function) { + return new KnnByteVectorField(name, vector, function); + } + + @Override + ByteVectorSimilarityQuery getVectorQuery( + String field, + byte[] vector, + float traversalSimilarity, + float resultSimilarity, + Query filter) { + return new ByteVectorSimilarityQuery( + field, vector, traversalSimilarity, resultSimilarity, filter); + } + + @Override + ByteVectorSimilarityQuery getThrowingVectorQuery( + String field, + byte[] vector, + float traversalSimilarity, + float resultSimilarity, + Query filter) { + return new ByteVectorSimilarityQuery( + field, vector, traversalSimilarity, resultSimilarity, filter) { + @Override + VectorScorer createVectorScorer(LeafReaderContext context) { + throw new UnsupportedOperationException(); + } + }; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestFloatVectorSimilarityQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestFloatVectorSimilarityQuery.java new file mode 100644 index 000000000000..9fe2d5965d22 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestFloatVectorSimilarityQuery.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.util.Arrays; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.TestVectorUtil; +import org.junit.Before; + +public class TestFloatVectorSimilarityQuery + extends BaseVectorSimilarityQueryTestCase< + float[], KnnFloatVectorField, FloatVectorSimilarityQuery> { + + @Before + public void setup() { + vectorField = getClass().getSimpleName() + ":VectorField"; + idField = getClass().getSimpleName() + ":IdField"; + function = VectorSimilarityFunction.EUCLIDEAN; + numDocs = atLeast(100); + dim = atLeast(50); + } + + @Override + float[] getRandomVector(int dim) { + return TestVectorUtil.randomVector(dim); + } + + @Override + float compare(float[] vector1, float[] vector2) { + return function.compare(vector1, vector2); + } + + @Override + boolean checkEquals(float[] vector1, float[] vector2) { + return Arrays.equals(vector1, vector2); + } + + @Override + KnnFloatVectorField getVectorField( + String name, float[] vector, VectorSimilarityFunction function) { + return new KnnFloatVectorField(name, vector, function); + } + + @Override + FloatVectorSimilarityQuery getVectorQuery( + String field, + float[] vector, + float traversalSimilarity, + float resultSimilarity, + Query filter) { + return new FloatVectorSimilarityQuery( + field, vector, traversalSimilarity, resultSimilarity, filter); + } + + @Override + FloatVectorSimilarityQuery getThrowingVectorQuery( + String field, + float[] vector, + float traversalSimilarity, + float resultSimilarity, + Query filter) { + return new FloatVectorSimilarityQuery( + field, vector, traversalSimilarity, resultSimilarity, filter) { + @Override + VectorScorer createVectorScorer(LeafReaderContext context) { + throw new UnsupportedOperationException(); + } + }; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityCollector.java new file mode 100644 index 000000000000..b0a80803d684 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestVectorSimilarityCollector.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestVectorSimilarityCollector extends LuceneTestCase { + public void testResultCollection() { + float traversalSimilarity = 0.3f, resultSimilarity = 0.5f; + + VectorSimilarityCollector collector = + new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, Integer.MAX_VALUE); + int[] nodes = {1, 5, 10, 4, 8, 3, 2, 6, 7, 9}; + float[] scores = {0.1f, 0.2f, 0.3f, 0.5f, 0.2f, 0.6f, 0.9f, 0.3f, 0.7f, 0.8f}; + + float[] minCompetitiveSimilarities = new float[nodes.length]; + for (int i = 0; i < nodes.length; i++) { + collector.collect(nodes[i], scores[i]); + minCompetitiveSimilarities[i] = collector.minCompetitiveSimilarity(); + } + + ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs; + int[] resultNodes = new int[scoreDocs.length]; + float[] resultScores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + resultNodes[i] = scoreDocs[i].doc; + resultScores[i] = scoreDocs[i].score; + } + + // All nodes above resultSimilarity appear in order of collection + assertArrayEquals(new int[] {4, 3, 2, 7, 9}, resultNodes); + assertArrayEquals(new float[] {0.5f, 0.6f, 0.9f, 0.7f, 0.8f}, resultScores, 1e-3f); + + // Min competitive similarity is minimum of traversalSimilarity or best result encountered + assertArrayEquals( + new float[] {0.1f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f}, + minCompetitiveSimilarities, + 1e-3f); + } +}