From af70c5bcce07eb419a7e6d33830f78b1e8413465 Mon Sep 17 00:00:00 2001 From: George Wang Date: Tue, 8 Oct 2024 17:18:13 -0700 Subject: [PATCH] 3893 - Fix index factory order of idmap and refinement (#3928) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3928 Fix issue in T203425107 Reviewed By: asadoughi Differential Revision: D64068971 fbshipit-source-id: 56db439793539570a102773ff2c7158d48feb7a9 --- faiss/index_factory.cpp | 36 ++++++++++++++++++------------------ tests/test_factory.py | 24 +++++++++++++++++++----- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp index cc57fda4de..092df879bf 100644 --- a/faiss/index_factory.cpp +++ b/faiss/index_factory.cpp @@ -679,6 +679,24 @@ std::unique_ptr index_factory_sub( // for the current match std::smatch sm; + // IndexIDMap -- it turns out is was used both as a prefix and a suffix, so + // support both + if (re_match(description, "(.+),IDMap2", sm) || + re_match(description, "IDMap2,(.+)", sm)) { + IndexIDMap2* idmap2 = new IndexIDMap2( + index_factory_sub(d, sm[1].str(), metric).release()); + idmap2->own_fields = true; + return std::unique_ptr(idmap2); + } + + if (re_match(description, "(.+),IDMap", sm) || + re_match(description, "IDMap,(.+)", sm)) { + IndexIDMap* idmap = new IndexIDMap( + index_factory_sub(d, sm[1].str(), metric).release()); + idmap->own_fields = true; + return std::unique_ptr(idmap); + } + // handle refines if (re_match(description, "(.+),RFlat", sm) || re_match(description, "(.+),Refine\\((.+)\\)", sm)) { @@ -755,24 +773,6 @@ std::unique_ptr index_factory_sub( d); } - // IndexIDMap -- it turns out is was used both as a prefix and a suffix, so - // support both - if (re_match(description, "(.+),IDMap2", sm) || - re_match(description, "IDMap2,(.+)", sm)) { - IndexIDMap2* idmap2 = new IndexIDMap2( - index_factory_sub(d, sm[1].str(), metric).release()); - idmap2->own_fields = true; - return std::unique_ptr(idmap2); - } - - if (re_match(description, "(.+),IDMap", sm) || - re_match(description, "IDMap,(.+)", sm)) { - IndexIDMap* idmap = new IndexIDMap( - index_factory_sub(d, sm[1].str(), metric).release()); - idmap->own_fields = true; - return std::unique_ptr(idmap); - } - { // handle basic index types Index* index = parse_other_indexes(description, d, metric); if (index) { diff --git a/tests/test_factory.py b/tests/test_factory.py index 220ba77660..f16a60e772 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -12,6 +12,7 @@ from faiss.contrib import factory_tools from faiss.contrib import datasets + class TestFactory(unittest.TestCase): def test_factory_1(self): @@ -40,7 +41,6 @@ def test_factory_2(self): index = faiss.index_factory(12, "SQ8") assert index.code_size == 12 - def test_factory_3(self): index = faiss.index_factory(12, "IVF10,PQ4") @@ -73,7 +73,8 @@ def test_factory_HNSW(self): def test_factory_HNSW_newstyle(self): index = faiss.index_factory(12, "HNSW32,Flat") assert index.storage.sa_code_size() == 12 * 4 - index = faiss.index_factory(12, "HNSW32,SQ8", faiss.METRIC_INNER_PRODUCT) + index = faiss.index_factory(12, "HNSW32,SQ8", + faiss.METRIC_INNER_PRODUCT) assert index.storage.sa_code_size() == 12 assert index.metric_type == faiss.METRIC_INNER_PRODUCT index = faiss.index_factory(12, "HNSW,PQ4") @@ -131,7 +132,8 @@ def test_factory_fast_scan(self): self.assertEqual(index.pq.nbits, 4) index = faiss.index_factory(56, "PQ28x4fs_64") self.assertEqual(index.bbs, 64) - index = faiss.index_factory(56, "IVF50,PQ28x4fs_64", faiss.METRIC_INNER_PRODUCT) + index = faiss.index_factory(56, "IVF50,PQ28x4fs_64", + faiss.METRIC_INNER_PRODUCT) self.assertEqual(index.bbs, 64) self.assertEqual(index.nlist, 50) self.assertTrue(index.cp.spherical) @@ -158,7 +160,6 @@ def test_parenthesis_refine(self): self.assertEqual(rf.pq.M, 25) self.assertEqual(rf.pq.nbits, 12) - def test_parenthesis_refine_2(self): # Refine applies on the whole index including pre-transforms index = faiss.index_factory(50, "PCA32,IVF32,Flat,Refine(PQ25x12)") @@ -264,6 +265,19 @@ def test_idmap2_prefix(self): index = faiss.downcast_index(index) self.assertEqual(index.__class__, faiss.IndexIDMap2) + def test_idmap_refine(self): + index = faiss.index_factory(8, "IDMap,PQ4x4fs,RFlat") + self.assertEqual(index.__class__, faiss.IndexIDMap) + refine_index = faiss.downcast_index(index.index) + self.assertEqual(refine_index.__class__, faiss.IndexRefineFlat) + base_index = faiss.downcast_index(refine_index.base_index) + self.assertEqual(base_index.__class__, faiss.IndexPQFastScan) + + # Index now works with add_with_ids, but not with add + index.train(np.zeros((16, 8))) + index.add_with_ids(np.zeros((16, 8)), np.arange(16)) + self.assertRaises(RuntimeError, index.add, np.zeros((16, 8))) + def test_ivf_hnsw(self): index = faiss.index_factory(123, "IVF100_HNSW,Flat") quantizer = faiss.downcast_index(index.quantizer) @@ -337,4 +351,4 @@ def test_replace_vt(self): index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1) index.replace_vt(faiss.ITQTransform(10, 10)) gc.collect() - index.vt.d_out # this should not crash + index.vt.d_out # this should not crash