From 24923e9b95bbb3671bc09912028146f7e7941292 Mon Sep 17 00:00:00 2001
From: BubbleCal <bubble-cal@outlook.com>
Date: Wed, 18 Dec 2024 19:35:27 +0800
Subject: [PATCH 1/2] fix: panic when get stats from index over binary vectors

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
---
 rust/lance/src/index/vector/ivf.rs    |  6 +++
 rust/lance/src/index/vector/ivf/v2.rs | 56 +++++++++++++++++++--------
 2 files changed, 45 insertions(+), 17 deletions(-)

diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs
index 8b7fd6b62a..c20fb14062 100644
--- a/rust/lance/src/index/vector/ivf.rs
+++ b/rust/lance/src/index/vector/ivf.rs
@@ -729,6 +729,12 @@ fn centroids_to_vectors(centroids: &FixedSizeListArray) -> Result<Vec<Vec<f32>>>
                         .iter()
                         .map(|v| *v as f32)
                         .collect::<Vec<_>>()),
+                    DataType::UInt8 => Ok(row
+                        .as_primitive::<UInt8Type>()
+                        .values()
+                        .iter()
+                        .map(|v| *v as f32)
+                        .collect::<Vec<_>>()),
                     _ => Err(Error::Index {
                         message: format!(
                             "IVF centroids must be FixedSizeList of floating number, got: {}",
diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs
index a20282842c..636ee51c46 100644
--- a/rust/lance/src/index/vector/ivf/v2.rs
+++ b/rust/lance/src/index/vector/ivf/v2.rs
@@ -531,8 +531,9 @@ mod tests {
     use lance_index::vector::ivf::IvfBuildParams;
     use lance_index::vector::pq::PQBuildParams;
     use lance_index::vector::sq::builder::SQBuildParams;
+    use lance_index::vector::v3::subindex::SubIndexType;
     use lance_index::vector::DIST_COL;
-    use lance_index::{DatasetIndexExt, IndexType};
+    use lance_index::{DatasetIndexExt, IndexParams, IndexType};
     use lance_linalg::distance::hamming::hamming;
     use lance_linalg::distance::DistanceType;
     use lance_testing::datagen::generate_random_array_with_range;
@@ -805,24 +806,30 @@ mod tests {
         test_index(params, nlist, recall_requirement).await;
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn test_index_stats() {
+    async fn test_index_stats(
+        #[values(
+            (VectorIndexParams::ivf_flat(4, DistanceType::Hamming), IndexType::IvfFlat),
+            (VectorIndexParams::ivf_pq(4, 8, 8, DistanceType::L2, 10), IndexType::IvfPq),
+            (VectorIndexParams::with_ivf_hnsw_sq_params(
+                DistanceType::Cosine,
+                IvfBuildParams::new(4),
+                Default::default(),
+                Default::default()
+            ), IndexType::IvfHnswSq),
+        )]
+        index: (VectorIndexParams, IndexType),
+    ) {
+        let (params, index_type) = index;
         let test_dir = tempdir().unwrap();
         let test_uri = test_dir.path().to_str().unwrap();
 
         let nlist = 4;
-        let (mut dataset, _) = generate_test_dataset::<Float32Type>(test_uri, 0.0..1.0).await;
-
-        let ivf_params = IvfBuildParams::new(nlist);
-        let sq_params = SQBuildParams::default();
-        let hnsw_params = HnswBuildParams::default();
-        let params = VectorIndexParams::with_ivf_hnsw_sq_params(
-            DistanceType::L2,
-            ivf_params,
-            hnsw_params,
-            sq_params,
-        );
-
+        let (mut dataset, _) = match params.metric_type {
+            DistanceType::Hamming => generate_test_dataset::<UInt8Type>(test_uri, 0..2).await,
+            _ => generate_test_dataset::<Float32Type>(test_uri, 0.0..1.0).await,
+        };
         dataset
             .create_index(
                 &["vector"],
@@ -837,14 +844,29 @@ mod tests {
         let stats = dataset.index_statistics("test_index").await.unwrap();
         let stats: serde_json::Value = serde_json::from_str(stats.as_str()).unwrap();
 
-        assert_eq!(stats["index_type"].as_str().unwrap(), "IVF_HNSW_SQ");
+        assert_eq!(
+            stats["index_type"].as_str().unwrap(),
+            index_type.to_string()
+        );
         for index in stats["indices"].as_array().unwrap() {
-            assert_eq!(index["index_type"].as_str().unwrap(), "IVF_HNSW_SQ");
+            assert_eq!(
+                index["index_type"].as_str().unwrap(),
+                index_type.to_string()
+            );
             assert_eq!(
                 index["num_partitions"].as_number().unwrap(),
                 &serde_json::Number::from(nlist)
             );
-            assert_eq!(index["sub_index"]["index_type"].as_str().unwrap(), "HNSW");
+
+            let sub_index = match index_type {
+                IndexType::IvfHnswPq | IndexType::IvfHnswSq => "HNSW",
+                IndexType::IvfPq => "PQ",
+                _ => "FLAT",
+            };
+            assert_eq!(
+                index["sub_index"]["index_type"].as_str().unwrap(),
+                sub_index
+            );
         }
     }
 

From a394cce80dd495dcd7ce97c45c451e2443a2f2ad Mon Sep 17 00:00:00 2001
From: BubbleCal <bubble-cal@outlook.com>
Date: Wed, 18 Dec 2024 19:48:58 +0800
Subject: [PATCH 2/2] fmt

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
---
 rust/lance/src/index/vector/ivf/v2.rs | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs
index 636ee51c46..df96885615 100644
--- a/rust/lance/src/index/vector/ivf/v2.rs
+++ b/rust/lance/src/index/vector/ivf/v2.rs
@@ -531,9 +531,8 @@ mod tests {
     use lance_index::vector::ivf::IvfBuildParams;
     use lance_index::vector::pq::PQBuildParams;
     use lance_index::vector::sq::builder::SQBuildParams;
-    use lance_index::vector::v3::subindex::SubIndexType;
     use lance_index::vector::DIST_COL;
-    use lance_index::{DatasetIndexExt, IndexParams, IndexType};
+    use lance_index::{DatasetIndexExt, IndexType};
     use lance_linalg::distance::hamming::hamming;
     use lance_linalg::distance::DistanceType;
     use lance_testing::datagen::generate_random_array_with_range;