Skip to content

Commit

Permalink
Fix IndexIVFFastScan reconstruct_from_offset method (#4095)
Browse files Browse the repository at this point in the history
Summary:
Resolves issue #4089 - IndexIVFPQFastScan crashes with certain nlist values

The `reconstruct_from_offset` method in `IndexIVFFastScan` was incorrectly reconstructing vectors, causing crashes when the `nlist` parameter was not byte-aligned (e.g. 100 instead of 256).

The root cause was that the `list_no` (Voronoi cell number) was not being properly encoded into the `code` vector before passing it to the `sa_decode` function. This resulted in invalid `list_no` values being read in `sa_decode`, triggering the assertion failure `'list_no >= 0 && list_no < nlist'` when `nlist` in some cases.

This PR fixes the issue with the following changes to `reconstruct_from_offset`:

1. Encode the `list_no` into the beginning of the `code` vector using the existing `encode_listno` method
2. Start the `BitstringWriter` after the coarse code portion of `code` (shifted by `coarse_code_size()` bytes)
3. Remove the residual centroid addition logic, as it is already handled in `sa_decode`

After these changes:
- Crashes no longer occur for any `nlist` value
- Reconstruction is now correct, matching the output of `IndexIVFPQ`

Fixes #4089

Please review and let me know if any changes are needed. Thanks!

Pull Request resolved: #4095

Reviewed By: asadoughi

Differential Revision: D67937160

Pulled By: mdouze

fbshipit-source-id: 4705106ba49c01c788b3c75c39c2260615f45764
  • Loading branch information
alisafaya authored and facebook-github-bot committed Jan 14, 2025
1 parent b9fe1dc commit 86fa0db
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 12 deletions.
20 changes: 8 additions & 12 deletions faiss/IndexIVFFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,34 +1353,30 @@ void IndexIVFFastScan::reconstruct_from_offset(
int64_t offset,
float* recons) const {
// unpack codes
size_t coarse_size = coarse_code_size();
std::vector<uint8_t> code(coarse_size + code_size, 0);
encode_listno(list_no, code.data());
InvertedLists::ScopedCodes list_codes(invlists, list_no);
std::vector<uint8_t> code(code_size, 0);
BitstringWriter bsw(code.data(), code_size);
BitstringWriter bsw(code.data() + coarse_size, code_size);

for (size_t m = 0; m < M; m++) {
uint8_t c =
pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
bsw.write(c, nbits);
}
sa_decode(1, code.data(), recons);

// add centroid to it
if (by_residual) {
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());
for (int i = 0; i < d; ++i) {
recons[i] += centroid[i];
}
}
sa_decode(1, code.data(), recons);
}

void IndexIVFFastScan::reconstruct_orig_invlists() {
FAISS_THROW_IF_NOT(orig_invlists != nullptr);
FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);

#pragma omp parallel for if (nlist > 100)
for (size_t list_no = 0; list_no < nlist; list_no++) {
InvertedLists::ScopedCodes codes(invlists, list_no);
InvertedLists::ScopedIds ids(invlists, list_no);
size_t list_size = orig_invlists->list_size(list_no);
size_t list_size = invlists->list_size(list_no);
std::vector<uint8_t> code(code_size, 0);

for (size_t offset = 0; offset < list_size; offset++) {
Expand Down
1 change: 1 addition & 0 deletions faiss/IndexIVFPQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
precomputed_table.nbytes());
}

#pragma omp parallel for if (nlist > 100)
for (size_t i = 0; i < nlist; i++) {
size_t nb = orig.invlists->list_size(i);
size_t nb2 = roundup(nb, bbs);
Expand Down
31 changes: 31 additions & 0 deletions tests/test_fast_scan_ivf.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,37 @@ def test_by_residual_odd_dim(self):
self.do_test(by_residual=True, d=30)


class TestReconstruct(unittest.TestCase):

def do_test(self, by_residual=False):
d = 32
metric = faiss.METRIC_L2

ds = datasets.SyntheticDataset(d, 2000, 5000, 200)

index = faiss.IndexIVFPQFastScan(faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric)
index.by_residual = by_residual
index.make_direct_map(True)
index.train(ds.get_train())
index.add(ds.get_database())

# Test reconstruction
index.reconstruct(123) # single id
index.reconstruct_n(123, 10) # single id
index.reconstruct_batch(np.arange(10))

# Test original list reconstruction
index.orig_invlists = faiss.ArrayInvertedLists(index.nlist, index.code_size)
index.reconstruct_orig_invlists()
assert index.orig_invlists.compute_ntotal() == index.ntotal

def test_no_residual(self):
self.do_test(by_residual=False)

def test_by_residual(self):
self.do_test(by_residual=True)


class TestIsTrained(unittest.TestCase):

def test_issue_2019(self):
Expand Down

0 comments on commit 86fa0db

Please sign in to comment.