Skip to content

Commit 7aa7d94

Browse files
authored
fix: handle null vectors in flat search (#3422)
1 parent 7c34f14 commit 7aa7d94

File tree

3 files changed

+81
-8
lines changed

3 files changed

+81
-8
lines changed

rust/lance-index/src/vector/flat.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
use std::sync::Arc;
88

9-
use arrow::array::AsArray;
9+
use arrow::{array::AsArray, buffer::NullBuffer};
1010
use arrow_array::{make_array, Array, ArrayRef, Float32Array, RecordBatch};
1111
use arrow_schema::{DataType, Field as ArrowField};
1212
use lance_arrow::*;
@@ -44,9 +44,9 @@ pub async fn compute_distance(
4444
.clone();
4545

4646
let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) {
47-
rowids.nulls().map(|nulls| nulls.buffer().clone())
47+
NullBuffer::union(rowids.nulls(), vectors.nulls())
4848
} else {
49-
None
49+
vectors.nulls().cloned()
5050
};
5151

5252
tokio::task::spawn_blocking(move || {
@@ -56,7 +56,7 @@ pub async fn compute_distance(
5656
let vectors = vectors
5757
.into_data()
5858
.into_builder()
59-
.null_bit_buffer(validity_buffer)
59+
.null_bit_buffer(validity_buffer.map(|b| b.buffer().clone()))
6060
.build()
6161
.map(make_array)?;
6262
let distances = match vectors.data_type() {

rust/lance-index/src/vector/pq/distance.rs

+3
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ pub(super) fn compute_pq_distance(
105105
num_sub_vectors: usize,
106106
code: &[u8],
107107
) -> Vec<f32> {
108+
if code.is_empty() {
109+
return Vec::new();
110+
}
108111
if num_bits == 4 {
109112
return compute_pq_distance_4bit(distance_table, num_sub_vectors, code);
110113
}

rust/lance/src/index/vector/ivf.rs

+74-4
Original file line numberDiff line numberDiff line change
@@ -1742,14 +1742,15 @@ mod tests {
17421742

17431743
use arrow_array::types::UInt64Type;
17441744
use arrow_array::{
1745-
make_array, Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array,
1745+
make_array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator,
1746+
RecordBatchReader, UInt64Array,
17461747
};
17471748
use arrow_buffer::{BooleanBuffer, NullBuffer};
1748-
use arrow_schema::Field;
1749+
use arrow_schema::{DataType, Field, Schema};
17491750
use itertools::Itertools;
17501751
use lance_core::utils::address::RowAddress;
17511752
use lance_core::ROW_ID;
1752-
use lance_datagen::{array, gen, Dimension, RowCount};
1753+
use lance_datagen::{array, gen, ArrayGeneratorExt, Dimension, RowCount};
17531754
use lance_index::vector::sq::builder::SQBuildParams;
17541755
use lance_linalg::distance::l2_distance_batch;
17551756
use lance_testing::datagen::{
@@ -1760,7 +1761,7 @@ mod tests {
17601761
use rstest::rstest;
17611762
use tempfile::tempdir;
17621763

1763-
use crate::dataset::InsertBuilder;
1764+
use crate::dataset::{InsertBuilder, WriteMode, WriteParams};
17641765
use crate::index::prefilter::DatasetPreFilter;
17651766
use crate::index::vector::IndexFileVersion;
17661767
use crate::index::vector_index_details;
@@ -2300,6 +2301,75 @@ mod tests {
23002301
assert_eq!(results["vec"].logical_null_count(), 0);
23012302
}
23022303

2304+
#[tokio::test]
2305+
async fn test_index_lifecycle_nulls() {
2306+
// Generate random data with nulls
2307+
let nrows = 2_000;
2308+
let dims = 32;
2309+
let data = gen()
2310+
.col(
2311+
"vec",
2312+
array::rand_vec::<Float32Type>(Dimension::from(dims as u32)).with_random_nulls(0.5),
2313+
)
2314+
.into_batch_rows(RowCount::from(nrows))
2315+
.unwrap();
2316+
let num_non_null = data["vec"].len() - data["vec"].logical_null_count();
2317+
2318+
let mut dataset = InsertBuilder::new("memory://")
2319+
.execute(vec![data])
2320+
.await
2321+
.unwrap();
2322+
2323+
// Create index
2324+
let index_params = VectorIndexParams::with_ivf_pq_params(
2325+
MetricType::L2,
2326+
IvfBuildParams::new(2),
2327+
PQBuildParams::new(2, 8),
2328+
);
2329+
dataset
2330+
.create_index(&["vec"], IndexType::Vector, None, &index_params, false)
2331+
.await
2332+
.unwrap();
2333+
2334+
// Check that the index is working
2335+
async fn check_index(dataset: &Dataset, num_non_null: usize, dims: usize) {
2336+
let query = vec![0.0; dims].into_iter().collect::<Float32Array>();
2337+
let results = dataset
2338+
.scan()
2339+
.nearest("vec", &query, 2_000)
2340+
.unwrap()
2341+
.nprobs(2)
2342+
.try_into_batch()
2343+
.await
2344+
.unwrap();
2345+
assert_eq!(results.num_rows(), num_non_null);
2346+
}
2347+
check_index(&dataset, num_non_null, dims).await;
2348+
2349+
// Append more data
2350+
let data = gen()
2351+
.col(
2352+
"vec",
2353+
array::rand_vec::<Float32Type>(Dimension::from(dims as u32)).with_random_nulls(0.5),
2354+
)
2355+
.into_batch_rows(RowCount::from(500))
2356+
.unwrap();
2357+
let num_non_null = data["vec"].len() - data["vec"].logical_null_count() + num_non_null;
2358+
let mut dataset = InsertBuilder::new(Arc::new(dataset))
2359+
.with_params(&WriteParams {
2360+
mode: WriteMode::Append,
2361+
..Default::default()
2362+
})
2363+
.execute(vec![data])
2364+
.await
2365+
.unwrap();
2366+
check_index(&dataset, num_non_null, dims).await;
2367+
2368+
// Optimize the index
2369+
dataset.optimize_indices(&Default::default()).await.unwrap();
2370+
check_index(&dataset, num_non_null, dims).await;
2371+
}
2372+
23032373
#[tokio::test]
23042374
async fn test_create_ivf_pq_cosine() {
23052375
let test_dir = tempdir().unwrap();

0 commit comments

Comments
 (0)