-
Notifications
You must be signed in to change notification settings - Fork 265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: implement XTR for retrieving multivector #3437
Conversation
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3437 +/- ##
==========================================
- Coverage 78.48% 78.48% -0.01%
==========================================
Files 252 252
Lines 94011 94220 +209
Branches 94011 94220 +209
==========================================
+ Hits 73783 73947 +164
- Misses 17232 17279 +47
+ Partials 2996 2994 -2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something seems off in the algorithm, with how missed_similarities
is handled. Could you address my comment, and also maybe write a unit tests that shows we get correct results? out of this?
let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>(); | ||
let dists = batch[DIST_COL].as_primitive::<Float32Type>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are using values here, can we add a debug assert that there are non nulls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
// at most, we will have k * refine_factor results for each query | ||
let mut results = HashMap::with_capacity(k * refactor); | ||
let mut missed_similarities = 0.0; | ||
while let Some((min_sim, batch)) = reduced_inputs.try_next().await? { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the algorithm in the paper deeply, but it seems odd to me that the order of the ANN queries matters. It appears that later batches will be adding a higher missed_similarities
value. Is that intentional?
It looks like the output order of select_all
isn't deterministic. https://docs.rs/futures/latest/futures/stream/fn.select_all.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's intentional, this is a little bit complicated, will add more comments about this:
considering we are updating the final results
with a batch
from a query vector, and for a row x
:
- if
x
exists inresults
but notbatch
: setmin_sim
as the estimated similarity, the contribution ismin_sim
- if
x
exists in both, then the contribution issim
inbatch
- if
x
exists in onlybatch
, this means all queries before missed this row, this algo maintainsmissed_similarities
as the sum ofmin_sim
so far, so the contribution ismissed_similarities + sim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. It makes more sense now that I know what part was missing.
we have tests here https://github.com/lancedb/lance/pull/3437/files#diff-6de816b72e7c722316243c57df4f809ad34dc8581367c72335154dada48c40edL993 |
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
I meant more text the XTR algorithm itself was working as expected. Part of why I'm having a hard time understand this PR is there are no tests showing the expected behavior of the algorithm. |
// for a row `r`: | ||
// if `r` is in only `results``, then `results[r] += min_sim` | ||
// if `r` is in only `query_results`, then `results[r] = query_results[r] + missed_similarities`, | ||
// here `missed_similarities` is the sum of `min_sim` from previous iterations | ||
// if `r` is in both, then `results[r] += query_results[r]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still having trouble understanding the logic.
So let's imagine batch 1 is:
- rowid: 1, distance: 2.0
- rowid: 2, distance 3.0
And batch 2 is:
- rowid: 2, distance: 1.0
- rowid: 3, distance: 0.5
If batch 1 comes in first, then we should get:
- rowid 1, distance = 2.0 + 2.0 = 4.0 (distance in batch 1, plus min sim from batch 1)
- rowid 2, distance = 3.0 + 1.0 = 4.0 (sum of distances across batches)
- rowid 3, distance = 0.0 + 0.5 = 0.5 (min distance plus 0.5)
But if batch 2 comes in first, then we will get:
- rowid 1, distance = 2.0 + 0.5 = 2.5 (min distance of batch 2, plus distance in batch 1)
- rowid 2, distance = 1.0 + 3.0 = 4.0
- rowid 3, distance = 0.5 + 0.5 = 1.0
rowid 1 and rowid 3 seem to get completely different distances depending on the order that the queries finish. And there doesn't seem to be any definition of what queries finish first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min_sim
is from the current batch, missed_similarities
is sum of the min_sim of all batches before.
so if batch 1 comes first:
- rowid 1, sim = 2.0 + 0.5 = 2.5 (sim in batch 1 + min_sim in batch 2)
- rowid 2, sim = 3.0 + 1.0 = 4.0 (sum of sim)
- rowid 3, sim = 2.0 + 0.5 = 2.5 (missed_sim_sum (min_sim in batch 1) + sim in batch 2)
if batch 2 comes first:
- rowid 1, sim = 0.5 + 2.0 = 2.5 (missed_sim_sum (min_sim in batch 2) + sim in batch 1)
- rowid 2, sim = 3.0 + 1.0 = 4.0
- rowid 3, sim = 0.5 + 2.0 = 2.5 (sim in batch 2 + min_sim in batch 1)
yeah it's very complicated... i should add some tests to verify that the scores are all the same in any orders
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think I see now. It wasn't clear initially which parts run per batch and what state is regular.
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
.collect::<Vec<_>>(); | ||
|
||
let mut res: Option<HashMap<_, _>> = None; | ||
for perm in batches.into_iter().permutations(3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether select_all
produces the items in the order of streams if all of them are ready, but 6 runs should be enough to verify it even though it produces in random order.
// for a row `r`: | ||
// if `r` is in only `results``, then `results[r] += min_sim` | ||
// if `r` is in only `query_results`, then `results[r] = query_results[r] + missed_similarities`, | ||
// here `missed_similarities` is the sum of `min_sim` from previous iterations | ||
// if `r` is in both, then `results[r] += query_results[r]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think I see now. It wasn't clear initially which parts run per batch and what state is regular.
this PR introduces XTR, which can score the documents without the original multivector, so we don't need any IO op for searching on multivector.
it sets the minimum similarity as the estimated similarity for missed documents of single query vector.