Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: detect the drift and retrain the index if hit threshold #3489

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: record loss for IVF and KMeans
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
BubbleCal committed Feb 28, 2025
commit be1974d41e4d72a7b73a1e6e5adc0504224b6e01
14 changes: 14 additions & 0 deletions rust/lance-index/src/vector/ivf/storage.rs
Original file line number Diff line number Diff line change
@@ -34,6 +34,9 @@ pub struct IvfModel {

/// Number of vectors in each partition.
pub lengths: Vec<u32>,

/// Kmeans loss
pub loss: Option<f64>,
}

impl DeepSizeOf for IvfModel {
@@ -53,6 +56,7 @@ impl IvfModel {
centroids: None,
offsets: vec![],
lengths: vec![],
loss: None,
}
}

@@ -61,6 +65,7 @@ impl IvfModel {
centroids: Some(centroids),
offsets: vec![],
lengths: vec![],
loss: None,
}
}

@@ -88,6 +93,14 @@ impl IvfModel {
self.lengths[part] as usize
}

pub fn num_rows(&self) -> u64 {
self.lengths.iter().map(|x| *x as u64).sum()
}

pub fn avg_loss(&self) -> Option<f64> {
self.loss.map(|loss| loss / self.num_rows() as f64)
}

/// Use the query vector to find `nprobes` closest partitions.
pub fn find_partitions(
&self,
@@ -215,6 +228,7 @@ impl TryFrom<PbIvf> for IvfModel {
centroids,
offsets,
lengths: proto.lengths,
loss: None,
})
}
}
13 changes: 13 additions & 0 deletions rust/lance-linalg/src/kmeans.rs
Original file line number Diff line number Diff line change
@@ -103,6 +103,9 @@ pub struct KMeans {

/// How to calculate distance between two vectors.
pub distance_type: DistanceType,

/// The loss of the last training.
pub loss: f64,
}

/// Randomly initialize kmeans centroids.
@@ -127,6 +130,7 @@ fn kmeans_random_init<T: ArrowPrimitiveType>(
centroids: Arc::new(centroids),
dimension,
distance_type,
loss: f64::MAX,
}
}

@@ -191,6 +195,7 @@ pub trait KMeansAlgo<T: Num> {
k: usize,
membership: &[Option<u32>],
distance_type: DistanceType,
loss: f64,
) -> KMeans;
}

@@ -245,6 +250,7 @@ where
k: usize,
membership: &[Option<u32>],
distance_type: DistanceType,
loss: f64,
) -> KMeans {
let mut cluster_cnts = vec![0_u64; k];
let mut new_centroids = vec![T::Native::zero(); k * dimension];
@@ -293,6 +299,7 @@ where
centroids: Arc::new(PrimitiveArray::<T>::from_iter_values(new_centroids)),
dimension,
distance_type,
loss,
}
}
}
@@ -337,6 +344,7 @@ impl KMeansAlgo<u8> for KModeAlgo {
k: usize,
membership: &[Option<u32>],
distance_type: DistanceType,
loss: f64,
) -> KMeans {
assert_eq!(distance_type, DistanceType::Hamming);

@@ -379,6 +387,7 @@ impl KMeansAlgo<u8> for KModeAlgo {
centroids: Arc::new(UInt8Array::from(centroids)),
dimension,
distance_type,
loss,
}
}
}
@@ -389,6 +398,7 @@ impl KMeans {
centroids: arrow_array::array::new_empty_array(&DataType::Float32),
dimension,
distance_type,
loss: f64::MAX,
}
}

@@ -398,6 +408,7 @@ impl KMeans {
centroids: ArrayRef,
dimension: usize,
distance_type: DistanceType,
loss: f64,
) -> Self {
assert!(matches!(
centroids.data_type(),
@@ -407,6 +418,7 @@ impl KMeans {
centroids,
dimension,
distance_type,
loss,
}
}

@@ -496,6 +508,7 @@ impl KMeans {
k,
&membership,
params.distance_type,
last_loss,
);
last_membership = Some(membership);
if (loss - last_loss).abs() / last_loss < params.tolerance {
1 change: 1 addition & 0 deletions rust/lance/src/index/vector/ivf/v2.rs
Original file line number Diff line number Diff line change
@@ -781,6 +781,7 @@ mod tests {
let gt_set = gt.iter().map(|r| r.1).collect::<HashSet<_>>();

let recall = row_ids.intersection(&gt_set).count() as f32 / k as f32;
println!("recall: {}", recall);
assert!(
recall >= recall_requirement,
"recall: {}\n results: {:?}\n\ngt: {:?}",

Unchanged files with check annotations Beta

.map(|task| task.load_and_remap(reader.clone(), index, mapping))
.buffered(object_store.io_parallelism());
let mut ivf = IvfModel {

Check failure on line 1402 in rust/lance/src/index/vector/ivf.rs

GitHub Actions / linux-arm

missing field `loss` in initializer of `IvfModel`

Check failure on line 1402 in rust/lance/src/index/vector/ivf.rs

GitHub Actions / linux-build (stable)

missing field `loss` in initializer of `IvfModel`
centroids: index.ivf.centroids.clone(),
offsets: Vec::with_capacity(index.ivf.offsets.len()),
lengths: Vec::with_capacity(index.ivf.lengths.len()),