-
Notifications
You must be signed in to change notification settings - Fork 265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support 4bit PQ on new IVF_PQ #3144
Conversation
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3144 +/- ##
==========================================
+ Coverage 77.93% 77.94% +0.01%
==========================================
Files 242 242
Lines 81736 81904 +168
Branches 81736 81904 +168
==========================================
+ Hits 63698 63840 +142
- Misses 14849 14891 +42
+ Partials 3189 3173 -16
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
@@ -38,6 +38,9 @@ use super::DISTANCE_TYPE_KEY; | |||
/// </section> | |||
pub trait DistCalculator { | |||
fn distance(&self, id: u32) -> f32; | |||
fn distance_all(&self) -> Vec<f32> { | |||
unimplemented!("Implement this") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a GH issue for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for IVF_PQ only for now, can implement this for all distance calculator
.zip(0..storage.len() as u32) | ||
.map(|(dist, id)| OrderedNode { | ||
id, | ||
dist: OrderedFloat(dist), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OrderedFloat
is very slow, can we use f32::total_cmp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already fixed it, orderedfloat uses total cmp
.iter() | ||
.zip(distances.iter_mut()) | ||
.for_each(|(¢roid_idx, sum)| { | ||
// for 4bit PQ, `centroid_idx` is 2 index, each index is 4bit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we make sure the registers are not swap out in such case, giving we have very high PQ, i.e., PQ=96 for dim=768. Do we have enough registers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep track of this as a potential bottleneck and optimize later? We could probably use SIMD types to force generate better code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I was thinking about this, calculating over single distance table may lead to better cache locality here instead of processing 2 at the same time.
for high PQ, it doesn't matter because the distance table is 2D matrix of shape (num_sub_vector, num_centroids), for the same sub vector, it's always 16 length (2^4)
@@ -109,6 +109,15 @@ impl ProductQuantizer { | |||
let num_sub_vectors = self.num_sub_vectors; | |||
let dim = self.dimension; | |||
let num_bits = self.num_bits; | |||
if num_bits == 4 && num_sub_vectors % 2 != 0 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we pad instead of requireing this? not a big deal as I think it's almost always a multiple of 2
}) | ||
.collect::<Vec<_>>() | ||
.collect::<Vec<_>>(); | ||
if num_bits == 4 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like a potential bottleneck. Let's have a ticket tracking these and maybe turn it in to generics?
} | ||
} | ||
|
||
fn distance_all(&self) -> Vec<f32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add some benchmarking soon? This is amazing!
just added a ticket to track all potential bottlenecks |
also introduces
distance_all
methods to distance calculator, to improve the new IVF_PQ search performancewe store the 4bit PQ codes in half-bytes