Skip to content

Commit 1e08425

Browse files
authored
perf: avoid copying of creating memory dist calculator (lancedb#2219)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent 99008c6 commit 1e08425

File tree

13 files changed

+52
-43
lines changed

13 files changed

+52
-43
lines changed

rust/lance-index/benches/hnsw.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ fn bench_hnsw(c: &mut Criterion) {
4444
.await
4545
.unwrap();
4646
let uids: HashSet<u32> = hnsw
47-
.search(query, K, 300, None)
47+
.search(query.clone(), K, 300, None)
4848
.unwrap()
4949
.iter()
5050
.map(|node| node.id)

rust/lance-index/src/vector/graph/memory.rs

+11-9
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
use std::sync::Arc;
77

88
use super::storage::{DistCalculator, VectorStorage};
9+
use arrow::array::AsArray;
910
use arrow_array::types::Float32Type;
11+
use arrow_array::ArrayRef;
1012
use lance_linalg::{distance::MetricType, MatrixView};
1113

1214
/// All data are stored in memory
@@ -26,7 +28,7 @@ impl InMemoryVectorStorage {
2628
}
2729
}
2830

29-
pub fn vector(&self, id: u32) -> &[f32] {
31+
pub fn vector(&self, id: u32) -> ArrayRef {
3032
self.vectors.row(id as usize).unwrap()
3133
}
3234
}
@@ -48,40 +50,40 @@ impl VectorStorage for InMemoryVectorStorage {
4850
self.metric_type
4951
}
5052

51-
fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator> {
53+
fn dist_calculator(&self, query: ArrayRef) -> Box<dyn DistCalculator> {
5254
Box::new(InMemoryDistanceCal {
5355
vectors: self.vectors.clone(),
54-
query: query.to_vec(),
56+
query,
5557
metric_type: self.metric_type,
5658
})
5759
}
5860

5961
fn dist_calculator_from_id(&self, id: u32) -> Box<dyn DistCalculator> {
6062
Box::new(InMemoryDistanceCal {
6163
vectors: self.vectors.clone(),
62-
query: self.vectors.row(id as usize).unwrap().to_vec(),
64+
query: self.vectors.row(id as usize).unwrap(),
6365
metric_type: self.metric_type,
6466
})
6567
}
6668

6769
/// Distance between two vectors.
6870
fn distance_between(&self, a: u32, b: u32) -> f32 {
69-
let vector1 = self.vectors.row(a as usize).unwrap();
70-
let vector2 = self.vectors.row(b as usize).unwrap();
71+
let vector1 = self.vectors.row_ref(a as usize).unwrap();
72+
let vector2 = self.vectors.row_ref(b as usize).unwrap();
7173
self.metric_type.func()(vector1, vector2)
7274
}
7375
}
7476

7577
struct InMemoryDistanceCal {
7678
vectors: Arc<MatrixView<Float32Type>>,
77-
query: Vec<f32>,
79+
query: ArrayRef,
7880
metric_type: MetricType,
7981
}
8082

8183
impl DistCalculator for InMemoryDistanceCal {
8284
#[inline]
8385
fn distance(&self, id: u32) -> f32 {
84-
let vector = self.vectors.row(id as usize).unwrap();
85-
self.metric_type.func()(&self.query, vector)
86+
let vector = self.vectors.row_ref(id as usize).unwrap();
87+
self.metric_type.func()(self.query.as_primitive::<Float32Type>().values(), vector)
8688
}
8789
}

rust/lance-index/src/vector/graph/storage.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use std::any::Any;
55

6+
use arrow_array::ArrayRef;
67
use lance_linalg::distance::MetricType;
78

89
pub trait DistCalculator {
@@ -35,7 +36,7 @@ pub trait VectorStorage: Send + Sync {
3536
///
3637
/// Using dist calcualtor can be more efficient as it can pre-compute some
3738
/// values.
38-
fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator>;
39+
fn dist_calculator(&self, query: ArrayRef) -> Box<dyn DistCalculator>;
3940

4041
fn dist_calculator_from_id(&self, id: u32) -> Box<dyn DistCalculator>;
4142

rust/lance-index/src/vector/hnsw.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::ops::Range;
1212
use std::sync::Arc;
1313

1414
use arrow::datatypes::UInt32Type;
15+
use arrow_array::ArrayRef;
1516
use arrow_array::{
1617
builder::{ListBuilder, UInt32Builder},
1718
cast::AsArray,
@@ -347,7 +348,7 @@ impl HNSW {
347348
/// A list of `(id_in_graph, distance)` pairs. Or Error if the search failed.
348349
pub fn search(
349350
&self,
350-
query: &[f32],
351+
query: ArrayRef,
351352
k: usize,
352353
ef: usize,
353354
bitset: Option<RoaringBitmap>,
@@ -593,7 +594,7 @@ mod tests {
593594
fn ground_truth(mat: &MatrixView<Float32Type>, query: &[f32], k: usize) -> HashSet<u32> {
594595
let mut dists = vec![];
595596
for i in 0..mat.num_rows() {
596-
let dist = lance_linalg::distance::l2_distance(query, mat.row(i).unwrap());
597+
let dist = lance_linalg::distance::l2_distance(query, mat.row_ref(i).unwrap());
597598
dists.push((dist, i as u32));
598599
}
599600
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
@@ -625,12 +626,12 @@ mod tests {
625626
.unwrap();
626627

627628
let results: HashSet<u32> = hnsw
628-
.search(q, K, 128, None)
629+
.search(q.clone(), K, 128, None)
629630
.unwrap()
630631
.iter()
631632
.map(|node| node.id)
632633
.collect();
633-
let gt = ground_truth(&mat, q, K);
634+
let gt = ground_truth(&mat, q.as_primitive::<Float32Type>().values(), K);
634635
let recall = results.intersection(&gt).count() as f32 / K as f32;
635636
assert!(recall >= 0.9, "Recall: {}", recall);
636637
}

rust/lance-index/src/vector/ivf.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ impl<T: ArrowFloatType + Dot + L2 + ArrowPrimitiveType> Ivf for IvfImpl<T> {
458458
.chunks_exact(dim)
459459
.zip(part_ids.values())
460460
.flat_map(|(vector, &part_id)| {
461-
let centroid = self.centroids.row(part_id as usize).unwrap();
461+
let centroid = self.centroids.row_ref(part_id as usize).unwrap();
462462
vector.iter().zip(centroid.iter()).map(|(&v, &c)| v - c)
463463
})
464464
.collect::<Vec<_>>();

rust/lance-index/src/vector/pq/storage.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
88
use std::{cmp::min, collections::HashMap, sync::Arc};
99

10+
use arrow_array::ArrayRef;
1011
use arrow_array::{
1112
cast::AsArray,
1213
types::{Float32Type, UInt64Type, UInt8Type},
@@ -410,13 +411,13 @@ impl VectorStorage for ProductQuantizationStorage {
410411
self.metric_type
411412
}
412413

413-
fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator> {
414+
fn dist_calculator(&self, query: ArrayRef) -> Box<dyn DistCalculator> {
414415
Box::new(PQDistCalculator::new(
415416
self.codebook.values(),
416417
self.num_bits,
417418
self.num_sub_vectors,
418419
self.pq_code.clone(),
419-
query,
420+
query.as_primitive::<Float32Type>().values(),
420421
self.metric_type(),
421422
))
422423
}

rust/lance-index/src/vector/pq/utils.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub(super) fn divide_to_subvectors<T: ArrowFloatType>(
2525
for i in 0..m {
2626
let mut builder = Vec::with_capacity(capacity);
2727
for j in 0..data.num_rows() {
28-
let row = data.row(j).unwrap();
28+
let row = data.row_ref(j).unwrap();
2929
let start = i * sub_vector_length;
3030
builder.extend_from_slice(&row[start..start + sub_vector_length]);
3131
}

rust/lance-index/src/vector/residual.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ impl<T: ArrowFloatType> Transformer for ResidualTransform<T> {
9797
.chunks_exact(dim as usize)
9898
.zip(part_ids.as_primitive::<UInt32Type>().values().iter())
9999
.for_each(|(vector, &part_id)| {
100-
let centroid = self.centroids.row(part_id as usize).unwrap();
100+
let centroid = self.centroids.row_ref(part_id as usize).unwrap();
101101
// TODO: SIMD
102102
residual_arr.extend(
103103
vector

rust/lance-index/src/vector/sq/storage.rs

+5-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
use std::{ops::Range, sync::Arc};
55

66
use arrow::{array::AsArray, datatypes::Float32Type};
7-
use arrow_array::{Array, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array};
7+
use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array};
88
use async_trait::async_trait;
99
use lance_core::{Error, Result, ROW_ID};
1010
use lance_file::reader::FileReader;
@@ -212,7 +212,7 @@ impl VectorStorage for ScalarQuantizationStorage {
212212
///
213213
/// Using dist calcualtor can be more efficient as it can pre-compute some
214214
/// values.
215-
fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator> {
215+
fn dist_calculator(&self, query: ArrayRef) -> Box<dyn DistCalculator> {
216216
Box::new(SQDistCalculator::new(
217217
query,
218218
self.sq_codes.clone(),
@@ -243,12 +243,9 @@ struct SQDistCalculator {
243243
}
244244

245245
impl SQDistCalculator {
246-
fn new(query: &[f32], sq_codes: Arc<FixedSizeListArray>, bounds: Range<f64>) -> Self {
247-
// TODO: support f16/f64
248-
let query_sq_code = scale_to_u8::<Float32Type>(query, bounds)
249-
.into_iter()
250-
.collect::<Vec<_>>();
251-
246+
fn new(query: ArrayRef, sq_codes: Arc<FixedSizeListArray>, bounds: Range<f64>) -> Self {
247+
let query_sq_code =
248+
scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), bounds);
252249
Self {
253250
query_sq_code,
254251
sq_codes,

rust/lance-linalg/src/matrix.rs

+16-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
use std::sync::Arc;
88

9-
use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray};
9+
use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, FixedSizeListArray};
1010
use arrow_schema::{ArrowError, DataType};
1111
use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, FloatType};
1212
use num_traits::{AsPrimitive, Float, FromPrimitive, ToPrimitive};
@@ -134,7 +134,7 @@ impl<T: ArrowFloatType> MatrixView<T> {
134134
/// Returns a row at index `i`. Returns `None` if the index is out of bound.
135135
///
136136
/// # Panics if the matrix is transposed.
137-
pub fn row(&self, i: usize) -> Option<&[T::Native]> {
137+
pub fn row_ref(&self, i: usize) -> Option<&[T::Native]> {
138138
assert!(
139139
!self.transpose,
140140
"Centroid is not defined for transposed matrix."
@@ -147,6 +147,19 @@ impl<T: ArrowFloatType> MatrixView<T> {
147147
}
148148
}
149149

150+
pub fn row(&self, i: usize) -> Option<ArrayRef> {
151+
assert!(
152+
!self.transpose,
153+
"Centroid is not defined for transposed matrix."
154+
);
155+
if i >= self.num_rows() {
156+
None
157+
} else {
158+
let dim = self.num_columns();
159+
Some(self.data.slice(i * dim, dim))
160+
}
161+
}
162+
150163
/// Compute the centroid from all the rows. Returns `None` if this matrix is empty.
151164
///
152165
/// # Panics if the matrix is transposed.
@@ -359,7 +372,7 @@ impl<'a, T: ArrowFloatType> Iterator for MatrixRowIter<'a, T> {
359372
fn next(&mut self) -> Option<Self::Item> {
360373
let cur_idx = self.cur_idx;
361374
self.cur_idx += 1;
362-
self.data.row(cur_idx)
375+
self.data.row_ref(cur_idx)
363376
}
364377
}
365378

rust/lance/examples/hnsw.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct Args {
4444
fn ground_truth(mat: &MatrixView<Float32Type>, query: &[f32], k: usize) -> HashSet<u32> {
4545
let mut dists = vec![];
4646
for i in 0..mat.num_rows() {
47-
let dist = lance_linalg::distance::l2_distance(query, mat.row(i).unwrap());
47+
let dist = lance_linalg::distance::l2_distance(query, mat.row_ref(i).unwrap());
4848
dists.push((dist, i as u32));
4949
}
5050
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
@@ -80,7 +80,7 @@ async fn main() {
8080

8181
let q = mat.row(0).unwrap();
8282
let k = 10;
83-
let gt = ground_truth(&mat, q, k);
83+
let gt = ground_truth(&mat, q.as_primitive::<Float32Type>().values(), k);
8484

8585
for ef_construction in [15, 30, 50] {
8686
let now = std::time::Instant::now();
@@ -98,7 +98,7 @@ async fn main() {
9898
let construct_time = now.elapsed().as_secs_f32();
9999
let now = std::time::Instant::now();
100100
let results: HashSet<u32> = hnsw
101-
.search(q, k, args.ef, None)
101+
.search(q.clone(), k, args.ef, None)
102102
.unwrap()
103103
.iter()
104104
.map(|node| node.id)

rust/lance/src/index/vector/hnsw.rs

+2-8
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ use std::{
88
sync::Arc,
99
};
1010

11-
use arrow_array::{cast::AsArray, types::Float32Type, Float32Array, RecordBatch, UInt64Array};
11+
use arrow_array::{Float32Array, RecordBatch, UInt64Array};
1212

1313
use arrow_schema::DataType;
1414
use async_trait::async_trait;
15-
use lance_arrow::*;
1615
use lance_core::{datatypes::Schema, Error, Result, ROW_ID};
1716
use lance_file::reader::FileReader;
1817
use lance_index::{
@@ -164,12 +163,7 @@ impl<Q: Quantization + Send + Sync + 'static> VectorIndex for HNSWIndex<Q> {
164163
});
165164
}
166165

167-
let results = self.hnsw.search(
168-
query.key.as_primitive::<Float32Type>().as_slice(),
169-
k,
170-
ef,
171-
bitmap,
172-
)?;
166+
let results = self.hnsw.search(query.key.clone(), k, ef, bitmap)?;
173167

174168
let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| row_ids[x.id as usize]));
175169
let distances = Arc::new(Float32Array::from_iter_values(

rust/lance/src/index/vector/ivf.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2210,7 +2210,7 @@ mod tests {
22102210
fn ground_truth(mat: &MatrixView<Float32Type>, query: &[f32], k: usize) -> HashSet<u32> {
22112211
let mut dists = vec![];
22122212
for i in 0..mat.num_rows() {
2213-
let dist = lance_linalg::distance::l2_distance(query, mat.row(i).unwrap());
2213+
let dist = lance_linalg::distance::l2_distance(query, mat.row_ref(i).unwrap());
22142214
dists.push((dist, i as u32));
22152215
}
22162216
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

0 commit comments

Comments
 (0)