Skip to content

Commit

Permalink
Add two new "Seeded" Knn queries for seeded vector search (#14084)
Browse files Browse the repository at this point in the history
### Description

In some vector search cases, users may already know some documents that are likely related to a query. Let's support seeding HNSW's scoring stage with these documents, rather than using HNSW's hierarchical stage.

An example use case is hybrid search, where both a traditional and vector search are performed. The top results from the traditional search are likely reasonable seeds for the vector search. Even when not performing hybrid search, traditional matching can often be faster than traversing the hierarchy, which can be used to speed up the vector search process (up to 2x faster for the same effectiveness), as was demonstrated in [this article](https://arxiv.org/abs/2307.16779) (full disclosure: seanmacavaney is an author of the article).

The main changes are:
 - A new "seeded" focused knn collector and collector manager
 - Two new basic knn queries that expose using these specialized collectors for seeded entrypoint
 - `HnswGraphSearcher`, which bypasses the `findBestEntryPoint` step if seeds are provided.


//cc @seanmacavaney

Co-authored-by: Sean MacAvaney <smacavaney@bloomberg.com>
Co-authored-by: Sean MacAvaney <sean.macavaney@gmail.com>
Co-authored-by: Christine Poerschke <cpoerschke@apache.org>
  • Loading branch information
4 people authored Jan 15, 2025
1 parent 905efa9 commit 34f0453
Show file tree
Hide file tree
Showing 17 changed files with 1,075 additions and 89 deletions.
6 changes: 5 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ API Changes

New Features
---------------------
(No changes)

* GITHUB#14084, GITHUB#13635, GITHUB#13634: Adds new `SeededKnnByteVectorQuery` and `SeededKnnFloatVectorQuery`
queries. These queries allow for the vector search entry points to be initialized via a `seed` query. This follows
the research provided via https://arxiv.org/abs/2307.16779. (Sean MacAvaney, Ben Trent).


Improvements
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {

private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;

private final byte[] target;
protected final byte[] target;

/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
Expand Down
54 changes: 54 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/KnnCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,58 @@ public interface KnnCollector {
* @return The collected top documents
*/
TopDocs topDocs();

/**
* KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend
* the object with new behaviors.
*
* @lucene.experimental
*/
abstract class Decorator implements KnnCollector {
private final KnnCollector collector;

public Decorator(KnnCollector collector) {
this.collector = collector;
}

@Override
public boolean earlyTerminated() {
return collector.earlyTerminated();
}

@Override
public void incVisitedCount(int count) {
collector.incVisitedCount(count);
}

@Override
public long visitedCount() {
return collector.visitedCount();
}

@Override
public long visitLimit() {
return collector.visitLimit();
}

@Override
public int k() {
return collector.k();
}

@Override
public boolean collect(int docId, float similarity) {
return collector.collect(docId, similarity);
}

@Override
public float minCompetitiveSimilarity() {
return collector.minCompetitiveSimilarity();
}

@Override
public TopDocs topDocs() {
return collector.topDocs();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery {

private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;

private final float[] target;
protected final float[] target;

/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.SeededKnnCollectorManager;

/**
* This is a version of knn byte vector query that provides a query seed to initiate the vector
* search. NOTE: The underlying format is free to ignore the provided seed
*
* <p>See <a href="https://dl.acm.org/doi/10.1145/3539618.3591715">"Lexically-Accelerated Dense
* Retrieval"</a> (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir).
* In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and
* Development in Information Retrieval Pages 152 - 162
*
* @lucene.experimental
*/
public class SeededKnnByteVectorQuery extends KnnByteVectorQuery {
final Query seed;
final Weight seedWeight;

/**
* Construct a new SeededKnnByteVectorQuery instance
*
* @param field knn byte vector field to query
* @param target the query vector
* @param k number of neighbors to return
* @param filter a filter on the neighbors to return
* @param seed a query seed to initiate the vector format search
*/
public SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) {
super(field, target, k, filter);
this.seed = Objects.requireNonNull(seed);
this.seedWeight = null;
}

SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Weight seedWeight) {
super(field, target, k, filter);
this.seed = null;
this.seedWeight = Objects.requireNonNull(seedWeight);
}

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (seedWeight != null) {
return super.rewrite(indexSearcher);
}
BooleanQuery.Builder booleanSeedQueryBuilder =
new BooleanQuery.Builder()
.add(seed, BooleanClause.Occur.MUST)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
if (filter != null) {
booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
}
Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
SeededKnnByteVectorQuery rewritten =
new SeededKnnByteVectorQuery(field, target, k, filter, seedWeight);
return rewritten.rewrite(indexSearcher);
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
if (seedWeight == null) {
throw new UnsupportedOperationException("must be rewritten before constructing manager");
}
return new SeededKnnCollectorManager(
super.getKnnCollectorManager(k, searcher),
seedWeight,
k,
leaf -> {
ByteVectorValues vv = leaf.getByteVectorValues(field);
if (vv == null) {
ByteVectorValues.checkField(leaf.getContext().reader(), field);
}
return vv;
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.SeededKnnCollectorManager;

/**
* This is a version of knn float vector query that provides a query seed to initiate the vector
* search. NOTE: The underlying format is free to ignore the provided seed.
*
* <p>See <a href="https://dl.acm.org/doi/10.1145/3539618.3591715">"Lexically-Accelerated Dense
* Retrieval"</a> (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir).
* In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and
* Development in Information Retrieval Pages 152 - 162
*
* @lucene.experimental
*/
public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery {
final Query seed;
final Weight seedWeight;

/**
* Construct a new SeededKnnFloatVectorQuery instance
*
* @param field knn float vector field to query
* @param target the query vector
* @param k number of neighbors to return
* @param filter a filter on the neighbors to return
* @param seed a query seed to initiate the vector format search
*/
public SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) {
super(field, target, k, filter);
this.seed = Objects.requireNonNull(seed);
this.seedWeight = null;
}

SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Weight seedWeight) {
super(field, target, k, filter);
this.seed = null;
this.seedWeight = Objects.requireNonNull(seedWeight);
}

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (seedWeight != null) {
return super.rewrite(indexSearcher);
}
BooleanQuery.Builder booleanSeedQueryBuilder =
new BooleanQuery.Builder()
.add(seed, BooleanClause.Occur.MUST)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER);
if (filter != null) {
booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
}
Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
SeededKnnFloatVectorQuery rewritten =
new SeededKnnFloatVectorQuery(field, target, k, filter, seedWeight);
return rewritten.rewrite(indexSearcher);
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
if (seedWeight == null) {
throw new UnsupportedOperationException("must be rewritten before constructing manager");
}
return new SeededKnnCollectorManager(
super.getKnnCollectorManager(k, searcher),
seedWeight,
k,
leaf -> {
FloatVectorValues vv = leaf.getFloatVectorValues(field);
if (vv == null) {
FloatVectorValues.checkField(leaf.getContext().reader(), field);
}
return vv;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,51 +45,19 @@ public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) th
return new TimeLimitingKnnCollector(collector);
}

class TimeLimitingKnnCollector implements KnnCollector {
private final KnnCollector collector;

TimeLimitingKnnCollector(KnnCollector collector) {
this.collector = collector;
class TimeLimitingKnnCollector extends KnnCollector.Decorator {
public TimeLimitingKnnCollector(KnnCollector collector) {
super(collector);
}

@Override
public boolean earlyTerminated() {
return queryTimeout.shouldExit() || collector.earlyTerminated();
}

@Override
public void incVisitedCount(int count) {
collector.incVisitedCount(count);
}

@Override
public long visitedCount() {
return collector.visitedCount();
}

@Override
public long visitLimit() {
return collector.visitLimit();
}

@Override
public int k() {
return collector.k();
}

@Override
public boolean collect(int docId, float similarity) {
return collector.collect(docId, similarity);
}

@Override
public float minCompetitiveSimilarity() {
return collector.minCompetitiveSimilarity();
return queryTimeout.shouldExit() || super.earlyTerminated();
}

@Override
public TopDocs topDocs() {
TopDocs docs = collector.topDocs();
TopDocs docs = super.topDocs();

// Mark results as partial if timeout is met
TotalHits.Relation relation =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.knn;

import org.apache.lucene.search.DocIdSetIterator;

/** Provides entry points for the kNN search */
public interface EntryPointProvider {
/** Iterator of valid entry points for the kNN search */
DocIdSetIterator entryPoints();

/** Number of valid entry points for the kNN search */
int numberOfEntryPoints();
}
Loading

0 comments on commit 34f0453

Please sign in to comment.