Skip to content

Commit f69480e

Browse files
authored
fix: scalar quantization can't work with NaNs (#3476)
Address potential out-of-range issues when scaling values to `u8` in the `ScalarQuantizer`. Introduce a test case to handle NaN values in the scaling function. --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent 8f8b630 commit f69480e

File tree

1 file changed

+18
-9
lines changed
  • rust/lance-index/src/vector

1 file changed

+18
-9
lines changed

rust/lance-index/src/vector/sq.rs

+18-9
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ impl ScalarQuantizer {
8888
.as_slice();
8989

9090
self.bounds = data.iter().fold(self.bounds.clone(), |f, v| {
91-
f.start.min(v.to_f64().unwrap())..f.end.max(v.to_f64().unwrap())
91+
f.start.min(v.as_())..f.end.max(v.as_())
9292
});
9393

9494
Ok(self.bounds.clone())
@@ -233,19 +233,17 @@ impl Quantization for ScalarQuantizer {
233233
}
234234

235235
pub(crate) fn scale_to_u8<T: ArrowFloatType>(values: &[T::Native], bounds: &Range<f64>) -> Vec<u8> {
236+
if bounds.start == bounds.end {
237+
return vec![0; values.len()];
238+
}
239+
236240
let range = bounds.end - bounds.start;
237241
values
238242
.iter()
239243
.map(|&v| {
240244
let v = v.to_f64().unwrap();
241-
match v {
242-
v if v < bounds.start => 0,
243-
v if v > bounds.end => 255,
244-
_ => ((v - bounds.start) * f64::from_u32(255).unwrap() / range)
245-
.round()
246-
.to_u8()
247-
.unwrap(),
248-
}
245+
let v = ((v - bounds.start) * 255.0 / range).round();
246+
v as u8 // rust `as` performs saturating cast when casting float to int, so it's safe and expected here
249247
})
250248
.collect_vec()
251249
}
@@ -350,4 +348,15 @@ mod tests {
350348
assert_eq!(*v, (i * 17) as u8,);
351349
});
352350
}
351+
352+
#[tokio::test]
353+
async fn test_scale_to_u8_with_nan() {
354+
let values = vec![0.0, 1.0, 2.0, 3.0, f64::NAN];
355+
let bounds = Range::<f64> {
356+
start: 0.0,
357+
end: 3.0,
358+
};
359+
let u8_values = scale_to_u8::<Float64Type>(&values, &bounds);
360+
assert_eq!(u8_values, vec![0, 85, 170, 255, 0]);
361+
}
353362
}

0 commit comments

Comments
 (0)