Skip to content

Commit ebd7efc

Browse files
committed
more tests
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent de30b35 commit ebd7efc

File tree

5 files changed

+96
-46
lines changed

5 files changed

+96
-46
lines changed

rust/lance-index/src/vector.rs

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub const INDEX_UUID_COLUMN: &str = "__index_uuid";
4949
pub const PART_ID_COLUMN: &str = "__ivf_part_id";
5050
pub const PQ_CODE_COLUMN: &str = "__pq_code";
5151
pub const SQ_CODE_COLUMN: &str = "__sq_code";
52+
pub const LOSS_METADATA_KEY: &str = "_loss";
5253

5354
lazy_static! {
5455
pub static ref VECTOR_RESULT_SCHEMA: arrow_schema::SchemaRef =

rust/lance-index/src/vector/ivf/transform.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ use lance_linalg::distance::DistanceType;
2020
use lance_linalg::kmeans::compute_partitions_arrow_array;
2121

2222
use crate::vector::transform::Transformer;
23+
use crate::vector::LOSS_METADATA_KEY;
2324

2425
use super::PART_ID_COLUMN;
2526

2627
/// PartitionTransformer
2728
///
2829
/// It computes the partition ID for each row from the input batch,
29-
/// and adds the partition ID as a new column to the batch.
30+
/// and adds the partition ID as a new column to the batch,
31+
/// and adds the loss as a metadata to the batch.
3032
///
3133
/// If the partition ID ("__ivf_part_id") column is already present in the Batch,
3234
/// this transform is a Noop.
@@ -75,7 +77,7 @@ impl Transformer for PartitionTransformer {
7577
.column_by_name(&self.input_column)
7678
.ok_or_else(|| lance_core::Error::Index {
7779
message: format!(
78-
"IvfTransformer: column {} not found in the RecordBatch",
80+
"PartitionTransformer: column {} not found in the RecordBatch",
7981
self.input_column
8082
),
8183
location: location!(),
@@ -85,7 +87,7 @@ impl Transformer for PartitionTransformer {
8587
.as_fixed_size_list_opt()
8688
.ok_or_else(|| lance_core::Error::Index {
8789
message: format!(
88-
"IvfTransformer: column {} is not a FixedSizeListArray: {}",
90+
"PartitionTransformer: column {} is not a FixedSizeListArray: {}",
8991
self.input_column,
9092
arr.data_type(),
9193
),
@@ -98,7 +100,7 @@ impl Transformer for PartitionTransformer {
98100
let field = Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), true);
99101
Ok(batch
100102
.try_with_column(field, Arc::new(part_ids))?
101-
.add_metadata("loss".to_owned(), loss.to_string())?)
103+
.add_metadata(LOSS_METADATA_KEY.to_owned(), loss.to_string())?)
102104
}
103105
}
104106

rust/lance-index/src/vector/v3/shuffler.rs

+29-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use object_store::path::Path;
3131
use snafu::location;
3232
use tokio::sync::Mutex;
3333

34-
use crate::vector::PART_ID_COLUMN;
34+
use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN};
3535

3636
#[async_trait::async_trait]
3737
/// A reader that can read the shuffled partitions.
@@ -46,6 +46,12 @@ pub trait ShuffleReader: Send + Sync {
4646

4747
/// Get the size of the partition by partition_id
4848
fn partition_size(&self, partition_id: usize) -> Result<usize>;
49+
50+
/// Get the total loss,
51+
/// if the loss is not available, return None,
52+
/// in such case, the caller should sum up the losses from each batch's metadata.
53+
/// Must be called after all partitions are read.
54+
fn total_loss(&self) -> Option<f64>;
4955
}
5056

5157
#[async_trait::async_trait]
@@ -105,6 +111,12 @@ impl Shuffler for IvfShuffler {
105111
spawn_cpu(move || {
106112
let batch = batch?;
107113

114+
let loss = batch
115+
.metadata()
116+
.get(LOSS_METADATA_KEY)
117+
.map(|s| s.parse::<f64>().unwrap_or_default())
118+
.unwrap_or_default();
119+
108120
let part_ids: &UInt32Array = batch
109121
.column_by_name(PART_ID_COLUMN)
110122
.expect("Partition ID column not found")
@@ -134,7 +146,7 @@ impl Shuffler for IvfShuffler {
134146
start = end;
135147
}
136148

137-
Ok::<Vec<Vec<RecordBatch>>, Error>(partition_buffers)
149+
Ok::<(Vec<Vec<RecordBatch>>, f64), Error>((partition_buffers, loss))
138150
})
139151
})
140152
.buffered(get_num_compute_intensive_cpus());
@@ -146,8 +158,10 @@ impl Shuffler for IvfShuffler {
146158
.collect::<Vec<_>>();
147159

148160
let mut counter = 0;
161+
let mut total_loss = 0.0;
149162
while let Some(shuffled) = parallel_sort_stream.next().await {
150-
let shuffled = shuffled?;
163+
let (shuffled, loss) = shuffled?;
164+
total_loss += loss;
151165

152166
for (part_id, batches) in shuffled.into_iter().enumerate() {
153167
let part_batches = &mut partition_buffers[part_id];
@@ -218,6 +232,7 @@ impl Shuffler for IvfShuffler {
218232
self.object_store.clone(),
219233
self.output_dir.clone(),
220234
partition_sizes,
235+
total_loss,
221236
)))
222237
}
223238
}
@@ -226,20 +241,23 @@ pub struct IvfShufflerReader {
226241
scheduler: Arc<ScanScheduler>,
227242
output_dir: Path,
228243
partition_sizes: Vec<usize>,
244+
loss: f64,
229245
}
230246

231247
impl IvfShufflerReader {
232248
pub fn new(
233249
object_store: Arc<ObjectStore>,
234250
output_dir: Path,
235251
partition_sizes: Vec<usize>,
252+
loss: f64,
236253
) -> Self {
237254
let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
238255
let scheduler = ScanScheduler::new(object_store, scheduler_config);
239256
Self {
240257
scheduler,
241258
output_dir,
242259
partition_sizes,
260+
loss,
243261
}
244262
}
245263
}
@@ -275,6 +293,10 @@ impl ShuffleReader for IvfShufflerReader {
275293
fn partition_size(&self, partition_id: usize) -> Result<usize> {
276294
Ok(self.partition_sizes[partition_id])
277295
}
296+
297+
fn total_loss(&self) -> Option<f64> {
298+
Some(self.loss)
299+
}
278300
}
279301

280302
pub struct SinglePartitionReader {
@@ -311,4 +333,8 @@ impl ShuffleReader for SinglePartitionReader {
311333
// so we just return 1 here
312334
Ok(1)
313335
}
336+
337+
fn total_loss(&self) -> Option<f64> {
338+
None
339+
}
314340
}

rust/lance/src/index/vector/builder.rs

+29-13
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use lance_index::vector::quantizer::{
2424
use lance_index::vector::storage::STORAGE_METADATA_KEY;
2525
use lance_index::vector::v3::shuffler::IvfShufflerReader;
2626
use lance_index::vector::v3::subindex::SubIndexType;
27-
use lance_index::vector::{VectorIndex, PART_ID_FIELD};
27+
use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PART_ID_FIELD};
2828
use lance_index::{
2929
pb,
3030
vector::{
@@ -451,6 +451,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
451451
Arc::new(self.store.clone()),
452452
self.temp_dir.clone(),
453453
vec![0; ivf.num_partitions()],
454+
0.0,
454455
)));
455456
return Ok(self);
456457
}
@@ -474,7 +475,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
474475
"dataset not set before building partitions",
475476
location!(),
476477
))?;
477-
let ivf = self.ivf.as_ref().ok_or(Error::invalid_input(
478+
let ivf = self.ivf.as_mut().ok_or(Error::invalid_input(
478479
"IVF not set before building partitions",
479480
location!(),
480481
))?;
@@ -503,22 +504,22 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
503504

504505
let dataset = Arc::new(dataset.clone());
505506
let reader = reader.clone();
506-
let ivf = Arc::new(ivf.clone());
507+
let ivf_model = Arc::new(ivf.clone());
507508
let existing_indices = Arc::new(self.existing_indices.clone());
508509
let distance_type = self.distance_type;
509-
let mut partition_sizes = vec![(0, 0); ivf.num_partitions()];
510+
let mut partition_sizes = vec![(0, 0); ivf_model.num_partitions()];
510511
let build_iter = partition_build_order.iter().map(|&partition| {
511512
let dataset = dataset.clone();
512513
let reader = reader.clone();
513514
let existing_indices = existing_indices.clone();
514515
let column = self.column.clone();
515516
let store = self.store.clone();
516517
let temp_dir = self.temp_dir.clone();
517-
let ivf = ivf.clone();
518+
let ivf = ivf_model.clone();
518519
let quantizer = quantizer.clone();
519520
let sub_index_params = sub_index_params.clone();
520521
async move {
521-
let batches = Self::take_partition_batches(
522+
let (batches, loss) = Self::take_partition_batches(
522523
partition,
523524
existing_indices.as_ref(),
524525
reader.as_ref(),
@@ -530,7 +531,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
530531

531532
let num_rows = batches.iter().map(|b| b.num_rows()).sum::<usize>();
532533
if num_rows == 0 {
533-
return Ok((0, 0));
534+
return Ok(((0, 0), 0.0));
534535
}
535536
let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?;
536537

@@ -545,6 +546,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
545546
partition,
546547
)
547548
.await
549+
.map(|res| (res, loss))
548550
}
549551
});
550552
let results = stream::iter(build_iter)
@@ -553,9 +555,15 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
553555
.boxed()
554556
.await?;
555557

556-
for (i, result) in results.into_iter().enumerate() {
557-
partition_sizes[partition_build_order[i]] = result;
558+
let mut total_loss = 0.0;
559+
for (i, (res, loss)) in results.into_iter().enumerate() {
560+
total_loss += loss;
561+
partition_sizes[partition_build_order[i]] = res;
562+
}
563+
if let Some(loss) = reader.total_loss() {
564+
total_loss += loss;
558565
}
566+
ivf.loss = Some(total_loss);
559567

560568
self.partition_sizes = partition_sizes;
561569
Ok(self)
@@ -617,7 +625,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
617625
dataset: &Arc<Dataset>,
618626
column: &str,
619627
store: &ObjectStore,
620-
) -> Result<Vec<RecordBatch>> {
628+
) -> Result<(Vec<RecordBatch>, f64)> {
621629
let mut batches = Vec::new();
622630
for existing_index in existing_indices.iter() {
623631
let existing_index = existing_index
@@ -648,15 +656,23 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
648656
batches.extend(part_batches);
649657
}
650658

659+
let mut loss = 0.0;
651660
if reader.partition_size(part_id)? > 0 {
652-
let partition_data = reader.read_partition(part_id).await?.ok_or(Error::io(
661+
let mut partition_data = reader.read_partition(part_id).await?.ok_or(Error::io(
653662
format!("partition {} is empty", part_id).as_str(),
654663
location!(),
655664
))?;
656-
batches.extend(partition_data.try_collect::<Vec<_>>().await?);
665+
while let Some(batch) = partition_data.try_next().await? {
666+
loss += batch
667+
.metadata()
668+
.get(LOSS_METADATA_KEY)
669+
.map(|s| s.parse::<f64>().unwrap_or(0.0))
670+
.unwrap_or(0.0);
671+
batches.push(batch);
672+
}
657673
}
658674

659-
Ok(batches)
675+
Ok((batches, loss))
660676
}
661677

662678
async fn merge_partitions(&mut self) -> Result<()> {

rust/lance/src/index/vector/ivf/v2.rs

+31-26
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,8 @@ mod tests {
614614
use arrow::datatypes::{UInt64Type, UInt8Type};
615615
use arrow::{array::AsArray, datatypes::Float32Type};
616616
use arrow_array::{
617-
Array, ArrayRef, ArrowPrimitiveType, FixedSizeListArray, ListArray, RecordBatch,
618-
RecordBatchIterator, UInt64Array,
617+
Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, FixedSizeListArray, ListArray,
618+
RecordBatch, RecordBatchIterator, UInt64Array,
619619
};
620620
use arrow_buffer::OffsetBuffer;
621621
use arrow_schema::{DataType, Field, Schema, SchemaRef};
@@ -704,7 +704,7 @@ mod tests {
704704
where
705705
T::Native: SampleUniform,
706706
{
707-
const VECTOR_NUM_PER_ROW: usize = 5;
707+
const VECTOR_NUM_PER_ROW: usize = 3;
708708
let start_id = start_id.unwrap_or(0);
709709
let ids = Arc::new(UInt64Array::from_iter_values(
710710
start_id..start_id + num_rows as u64,
@@ -717,32 +717,20 @@ mod tests {
717717
let data_type = vectors.data_type().clone();
718718
let mut fields = vec![Field::new("id", DataType::UInt64, false)];
719719
let mut arrays: Vec<ArrayRef> = vec![ids];
720-
let mut fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
721-
if data_type != DataType::UInt8 {
722-
fsl = lance_linalg::kernels::normalize_fsl(&fsl).unwrap();
723-
}
720+
let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
724721
if is_multivector {
722+
let vector_field = Arc::new(Field::new(
723+
"item",
724+
DataType::FixedSizeList(Arc::new(Field::new("item", data_type, true)), DIM as i32),
725+
true,
726+
));
725727
fields.push(Field::new(
726728
"vector",
727-
DataType::List(Arc::new(Field::new(
728-
"item",
729-
DataType::FixedSizeList(
730-
Arc::new(Field::new("item", data_type.clone(), true)),
731-
DIM as i32,
732-
),
733-
true,
734-
))),
729+
DataType::List(vector_field.clone()),
735730
true,
736731
));
737732
let array = Arc::new(ListArray::new(
738-
Arc::new(Field::new(
739-
"item",
740-
DataType::FixedSizeList(
741-
Arc::new(Field::new("item", data_type, true)),
742-
DIM as i32,
743-
),
744-
true,
745-
)),
733+
vector_field,
746734
OffsetBuffer::from_lengths(std::iter::repeat(VECTOR_NUM_PER_ROW).take(num_rows)),
747735
Arc::new(fsl),
748736
None,
@@ -978,7 +966,7 @@ mod tests {
978966
params: VectorIndexParams,
979967
range: Range<T::Native>,
980968
) where
981-
T::Native: SampleUniform + std::ops::Add<Output = T::Native>,
969+
T::Native: SampleUniform,
982970
{
983971
let test_dir = tempdir().unwrap();
984972
let test_uri = test_dir.path().to_str().unwrap();
@@ -1019,7 +1007,10 @@ mod tests {
10191007
let mut count = 0;
10201008
// append more rows and make delta index until hitting the retrain threshold
10211009
loop {
1022-
let range = range.start..range.end + range.end + range.end + range.end + range.end;
1010+
let range = match count {
1011+
0 => range.clone(),
1012+
_ => range.end.neg_wrapping().sub_wrapping(range.end)..range.end.neg_wrapping(),
1013+
};
10231014
append_dataset::<T>(&mut dataset, 500, range).await;
10241015
dataset
10251016
.optimize_indices(&OptimizeOptions {
@@ -1032,8 +1023,22 @@ mod tests {
10321023

10331024
let new_avg_loss = get_avg_loss(&dataset).await;
10341025
if new_avg_loss / original_avg_loss >= *AVG_LOSS_RETRAIN_THRESHOLD {
1026+
if count <= 1 {
1027+
// the first append is with the same data distribution, so the loss should be
1028+
// very close to the original loss, then it shouldn't hit the retrain threshold
1029+
panic!(
1030+
"retrain threshold {} should not be hit",
1031+
*AVG_LOSS_RETRAIN_THRESHOLD
1032+
);
1033+
}
10351034
break;
10361035
}
1036+
if count >= 10 {
1037+
panic!(
1038+
"failed to hit the retrain threshold {}",
1039+
*AVG_LOSS_RETRAIN_THRESHOLD
1040+
);
1041+
}
10371042

10381043
// all delta indices should have the same centroids as the original index
10391044
let ivf_models = get_ivf_models(&dataset).await;
@@ -1052,7 +1057,7 @@ mod tests {
10521057
.await
10531058
.unwrap();
10541059
let stats = dataset.index_statistics("vector_idx").await.unwrap();
1055-
let stats = serde_json::to_value(stats).unwrap();
1060+
let stats: serde_json::Value = serde_json::from_str(&stats).unwrap();
10561061
assert_eq!(stats["num_indices"], 1);
10571062

10581063
let ivf_models = get_ivf_models(&dataset).await;

0 commit comments

Comments
 (0)