Skip to content

Commit d79e870

Browse files
authored
perf: improve PQ computing distances (lancedb#3150)
this is done by make the compiler know the size of distance table slice ``` 5242880,L2,PQ=96,DIM=1536 time: [148.44 ms 149.47 ms 150.50 ms] change: [-53.716% -53.486% -53.252%] (p = 0.00 < 0.10) Performance has improved. 5242880,Cosine,PQ=96,DIM=1536 time: [191.84 ms 192.21 ms 192.75 ms] change: [-46.738% -46.621% -46.461%] (p = 0.00 < 0.10) Performance has improved. ``` --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent bf2ce1f commit d79e870

File tree

4 files changed

+100
-38
lines changed

4 files changed

+100
-38
lines changed

rust/lance-index/Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ harness = false
8282
name = "pq_dist_table"
8383
harness = false
8484

85+
[[bench]]
86+
name = "4bitpq_dist_table"
87+
harness = false
88+
8589
[[bench]]
8690
name = "pq_assignment"
8791
harness = false
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright The Lance Authors
3+
4+
//! Benchmark of building PQ distance table.
5+
6+
use std::iter::repeat;
7+
8+
use arrow_array::types::Float32Type;
9+
use arrow_array::{FixedSizeListArray, UInt8Array};
10+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
11+
use lance_arrow::FixedSizeListArrayExt;
12+
use lance_index::vector::pq::ProductQuantizer;
13+
use lance_linalg::distance::DistanceType;
14+
use lance_testing::datagen::generate_random_array_with_seed;
15+
use rand::{prelude::StdRng, Rng, SeedableRng};
16+
17+
#[cfg(target_os = "linux")]
18+
use pprof::criterion::{Output, PProfProfiler};
19+
20+
const PQ: usize = 96;
21+
const DIM: usize = 1536;
22+
const TOTAL: usize = 16 * 1000;
23+
24+
fn dist_table(c: &mut Criterion) {
25+
let codebook = generate_random_array_with_seed::<Float32Type>(256 * DIM, [88; 32]);
26+
let query = generate_random_array_with_seed::<Float32Type>(DIM, [32; 32]);
27+
28+
let mut rnd = StdRng::from_seed([32; 32]);
29+
let code = UInt8Array::from_iter_values(repeat(rnd.gen::<u8>()).take(TOTAL * PQ));
30+
31+
for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot].iter() {
32+
let pq = ProductQuantizer::new(
33+
PQ,
34+
4,
35+
DIM,
36+
FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(),
37+
*dt,
38+
);
39+
40+
c.bench_function(
41+
format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(),
42+
|b| {
43+
b.iter(|| {
44+
black_box(pq.compute_distances(&query, &code).unwrap());
45+
})
46+
},
47+
);
48+
}
49+
}
50+
51+
#[cfg(target_os = "linux")]
52+
criterion_group!(
53+
name=benches;
54+
config = Criterion::default().significance_level(0.1).sample_size(10)
55+
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
56+
targets = dist_table);
57+
58+
#[cfg(not(target_os = "linux"))]
59+
criterion_group!(
60+
name=benches;
61+
config = Criterion::default().significance_level(0.1).sample_size(10);
62+
targets = dist_table);
63+
64+
criterion_main!(benches);

rust/lance-index/benches/pq_dist_table.rs

+18-33
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use pprof::criterion::{Output, PProfProfiler};
1919

2020
const PQ: usize = 96;
2121
const DIM: usize = 1536;
22-
const TOTAL: usize = 5 * 1024 * 1024;
22+
const TOTAL: usize = 16 * 1000;
2323

2424
fn dist_table(c: &mut Criterion) {
2525
let codebook = generate_random_array_with_seed::<Float32Type>(256 * DIM, [88; 32]);
@@ -28,39 +28,24 @@ fn dist_table(c: &mut Criterion) {
2828
let mut rnd = StdRng::from_seed([32; 32]);
2929
let code = UInt8Array::from_iter_values(repeat(rnd.gen::<u8>()).take(TOTAL * PQ));
3030

31-
let l2_pq = ProductQuantizer::new(
32-
PQ,
33-
8,
34-
DIM,
35-
FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(),
36-
DistanceType::L2,
37-
);
31+
for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot].iter() {
32+
let pq = ProductQuantizer::new(
33+
PQ,
34+
8,
35+
DIM,
36+
FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(),
37+
*dt,
38+
);
3839

39-
c.bench_function(
40-
format!("{},L2,PQ={},DIM={}", TOTAL, PQ, DIM).as_str(),
41-
|b| {
42-
b.iter(|| {
43-
black_box(l2_pq.compute_distances(&query, &code).unwrap().len());
44-
})
45-
},
46-
);
47-
48-
let cosine_pq = ProductQuantizer::new(
49-
PQ,
50-
8,
51-
DIM,
52-
FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
53-
DistanceType::Cosine,
54-
);
55-
56-
c.bench_function(
57-
format!("{},Cosine,PQ={},DIM={}", TOTAL, PQ, DIM).as_str(),
58-
|b| {
59-
b.iter(|| {
60-
black_box(cosine_pq.compute_distances(&query, &code).unwrap());
61-
})
62-
},
63-
);
40+
c.bench_function(
41+
format!("{},{},PQ={},DIM={}", TOTAL, dt, PQ, DIM).as_str(),
42+
|b| {
43+
b.iter(|| {
44+
black_box(pq.compute_distances(&query, &code).unwrap());
45+
})
46+
},
47+
);
48+
}
6449
}
6550

6651
#[cfg(target_os = "linux")]

rust/lance-index/src/vector/pq/distance.rs

+14-5
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,11 @@ pub(super) fn compute_l2_distance(
8080
// so code[i * num_vectors + j] is the code of i-th sub-vector of the j-th vector.
8181
let num_vectors = code.len() / num_sub_vectors;
8282
let mut distances = vec![0.0_f32; num_vectors];
83-
let num_centroids = 2_usize.pow(num_bits);
83+
// it must be 8
84+
const NUM_CENTROIDS: usize = 2_usize.pow(8);
8485
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
85-
let dist_table = &distance_table[sub_vec_idx * num_centroids..];
86+
let dist_table =
87+
&distance_table[sub_vec_idx * NUM_CENTROIDS..(sub_vec_idx + 1) * NUM_CENTROIDS];
8688
debug_assert_eq!(vec_indices.len(), distances.len());
8789
vec_indices
8890
.iter()
@@ -103,9 +105,16 @@ pub(super) fn compute_l2_distance_4bit(
103105
) -> Vec<f32> {
104106
let num_vectors = code.len() * 2 / num_sub_vectors;
105107
let mut distances = vec![0.0_f32; num_vectors];
106-
let num_centroids = 2_usize.pow(4);
108+
const NUM_CENTROIDS: usize = 2_usize.pow(4);
107109
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
108-
let dist_table = &distance_table[sub_vec_idx * 2 * num_centroids..];
110+
let dist_table: &[f32; NUM_CENTROIDS] = &distance_table
111+
[sub_vec_idx * 2 * NUM_CENTROIDS..(sub_vec_idx * 2 + 1) * NUM_CENTROIDS]
112+
.try_into()
113+
.unwrap();
114+
let dist_table_next: &[f32; NUM_CENTROIDS] = &distance_table
115+
[(sub_vec_idx * 2 + 1) * NUM_CENTROIDS..(sub_vec_idx * 2 + 2) * NUM_CENTROIDS]
116+
.try_into()
117+
.unwrap();
109118
debug_assert_eq!(vec_indices.len(), distances.len());
110119
vec_indices
111120
.iter()
@@ -115,7 +124,7 @@ pub(super) fn compute_l2_distance_4bit(
115124
let current_idx = centroid_idx & 0xF;
116125
let next_idx = centroid_idx >> 4;
117126
*sum += dist_table[current_idx as usize];
118-
*sum += dist_table[num_centroids + next_idx as usize];
127+
*sum += dist_table_next[next_idx as usize];
119128
});
120129
}
121130

0 commit comments

Comments
 (0)