From 86fa0db34ec4aa3b8d63d16fa96663899fd37d65 Mon Sep 17 00:00:00 2001 From: Ali Safaya Date: Tue, 14 Jan 2025 13:51:02 -0800 Subject: [PATCH] Fix IndexIVFFastScan reconstruct_from_offset method (#4095) Summary: Resolves issue https://github.com/facebookresearch/faiss/issues/4089 - IndexIVFPQFastScan crashes with certain nlist values The `reconstruct_from_offset` method in `IndexIVFFastScan` was incorrectly reconstructing vectors, causing crashes when the `nlist` parameter was not byte-aligned (e.g. 100 instead of 256). The root cause was that the `list_no` (Voronoi cell number) was not being properly encoded into the `code` vector before passing it to the `sa_decode` function. This resulted in invalid `list_no` values being read in `sa_decode`, triggering the assertion failure `'list_no >= 0 && list_no < nlist'` when `nlist` in some cases. This PR fixes the issue with the following changes to `reconstruct_from_offset`: 1. Encode the `list_no` into the beginning of the `code` vector using the existing `encode_listno` method 2. Start the `BitstringWriter` after the coarse code portion of `code` (shifted by `coarse_code_size()` bytes) 3. Remove the residual centroid addition logic, as it is already handled in `sa_decode` After these changes: - Crashes no longer occur for any `nlist` value - Reconstruction is now correct, matching the output of `IndexIVFPQ` Fixes https://github.com/facebookresearch/faiss/issues/4089 Please review and let me know if any changes are needed. Thanks! Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4095 Reviewed By: asadoughi Differential Revision: D67937160 Pulled By: mdouze fbshipit-source-id: 4705106ba49c01c788b3c75c39c2260615f45764 --- faiss/IndexIVFFastScan.cpp | 20 ++++++++------------ faiss/IndexIVFPQFastScan.cpp | 1 + tests/test_fast_scan_ivf.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/faiss/IndexIVFFastScan.cpp b/faiss/IndexIVFFastScan.cpp index f95ad354a7..2b6d7abc19 100644 --- a/faiss/IndexIVFFastScan.cpp +++ b/faiss/IndexIVFFastScan.cpp @@ -1353,34 +1353,30 @@ void IndexIVFFastScan::reconstruct_from_offset( int64_t offset, float* recons) const { // unpack codes + size_t coarse_size = coarse_code_size(); + std::vector code(coarse_size + code_size, 0); + encode_listno(list_no, code.data()); InvertedLists::ScopedCodes list_codes(invlists, list_no); - std::vector code(code_size, 0); - BitstringWriter bsw(code.data(), code_size); + BitstringWriter bsw(code.data() + coarse_size, code_size); + for (size_t m = 0; m < M; m++) { uint8_t c = pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m); bsw.write(c, nbits); } - sa_decode(1, code.data(), recons); - // add centroid to it - if (by_residual) { - std::vector centroid(d); - quantizer->reconstruct(list_no, centroid.data()); - for (int i = 0; i < d; ++i) { - recons[i] += centroid[i]; - } - } + sa_decode(1, code.data(), recons); } void IndexIVFFastScan::reconstruct_orig_invlists() { FAISS_THROW_IF_NOT(orig_invlists != nullptr); FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0); +#pragma omp parallel for if (nlist > 100) for (size_t list_no = 0; list_no < nlist; list_no++) { InvertedLists::ScopedCodes codes(invlists, list_no); InvertedLists::ScopedIds ids(invlists, list_no); - size_t list_size = orig_invlists->list_size(list_no); + size_t list_size = invlists->list_size(list_no); std::vector code(code_size, 0); for (size_t offset = 0; offset < list_size; offset++) { diff --git a/faiss/IndexIVFPQFastScan.cpp b/faiss/IndexIVFPQFastScan.cpp index 9d1cdfcae3..c1fd206ee2 100644 --- a/faiss/IndexIVFPQFastScan.cpp +++ b/faiss/IndexIVFPQFastScan.cpp @@ -76,6 +76,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) precomputed_table.nbytes()); } +#pragma omp parallel for if (nlist > 100) for (size_t i = 0; i < nlist; i++) { size_t nb = orig.invlists->list_size(i); size_t nb2 = roundup(nb, bbs); diff --git a/tests/test_fast_scan_ivf.py b/tests/test_fast_scan_ivf.py index 55de784ad6..63327e14c0 100644 --- a/tests/test_fast_scan_ivf.py +++ b/tests/test_fast_scan_ivf.py @@ -543,6 +543,37 @@ def test_by_residual_odd_dim(self): self.do_test(by_residual=True, d=30) +class TestReconstruct(unittest.TestCase): + + def do_test(self, by_residual=False): + d = 32 + metric = faiss.METRIC_L2 + + ds = datasets.SyntheticDataset(d, 2000, 5000, 200) + + index = faiss.IndexIVFPQFastScan(faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric) + index.by_residual = by_residual + index.make_direct_map(True) + index.train(ds.get_train()) + index.add(ds.get_database()) + + # Test reconstruction + index.reconstruct(123) # single id + index.reconstruct_n(123, 10) # single id + index.reconstruct_batch(np.arange(10)) + + # Test original list reconstruction + index.orig_invlists = faiss.ArrayInvertedLists(index.nlist, index.code_size) + index.reconstruct_orig_invlists() + assert index.orig_invlists.compute_ntotal() == index.ntotal + + def test_no_residual(self): + self.do_test(by_residual=False) + + def test_by_residual(self): + self.do_test(by_residual=True) + + class TestIsTrained(unittest.TestCase): def test_issue_2019(self):