Skip to content

Commit bcf9e09

Browse files
authored
fix: the distance for multivector query is not correct (#3522)
the dist should be `dist = sum(1 - sim)` for multivector query, but we set it `dist = 1 - sum(sim)`. The order of results is still correct but let's make it consistent Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent b8a74ce commit bcf9e09

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

python/python/tests/test_vector_index.py

+11
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,17 @@ def test_multivec_ann(indexed_multivec_dataset: lance.LanceDataset):
558558
assert results["vector"].type == pa.list_(pa.list_(pa.float32(), 128))
559559
assert len(results["vector"][0]) == 5
560560

561+
query = [query, query]
562+
doubled_results = indexed_multivec_dataset.to_table(
563+
nearest={"column": "vector", "q": query, "k": 100}
564+
)
565+
assert len(results) == len(doubled_results)
566+
for i in range(len(results)):
567+
assert (
568+
results["_distance"][i].as_py() * 2
569+
== doubled_results["_distance"][i].as_py()
570+
)
571+
561572
# query with a vector that dim not match
562573
query = np.random.rand(256)
563574
with pytest.raises(ValueError, match="does not match index column size"):

rust/lance/src/io/exec/knn.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,7 @@ impl ExecutionPlan for MultivectorScoringExec {
813813

814814
let k = self.query.k;
815815
let refactor = self.query.refine_factor.unwrap_or(1) as usize;
816+
let num_queries = self.inputs.len() as f32;
816817
let stream = stream::once(async move {
817818
// at most, we will have k * refine_factor results for each query
818819
let mut results = HashMap::with_capacity(k * refactor);
@@ -850,7 +851,7 @@ impl ExecutionPlan for MultivectorScoringExec {
850851
let dists = sims
851852
.into_iter()
852853
// it's similarity, so we need to convert it back to distance
853-
.map(|sim| 1.0 - sim)
854+
.map(|sim| num_queries - sim)
854855
.collect::<Vec<_>>();
855856
let row_ids = UInt64Array::from(row_ids);
856857
let dists = Float32Array::from(dists);

0 commit comments

Comments
 (0)