From 24923e9b95bbb3671bc09912028146f7e7941292 Mon Sep 17 00:00:00 2001 From: BubbleCal <bubble-cal@outlook.com> Date: Wed, 18 Dec 2024 19:35:27 +0800 Subject: [PATCH 1/2] fix: panic when get stats from index over binary vectors Signed-off-by: BubbleCal <bubble-cal@outlook.com> --- rust/lance/src/index/vector/ivf.rs | 6 +++ rust/lance/src/index/vector/ivf/v2.rs | 56 +++++++++++++++++++-------- 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 8b7fd6b62a..c20fb14062 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -729,6 +729,12 @@ fn centroids_to_vectors(centroids: &FixedSizeListArray) -> Result<Vec<Vec<f32>>> .iter() .map(|v| *v as f32) .collect::<Vec<_>>()), + DataType::UInt8 => Ok(row + .as_primitive::<UInt8Type>() + .values() + .iter() + .map(|v| *v as f32) + .collect::<Vec<_>>()), _ => Err(Error::Index { message: format!( "IVF centroids must be FixedSizeList of floating number, got: {}", diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index a20282842c..636ee51c46 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -531,8 +531,9 @@ mod tests { use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; use lance_index::vector::sq::builder::SQBuildParams; + use lance_index::vector::v3::subindex::SubIndexType; use lance_index::vector::DIST_COL; - use lance_index::{DatasetIndexExt, IndexType}; + use lance_index::{DatasetIndexExt, IndexParams, IndexType}; use lance_linalg::distance::hamming::hamming; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_range; @@ -805,24 +806,30 @@ mod tests { test_index(params, nlist, recall_requirement).await; } + #[rstest] #[tokio::test] - async fn test_index_stats() { + async fn test_index_stats( + #[values( + (VectorIndexParams::ivf_flat(4, DistanceType::Hamming), IndexType::IvfFlat), + (VectorIndexParams::ivf_pq(4, 8, 8, DistanceType::L2, 10), IndexType::IvfPq), + (VectorIndexParams::with_ivf_hnsw_sq_params( + DistanceType::Cosine, + IvfBuildParams::new(4), + Default::default(), + Default::default() + ), IndexType::IvfHnswSq), + )] + index: (VectorIndexParams, IndexType), + ) { + let (params, index_type) = index; let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); let nlist = 4; - let (mut dataset, _) = generate_test_dataset::<Float32Type>(test_uri, 0.0..1.0).await; - - let ivf_params = IvfBuildParams::new(nlist); - let sq_params = SQBuildParams::default(); - let hnsw_params = HnswBuildParams::default(); - let params = VectorIndexParams::with_ivf_hnsw_sq_params( - DistanceType::L2, - ivf_params, - hnsw_params, - sq_params, - ); - + let (mut dataset, _) = match params.metric_type { + DistanceType::Hamming => generate_test_dataset::<UInt8Type>(test_uri, 0..2).await, + _ => generate_test_dataset::<Float32Type>(test_uri, 0.0..1.0).await, + }; dataset .create_index( &["vector"], @@ -837,14 +844,29 @@ mod tests { let stats = dataset.index_statistics("test_index").await.unwrap(); let stats: serde_json::Value = serde_json::from_str(stats.as_str()).unwrap(); - assert_eq!(stats["index_type"].as_str().unwrap(), "IVF_HNSW_SQ"); + assert_eq!( + stats["index_type"].as_str().unwrap(), + index_type.to_string() + ); for index in stats["indices"].as_array().unwrap() { - assert_eq!(index["index_type"].as_str().unwrap(), "IVF_HNSW_SQ"); + assert_eq!( + index["index_type"].as_str().unwrap(), + index_type.to_string() + ); assert_eq!( index["num_partitions"].as_number().unwrap(), &serde_json::Number::from(nlist) ); - assert_eq!(index["sub_index"]["index_type"].as_str().unwrap(), "HNSW"); + + let sub_index = match index_type { + IndexType::IvfHnswPq | IndexType::IvfHnswSq => "HNSW", + IndexType::IvfPq => "PQ", + _ => "FLAT", + }; + assert_eq!( + index["sub_index"]["index_type"].as_str().unwrap(), + sub_index + ); } } From a394cce80dd495dcd7ce97c45c451e2443a2f2ad Mon Sep 17 00:00:00 2001 From: BubbleCal <bubble-cal@outlook.com> Date: Wed, 18 Dec 2024 19:48:58 +0800 Subject: [PATCH 2/2] fmt Signed-off-by: BubbleCal <bubble-cal@outlook.com> --- rust/lance/src/index/vector/ivf/v2.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 636ee51c46..df96885615 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -531,9 +531,8 @@ mod tests { use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; use lance_index::vector::sq::builder::SQBuildParams; - use lance_index::vector::v3::subindex::SubIndexType; use lance_index::vector::DIST_COL; - use lance_index::{DatasetIndexExt, IndexParams, IndexType}; + use lance_index::{DatasetIndexExt, IndexType}; use lance_linalg::distance::hamming::hamming; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_range;