@@ -24,7 +24,7 @@ use lance_index::vector::quantizer::{
24
24
use lance_index:: vector:: storage:: STORAGE_METADATA_KEY ;
25
25
use lance_index:: vector:: v3:: shuffler:: IvfShufflerReader ;
26
26
use lance_index:: vector:: v3:: subindex:: SubIndexType ;
27
- use lance_index:: vector:: { VectorIndex , PART_ID_FIELD } ;
27
+ use lance_index:: vector:: { VectorIndex , LOSS_METADATA_KEY , PART_ID_FIELD } ;
28
28
use lance_index:: {
29
29
pb,
30
30
vector:: {
@@ -451,6 +451,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
451
451
Arc :: new ( self . store . clone ( ) ) ,
452
452
self . temp_dir . clone ( ) ,
453
453
vec ! [ 0 ; ivf. num_partitions( ) ] ,
454
+ 0.0 ,
454
455
) ) ) ;
455
456
return Ok ( self ) ;
456
457
}
@@ -474,7 +475,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
474
475
"dataset not set before building partitions" ,
475
476
location ! ( ) ,
476
477
) ) ?;
477
- let ivf = self . ivf . as_ref ( ) . ok_or ( Error :: invalid_input (
478
+ let ivf = self . ivf . as_mut ( ) . ok_or ( Error :: invalid_input (
478
479
"IVF not set before building partitions" ,
479
480
location ! ( ) ,
480
481
) ) ?;
@@ -503,22 +504,22 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
503
504
504
505
let dataset = Arc :: new ( dataset. clone ( ) ) ;
505
506
let reader = reader. clone ( ) ;
506
- let ivf = Arc :: new ( ivf. clone ( ) ) ;
507
+ let ivf_model = Arc :: new ( ivf. clone ( ) ) ;
507
508
let existing_indices = Arc :: new ( self . existing_indices . clone ( ) ) ;
508
509
let distance_type = self . distance_type ;
509
- let mut partition_sizes = vec ! [ ( 0 , 0 ) ; ivf . num_partitions( ) ] ;
510
+ let mut partition_sizes = vec ! [ ( 0 , 0 ) ; ivf_model . num_partitions( ) ] ;
510
511
let build_iter = partition_build_order. iter ( ) . map ( |& partition| {
511
512
let dataset = dataset. clone ( ) ;
512
513
let reader = reader. clone ( ) ;
513
514
let existing_indices = existing_indices. clone ( ) ;
514
515
let column = self . column . clone ( ) ;
515
516
let store = self . store . clone ( ) ;
516
517
let temp_dir = self . temp_dir . clone ( ) ;
517
- let ivf = ivf . clone ( ) ;
518
+ let ivf = ivf_model . clone ( ) ;
518
519
let quantizer = quantizer. clone ( ) ;
519
520
let sub_index_params = sub_index_params. clone ( ) ;
520
521
async move {
521
- let batches = Self :: take_partition_batches (
522
+ let ( batches, loss ) = Self :: take_partition_batches (
522
523
partition,
523
524
existing_indices. as_ref ( ) ,
524
525
reader. as_ref ( ) ,
@@ -530,7 +531,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
530
531
531
532
let num_rows = batches. iter ( ) . map ( |b| b. num_rows ( ) ) . sum :: < usize > ( ) ;
532
533
if num_rows == 0 {
533
- return Ok ( ( 0 , 0 ) ) ;
534
+ return Ok ( ( ( 0 , 0 ) , 0. 0) ) ;
534
535
}
535
536
let batch = arrow:: compute:: concat_batches ( & batches[ 0 ] . schema ( ) , batches. iter ( ) ) ?;
536
537
@@ -545,6 +546,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
545
546
partition,
546
547
)
547
548
. await
549
+ . map ( |res| ( res, loss) )
548
550
}
549
551
} ) ;
550
552
let results = stream:: iter ( build_iter)
@@ -553,9 +555,15 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
553
555
. boxed ( )
554
556
. await ?;
555
557
556
- for ( i, result) in results. into_iter ( ) . enumerate ( ) {
557
- partition_sizes[ partition_build_order[ i] ] = result;
558
+ let mut total_loss = 0.0 ;
559
+ for ( i, ( res, loss) ) in results. into_iter ( ) . enumerate ( ) {
560
+ total_loss += loss;
561
+ partition_sizes[ partition_build_order[ i] ] = res;
562
+ }
563
+ if let Some ( loss) = reader. total_loss ( ) {
564
+ total_loss += loss;
558
565
}
566
+ ivf. loss = Some ( total_loss) ;
559
567
560
568
self . partition_sizes = partition_sizes;
561
569
Ok ( self )
@@ -617,7 +625,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
617
625
dataset : & Arc < Dataset > ,
618
626
column : & str ,
619
627
store : & ObjectStore ,
620
- ) -> Result < Vec < RecordBatch > > {
628
+ ) -> Result < ( Vec < RecordBatch > , f64 ) > {
621
629
let mut batches = Vec :: new ( ) ;
622
630
for existing_index in existing_indices. iter ( ) {
623
631
let existing_index = existing_index
@@ -648,15 +656,23 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
648
656
batches. extend ( part_batches) ;
649
657
}
650
658
659
+ let mut loss = 0.0 ;
651
660
if reader. partition_size ( part_id) ? > 0 {
652
- let partition_data = reader. read_partition ( part_id) . await ?. ok_or ( Error :: io (
661
+ let mut partition_data = reader. read_partition ( part_id) . await ?. ok_or ( Error :: io (
653
662
format ! ( "partition {} is empty" , part_id) . as_str ( ) ,
654
663
location ! ( ) ,
655
664
) ) ?;
656
- batches. extend ( partition_data. try_collect :: < Vec < _ > > ( ) . await ?) ;
665
+ while let Some ( batch) = partition_data. try_next ( ) . await ? {
666
+ loss += batch
667
+ . metadata ( )
668
+ . get ( LOSS_METADATA_KEY )
669
+ . map ( |s| s. parse :: < f64 > ( ) . unwrap_or ( 0.0 ) )
670
+ . unwrap_or ( 0.0 ) ;
671
+ batches. push ( batch) ;
672
+ }
657
673
}
658
674
659
- Ok ( batches)
675
+ Ok ( ( batches, loss ) )
660
676
}
661
677
662
678
async fn merge_partitions ( & mut self ) -> Result < ( ) > {
0 commit comments