@@ -805,24 +805,30 @@ mod tests {
805
805
test_index ( params, nlist, recall_requirement) . await ;
806
806
}
807
807
808
+ #[ rstest]
808
809
#[ tokio:: test]
809
- async fn test_index_stats ( ) {
810
+ async fn test_index_stats (
811
+ #[ values(
812
+ ( VectorIndexParams :: ivf_flat( 4 , DistanceType :: Hamming ) , IndexType :: IvfFlat ) ,
813
+ ( VectorIndexParams :: ivf_pq( 4 , 8 , 8 , DistanceType :: L2 , 10 ) , IndexType :: IvfPq ) ,
814
+ ( VectorIndexParams :: with_ivf_hnsw_sq_params(
815
+ DistanceType :: Cosine ,
816
+ IvfBuildParams :: new( 4 ) ,
817
+ Default :: default ( ) ,
818
+ Default :: default ( )
819
+ ) , IndexType :: IvfHnswSq ) ,
820
+ ) ]
821
+ index : ( VectorIndexParams , IndexType ) ,
822
+ ) {
823
+ let ( params, index_type) = index;
810
824
let test_dir = tempdir ( ) . unwrap ( ) ;
811
825
let test_uri = test_dir. path ( ) . to_str ( ) . unwrap ( ) ;
812
826
813
827
let nlist = 4 ;
814
- let ( mut dataset, _) = generate_test_dataset :: < Float32Type > ( test_uri, 0.0 ..1.0 ) . await ;
815
-
816
- let ivf_params = IvfBuildParams :: new ( nlist) ;
817
- let sq_params = SQBuildParams :: default ( ) ;
818
- let hnsw_params = HnswBuildParams :: default ( ) ;
819
- let params = VectorIndexParams :: with_ivf_hnsw_sq_params (
820
- DistanceType :: L2 ,
821
- ivf_params,
822
- hnsw_params,
823
- sq_params,
824
- ) ;
825
-
828
+ let ( mut dataset, _) = match params. metric_type {
829
+ DistanceType :: Hamming => generate_test_dataset :: < UInt8Type > ( test_uri, 0 ..2 ) . await ,
830
+ _ => generate_test_dataset :: < Float32Type > ( test_uri, 0.0 ..1.0 ) . await ,
831
+ } ;
826
832
dataset
827
833
. create_index (
828
834
& [ "vector" ] ,
@@ -837,14 +843,29 @@ mod tests {
837
843
let stats = dataset. index_statistics ( "test_index" ) . await . unwrap ( ) ;
838
844
let stats: serde_json:: Value = serde_json:: from_str ( stats. as_str ( ) ) . unwrap ( ) ;
839
845
840
- assert_eq ! ( stats[ "index_type" ] . as_str( ) . unwrap( ) , "IVF_HNSW_SQ" ) ;
846
+ assert_eq ! (
847
+ stats[ "index_type" ] . as_str( ) . unwrap( ) ,
848
+ index_type. to_string( )
849
+ ) ;
841
850
for index in stats[ "indices" ] . as_array ( ) . unwrap ( ) {
842
- assert_eq ! ( index[ "index_type" ] . as_str( ) . unwrap( ) , "IVF_HNSW_SQ" ) ;
851
+ assert_eq ! (
852
+ index[ "index_type" ] . as_str( ) . unwrap( ) ,
853
+ index_type. to_string( )
854
+ ) ;
843
855
assert_eq ! (
844
856
index[ "num_partitions" ] . as_number( ) . unwrap( ) ,
845
857
& serde_json:: Number :: from( nlist)
846
858
) ;
847
- assert_eq ! ( index[ "sub_index" ] [ "index_type" ] . as_str( ) . unwrap( ) , "HNSW" ) ;
859
+
860
+ let sub_index = match index_type {
861
+ IndexType :: IvfHnswPq | IndexType :: IvfHnswSq => "HNSW" ,
862
+ IndexType :: IvfPq => "PQ" ,
863
+ _ => "FLAT" ,
864
+ } ;
865
+ assert_eq ! (
866
+ index[ "sub_index" ] [ "index_type" ] . as_str( ) . unwrap( ) ,
867
+ sub_index
868
+ ) ;
848
869
}
849
870
}
850
871
0 commit comments