Skip to content

Commit ae36abe

Browse files
authoredDec 18, 2024··
fix: panic when get stats from index over binary vectors (#3267)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent d038e34 commit ae36abe

File tree

2 files changed

+43
-16
lines changed

2 files changed

+43
-16
lines changed
 

‎rust/lance/src/index/vector/ivf.rs

+6
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,12 @@ fn centroids_to_vectors(centroids: &FixedSizeListArray) -> Result<Vec<Vec<f32>>>
729729
.iter()
730730
.map(|v| *v as f32)
731731
.collect::<Vec<_>>()),
732+
DataType::UInt8 => Ok(row
733+
.as_primitive::<UInt8Type>()
734+
.values()
735+
.iter()
736+
.map(|v| *v as f32)
737+
.collect::<Vec<_>>()),
732738
_ => Err(Error::Index {
733739
message: format!(
734740
"IVF centroids must be FixedSizeList of floating number, got: {}",

‎rust/lance/src/index/vector/ivf/v2.rs

+37-16
Original file line numberDiff line numberDiff line change
@@ -805,24 +805,30 @@ mod tests {
805805
test_index(params, nlist, recall_requirement).await;
806806
}
807807

808+
#[rstest]
808809
#[tokio::test]
809-
async fn test_index_stats() {
810+
async fn test_index_stats(
811+
#[values(
812+
(VectorIndexParams::ivf_flat(4, DistanceType::Hamming), IndexType::IvfFlat),
813+
(VectorIndexParams::ivf_pq(4, 8, 8, DistanceType::L2, 10), IndexType::IvfPq),
814+
(VectorIndexParams::with_ivf_hnsw_sq_params(
815+
DistanceType::Cosine,
816+
IvfBuildParams::new(4),
817+
Default::default(),
818+
Default::default()
819+
), IndexType::IvfHnswSq),
820+
)]
821+
index: (VectorIndexParams, IndexType),
822+
) {
823+
let (params, index_type) = index;
810824
let test_dir = tempdir().unwrap();
811825
let test_uri = test_dir.path().to_str().unwrap();
812826

813827
let nlist = 4;
814-
let (mut dataset, _) = generate_test_dataset::<Float32Type>(test_uri, 0.0..1.0).await;
815-
816-
let ivf_params = IvfBuildParams::new(nlist);
817-
let sq_params = SQBuildParams::default();
818-
let hnsw_params = HnswBuildParams::default();
819-
let params = VectorIndexParams::with_ivf_hnsw_sq_params(
820-
DistanceType::L2,
821-
ivf_params,
822-
hnsw_params,
823-
sq_params,
824-
);
825-
828+
let (mut dataset, _) = match params.metric_type {
829+
DistanceType::Hamming => generate_test_dataset::<UInt8Type>(test_uri, 0..2).await,
830+
_ => generate_test_dataset::<Float32Type>(test_uri, 0.0..1.0).await,
831+
};
826832
dataset
827833
.create_index(
828834
&["vector"],
@@ -837,14 +843,29 @@ mod tests {
837843
let stats = dataset.index_statistics("test_index").await.unwrap();
838844
let stats: serde_json::Value = serde_json::from_str(stats.as_str()).unwrap();
839845

840-
assert_eq!(stats["index_type"].as_str().unwrap(), "IVF_HNSW_SQ");
846+
assert_eq!(
847+
stats["index_type"].as_str().unwrap(),
848+
index_type.to_string()
849+
);
841850
for index in stats["indices"].as_array().unwrap() {
842-
assert_eq!(index["index_type"].as_str().unwrap(), "IVF_HNSW_SQ");
851+
assert_eq!(
852+
index["index_type"].as_str().unwrap(),
853+
index_type.to_string()
854+
);
843855
assert_eq!(
844856
index["num_partitions"].as_number().unwrap(),
845857
&serde_json::Number::from(nlist)
846858
);
847-
assert_eq!(index["sub_index"]["index_type"].as_str().unwrap(), "HNSW");
859+
860+
let sub_index = match index_type {
861+
IndexType::IvfHnswPq | IndexType::IvfHnswSq => "HNSW",
862+
IndexType::IvfPq => "PQ",
863+
_ => "FLAT",
864+
};
865+
assert_eq!(
866+
index["sub_index"]["index_type"].as_str().unwrap(),
867+
sub_index
868+
);
848869
}
849870
}
850871

0 commit comments

Comments
 (0)
Please sign in to comment.