@@ -2223,42 +2223,102 @@ mod tests {
2223
2223
. await ;
2224
2224
}
2225
2225
2226
+ struct TestPqParams {
2227
+ num_sub_vectors : usize ,
2228
+ num_bits : usize ,
2229
+ }
2230
+
2231
+ impl TestPqParams {
2232
+ fn small ( ) -> Self {
2233
+ Self {
2234
+ num_sub_vectors : 2 ,
2235
+ num_bits : 8 ,
2236
+ }
2237
+ }
2238
+ }
2239
+
2240
+ // Clippy doesn't like that all start with Ivf but we might have some in the future
2241
+ // that _don't_ start with Ivf so I feel it is meaningful to keep the prefix
2242
+ #[ allow( clippy:: enum_variant_names) ]
2243
+ enum TestIndexType {
2244
+ IvfPq { pq : TestPqParams } ,
2245
+ IvfHnswPq { pq : TestPqParams , num_edges : usize } ,
2246
+ IvfHnswSq { num_edges : usize } ,
2247
+ IvfFlat ,
2248
+ }
2249
+
2250
+ struct CreateIndexCase {
2251
+ metric_type : MetricType ,
2252
+ num_partitions : usize ,
2253
+ dimension : usize ,
2254
+ index_type : TestIndexType ,
2255
+ }
2256
+
2226
2257
// We test L2 and Dot, because L2 PQ uses residuals while Dot doesn't,
2227
2258
// so they have slightly different code paths.
2228
2259
#[ tokio:: test]
2229
2260
#[ rstest]
2230
- #[ case:: ivf_pq_l2( VectorIndexParams :: with_ivf_pq_params(
2231
- MetricType :: L2 ,
2232
- IvfBuildParams :: new( 2 ) ,
2233
- PQBuildParams :: new( 2 , 8 ) ,
2234
- ) ) ]
2235
- #[ case:: ivf_pq_dot( VectorIndexParams :: with_ivf_pq_params(
2236
- MetricType :: Dot ,
2237
- IvfBuildParams :: new( 2 ) ,
2238
- PQBuildParams :: new( 2 , 8 ) ,
2239
- ) ) ]
2240
- #[ case:: ivf_flat( VectorIndexParams :: ivf_flat( 1 , MetricType :: Dot ) ) ]
2241
- #[ case:: ivf_hnsw_pq( VectorIndexParams :: with_ivf_hnsw_pq_params(
2242
- MetricType :: Dot ,
2243
- IvfBuildParams :: new( 2 ) ,
2244
- HnswBuildParams :: default ( ) . num_edges( 100 ) ,
2245
- PQBuildParams :: new( 2 , 8 )
2246
- ) ) ]
2247
- #[ case:: ivf_hnsw_sq( VectorIndexParams :: with_ivf_hnsw_sq_params(
2248
- MetricType :: Dot ,
2249
- IvfBuildParams :: new( 2 ) ,
2250
- HnswBuildParams :: default ( ) . num_edges( 100 ) ,
2251
- SQBuildParams :: default ( )
2252
- ) ) ]
2261
+ #[ case:: ivf_pq_l2( CreateIndexCase {
2262
+ metric_type: MetricType :: L2 ,
2263
+ num_partitions: 2 ,
2264
+ dimension: 16 ,
2265
+ index_type: TestIndexType :: IvfPq { pq: TestPqParams :: small( ) } ,
2266
+ } ) ]
2267
+ #[ case:: ivf_pq_dot( CreateIndexCase {
2268
+ metric_type: MetricType :: Dot ,
2269
+ num_partitions: 2 ,
2270
+ dimension: 2000 ,
2271
+ index_type: TestIndexType :: IvfPq { pq: TestPqParams :: small( ) } ,
2272
+ } ) ]
2273
+ #[ case:: ivf_flat( CreateIndexCase { num_partitions: 1 , metric_type: MetricType :: Dot , dimension: 16 , index_type: TestIndexType :: IvfFlat } ) ]
2274
+ #[ case:: ivf_hnsw_pq( CreateIndexCase {
2275
+ num_partitions: 2 ,
2276
+ metric_type: MetricType :: Dot ,
2277
+ dimension: 16 ,
2278
+ index_type: TestIndexType :: IvfHnswPq { pq: TestPqParams :: small( ) , num_edges: 100 } ,
2279
+ } ) ]
2280
+ #[ case:: ivf_hnsw_sq( CreateIndexCase {
2281
+ metric_type: MetricType :: Dot ,
2282
+ num_partitions: 2 ,
2283
+ dimension: 16 ,
2284
+ index_type: TestIndexType :: IvfHnswSq { num_edges: 100 } ,
2285
+ } ) ]
2253
2286
async fn test_create_index_nulls (
2254
- #[ case] mut index_params : VectorIndexParams ,
2287
+ #[ case] test_case : CreateIndexCase ,
2255
2288
#[ values( IndexFileVersion :: Legacy , IndexFileVersion :: V3 ) ] index_version : IndexFileVersion ,
2256
2289
) {
2290
+ let mut index_params = match test_case. index_type {
2291
+ TestIndexType :: IvfPq { pq } => VectorIndexParams :: with_ivf_pq_params (
2292
+ test_case. metric_type ,
2293
+ IvfBuildParams :: new ( test_case. num_partitions ) ,
2294
+ PQBuildParams :: new ( pq. num_sub_vectors , pq. num_bits ) ,
2295
+ ) ,
2296
+ TestIndexType :: IvfHnswPq { pq, num_edges } => {
2297
+ VectorIndexParams :: with_ivf_hnsw_pq_params (
2298
+ test_case. metric_type ,
2299
+ IvfBuildParams :: new ( test_case. num_partitions ) ,
2300
+ HnswBuildParams :: default ( ) . num_edges ( num_edges) ,
2301
+ PQBuildParams :: new ( pq. num_sub_vectors , pq. num_bits ) ,
2302
+ )
2303
+ }
2304
+ TestIndexType :: IvfFlat => {
2305
+ VectorIndexParams :: ivf_flat ( test_case. num_partitions , test_case. metric_type )
2306
+ }
2307
+ TestIndexType :: IvfHnswSq { num_edges } => VectorIndexParams :: with_ivf_hnsw_sq_params (
2308
+ test_case. metric_type ,
2309
+ IvfBuildParams :: new ( test_case. num_partitions ) ,
2310
+ HnswBuildParams :: default ( ) . num_edges ( num_edges) ,
2311
+ SQBuildParams :: default ( ) ,
2312
+ ) ,
2313
+ } ;
2257
2314
index_params. version ( index_version) ;
2258
2315
2259
2316
let nrows = 2_000 ;
2260
2317
let data = gen ( )
2261
- . col ( "vec" , array:: rand_vec :: < Float32Type > ( Dimension :: from ( 16 ) ) )
2318
+ . col (
2319
+ "vec" ,
2320
+ array:: rand_vec :: < Float32Type > ( Dimension :: from ( test_case. dimension as u32 ) ) ,
2321
+ )
2262
2322
. into_batch_rows ( RowCount :: from ( nrows) )
2263
2323
. unwrap ( ) ;
2264
2324
@@ -2287,7 +2347,9 @@ mod tests {
2287
2347
. await
2288
2348
. unwrap ( ) ;
2289
2349
2290
- let query = vec ! [ 0.0 ; 16 ] . into_iter ( ) . collect :: < Float32Array > ( ) ;
2350
+ let query = vec ! [ 0.0 ; test_case. dimension]
2351
+ . into_iter ( )
2352
+ . collect :: < Float32Array > ( ) ;
2291
2353
let results = dataset
2292
2354
. scan ( )
2293
2355
. nearest ( "vec" , & query, 2_000 )
0 commit comments