Skip to content

Commit c58814a

Browse files
authored
fix: avoid divide-by-zero when training an index with a large dimension (#3426)
1 parent a7c5216 commit c58814a

File tree

3 files changed

+92
-30
lines changed

3 files changed

+92
-30
lines changed

rust/lance-io/src/object_store.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ impl ObjectStore {
501501
Self {
502502
inner: Arc::new(InMemory::new()).traced(),
503503
scheme: String::from("memory"),
504-
block_size: 64 * 1024,
504+
block_size: 4 * 1024,
505505
use_constant_size_upload_parts: false,
506506
list_is_lexically_ordered: true,
507507
io_parallelism: get_num_compute_intensive_cpus(),
@@ -977,7 +977,7 @@ async fn configure_store(
977977
"memory" => Ok(ObjectStore {
978978
inner: Arc::new(InMemory::new()).traced(),
979979
scheme: String::from("memory"),
980-
block_size: cloud_block_size,
980+
block_size: file_block_size,
981981
use_constant_size_upload_parts: false,
982982
list_is_lexically_ordered: true,
983983
io_parallelism: get_num_compute_intensive_cpus(),
@@ -1219,7 +1219,6 @@ mod tests {
12191219
#[rstest]
12201220
#[case("s3://bucket/foo.lance", None)]
12211221
#[case("gs://bucket/foo.lance", None)]
1222-
#[case("memory:///bucket/foo.lance", None)]
12231222
#[case("az://account/bucket/foo.lance",
12241223
Some(HashMap::from([
12251224
(String::from("account_name"), String::from("account")),
@@ -1236,6 +1235,7 @@ mod tests {
12361235
#[rstest]
12371236
#[case("file")]
12381237
#[case("file-object-store")]
1238+
#[case("memory:///bucket/foo.lance")]
12391239
#[tokio::test]
12401240
async fn test_block_size_used_file(#[case] prefix: &str) {
12411241
let tmp_dir = tempfile::tempdir().unwrap();

rust/lance/src/index/vector/ivf.rs

+88-26
Original file line numberDiff line numberDiff line change
@@ -2223,42 +2223,102 @@ mod tests {
22232223
.await;
22242224
}
22252225

2226+
struct TestPqParams {
2227+
num_sub_vectors: usize,
2228+
num_bits: usize,
2229+
}
2230+
2231+
impl TestPqParams {
2232+
fn small() -> Self {
2233+
Self {
2234+
num_sub_vectors: 2,
2235+
num_bits: 8,
2236+
}
2237+
}
2238+
}
2239+
2240+
// Clippy doesn't like that all start with Ivf but we might have some in the future
2241+
// that _don't_ start with Ivf so I feel it is meaningful to keep the prefix
2242+
#[allow(clippy::enum_variant_names)]
2243+
enum TestIndexType {
2244+
IvfPq { pq: TestPqParams },
2245+
IvfHnswPq { pq: TestPqParams, num_edges: usize },
2246+
IvfHnswSq { num_edges: usize },
2247+
IvfFlat,
2248+
}
2249+
2250+
struct CreateIndexCase {
2251+
metric_type: MetricType,
2252+
num_partitions: usize,
2253+
dimension: usize,
2254+
index_type: TestIndexType,
2255+
}
2256+
22262257
// We test L2 and Dot, because L2 PQ uses residuals while Dot doesn't,
22272258
// so they have slightly different code paths.
22282259
#[tokio::test]
22292260
#[rstest]
2230-
#[case::ivf_pq_l2(VectorIndexParams::with_ivf_pq_params(
2231-
MetricType::L2,
2232-
IvfBuildParams::new(2),
2233-
PQBuildParams::new(2, 8),
2234-
))]
2235-
#[case::ivf_pq_dot(VectorIndexParams::with_ivf_pq_params(
2236-
MetricType::Dot,
2237-
IvfBuildParams::new(2),
2238-
PQBuildParams::new(2, 8),
2239-
))]
2240-
#[case::ivf_flat(VectorIndexParams::ivf_flat(1, MetricType::Dot))]
2241-
#[case::ivf_hnsw_pq(VectorIndexParams::with_ivf_hnsw_pq_params(
2242-
MetricType::Dot,
2243-
IvfBuildParams::new(2),
2244-
HnswBuildParams::default().num_edges(100),
2245-
PQBuildParams::new(2, 8)
2246-
))]
2247-
#[case::ivf_hnsw_sq(VectorIndexParams::with_ivf_hnsw_sq_params(
2248-
MetricType::Dot,
2249-
IvfBuildParams::new(2),
2250-
HnswBuildParams::default().num_edges(100),
2251-
SQBuildParams::default()
2252-
))]
2261+
#[case::ivf_pq_l2(CreateIndexCase {
2262+
metric_type: MetricType::L2,
2263+
num_partitions: 2,
2264+
dimension: 16,
2265+
index_type: TestIndexType::IvfPq { pq: TestPqParams::small() },
2266+
})]
2267+
#[case::ivf_pq_dot(CreateIndexCase {
2268+
metric_type: MetricType::Dot,
2269+
num_partitions: 2,
2270+
dimension: 2000,
2271+
index_type: TestIndexType::IvfPq { pq: TestPqParams::small() },
2272+
})]
2273+
#[case::ivf_flat(CreateIndexCase { num_partitions: 1, metric_type: MetricType::Dot, dimension: 16, index_type: TestIndexType::IvfFlat })]
2274+
#[case::ivf_hnsw_pq(CreateIndexCase {
2275+
num_partitions: 2,
2276+
metric_type: MetricType::Dot,
2277+
dimension: 16,
2278+
index_type: TestIndexType::IvfHnswPq { pq: TestPqParams::small(), num_edges: 100 },
2279+
})]
2280+
#[case::ivf_hnsw_sq(CreateIndexCase {
2281+
metric_type: MetricType::Dot,
2282+
num_partitions: 2,
2283+
dimension: 16,
2284+
index_type: TestIndexType::IvfHnswSq { num_edges: 100 },
2285+
})]
22532286
async fn test_create_index_nulls(
2254-
#[case] mut index_params: VectorIndexParams,
2287+
#[case] test_case: CreateIndexCase,
22552288
#[values(IndexFileVersion::Legacy, IndexFileVersion::V3)] index_version: IndexFileVersion,
22562289
) {
2290+
let mut index_params = match test_case.index_type {
2291+
TestIndexType::IvfPq { pq } => VectorIndexParams::with_ivf_pq_params(
2292+
test_case.metric_type,
2293+
IvfBuildParams::new(test_case.num_partitions),
2294+
PQBuildParams::new(pq.num_sub_vectors, pq.num_bits),
2295+
),
2296+
TestIndexType::IvfHnswPq { pq, num_edges } => {
2297+
VectorIndexParams::with_ivf_hnsw_pq_params(
2298+
test_case.metric_type,
2299+
IvfBuildParams::new(test_case.num_partitions),
2300+
HnswBuildParams::default().num_edges(num_edges),
2301+
PQBuildParams::new(pq.num_sub_vectors, pq.num_bits),
2302+
)
2303+
}
2304+
TestIndexType::IvfFlat => {
2305+
VectorIndexParams::ivf_flat(test_case.num_partitions, test_case.metric_type)
2306+
}
2307+
TestIndexType::IvfHnswSq { num_edges } => VectorIndexParams::with_ivf_hnsw_sq_params(
2308+
test_case.metric_type,
2309+
IvfBuildParams::new(test_case.num_partitions),
2310+
HnswBuildParams::default().num_edges(num_edges),
2311+
SQBuildParams::default(),
2312+
),
2313+
};
22572314
index_params.version(index_version);
22582315

22592316
let nrows = 2_000;
22602317
let data = gen()
2261-
.col("vec", array::rand_vec::<Float32Type>(Dimension::from(16)))
2318+
.col(
2319+
"vec",
2320+
array::rand_vec::<Float32Type>(Dimension::from(test_case.dimension as u32)),
2321+
)
22622322
.into_batch_rows(RowCount::from(nrows))
22632323
.unwrap();
22642324

@@ -2287,7 +2347,9 @@ mod tests {
22872347
.await
22882348
.unwrap();
22892349

2290-
let query = vec![0.0; 16].into_iter().collect::<Float32Array>();
2350+
let query = vec![0.0; test_case.dimension]
2351+
.into_iter()
2352+
.collect::<Float32Array>();
22912353
let results = dataset
22922354
.scan()
22932355
.nearest("vec", &query, 2_000)

rust/lance/src/index/vector/utils.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ fn random_ranges(
274274
block_size: usize,
275275
byte_width: usize,
276276
) -> impl Iterator<Item = std::ops::Range<u64>> + Send {
277-
let rows_per_batch = block_size / byte_width;
277+
let rows_per_batch = 1.max(block_size / byte_width);
278278
let mut rng = SmallRng::from_entropy();
279279
let num_bins = num_rows.div_ceil(rows_per_batch);
280280

0 commit comments

Comments
 (0)