Skip to content

Commit

Permalink
3893 - Fix index factory order of idmap and refinement (#3928)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3928

Fix issue in T203425107

Reviewed By: asadoughi

Differential Revision: D64068971

fbshipit-source-id: 56db439793539570a102773ff2c7158d48feb7a9
  • Loading branch information
gtwang01 authored and facebook-github-bot committed Oct 9, 2024
1 parent c5aed7c commit af70c5b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
36 changes: 18 additions & 18 deletions faiss/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,24 @@ std::unique_ptr<Index> index_factory_sub(
// for the current match
std::smatch sm;

// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
// support both
if (re_match(description, "(.+),IDMap2", sm) ||
re_match(description, "IDMap2,(.+)", sm)) {
IndexIDMap2* idmap2 = new IndexIDMap2(
index_factory_sub(d, sm[1].str(), metric).release());
idmap2->own_fields = true;
return std::unique_ptr<Index>(idmap2);
}

if (re_match(description, "(.+),IDMap", sm) ||
re_match(description, "IDMap,(.+)", sm)) {
IndexIDMap* idmap = new IndexIDMap(
index_factory_sub(d, sm[1].str(), metric).release());
idmap->own_fields = true;
return std::unique_ptr<Index>(idmap);
}

// handle refines
if (re_match(description, "(.+),RFlat", sm) ||
re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
Expand Down Expand Up @@ -755,24 +773,6 @@ std::unique_ptr<Index> index_factory_sub(
d);
}

// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
// support both
if (re_match(description, "(.+),IDMap2", sm) ||
re_match(description, "IDMap2,(.+)", sm)) {
IndexIDMap2* idmap2 = new IndexIDMap2(
index_factory_sub(d, sm[1].str(), metric).release());
idmap2->own_fields = true;
return std::unique_ptr<Index>(idmap2);
}

if (re_match(description, "(.+),IDMap", sm) ||
re_match(description, "IDMap,(.+)", sm)) {
IndexIDMap* idmap = new IndexIDMap(
index_factory_sub(d, sm[1].str(), metric).release());
idmap->own_fields = true;
return std::unique_ptr<Index>(idmap);
}

{ // handle basic index types
Index* index = parse_other_indexes(description, d, metric);
if (index) {
Expand Down
24 changes: 19 additions & 5 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from faiss.contrib import factory_tools
from faiss.contrib import datasets


class TestFactory(unittest.TestCase):

def test_factory_1(self):
Expand Down Expand Up @@ -40,7 +41,6 @@ def test_factory_2(self):
index = faiss.index_factory(12, "SQ8")
assert index.code_size == 12


def test_factory_3(self):

index = faiss.index_factory(12, "IVF10,PQ4")
Expand Down Expand Up @@ -73,7 +73,8 @@ def test_factory_HNSW(self):
def test_factory_HNSW_newstyle(self):
index = faiss.index_factory(12, "HNSW32,Flat")
assert index.storage.sa_code_size() == 12 * 4
index = faiss.index_factory(12, "HNSW32,SQ8", faiss.METRIC_INNER_PRODUCT)
index = faiss.index_factory(12, "HNSW32,SQ8",
faiss.METRIC_INNER_PRODUCT)
assert index.storage.sa_code_size() == 12
assert index.metric_type == faiss.METRIC_INNER_PRODUCT
index = faiss.index_factory(12, "HNSW,PQ4")
Expand Down Expand Up @@ -131,7 +132,8 @@ def test_factory_fast_scan(self):
self.assertEqual(index.pq.nbits, 4)
index = faiss.index_factory(56, "PQ28x4fs_64")
self.assertEqual(index.bbs, 64)
index = faiss.index_factory(56, "IVF50,PQ28x4fs_64", faiss.METRIC_INNER_PRODUCT)
index = faiss.index_factory(56, "IVF50,PQ28x4fs_64",
faiss.METRIC_INNER_PRODUCT)
self.assertEqual(index.bbs, 64)
self.assertEqual(index.nlist, 50)
self.assertTrue(index.cp.spherical)
Expand All @@ -158,7 +160,6 @@ def test_parenthesis_refine(self):
self.assertEqual(rf.pq.M, 25)
self.assertEqual(rf.pq.nbits, 12)


def test_parenthesis_refine_2(self):
# Refine applies on the whole index including pre-transforms
index = faiss.index_factory(50, "PCA32,IVF32,Flat,Refine(PQ25x12)")
Expand Down Expand Up @@ -264,6 +265,19 @@ def test_idmap2_prefix(self):
index = faiss.downcast_index(index)
self.assertEqual(index.__class__, faiss.IndexIDMap2)

def test_idmap_refine(self):
index = faiss.index_factory(8, "IDMap,PQ4x4fs,RFlat")
self.assertEqual(index.__class__, faiss.IndexIDMap)
refine_index = faiss.downcast_index(index.index)
self.assertEqual(refine_index.__class__, faiss.IndexRefineFlat)
base_index = faiss.downcast_index(refine_index.base_index)
self.assertEqual(base_index.__class__, faiss.IndexPQFastScan)

# Index now works with add_with_ids, but not with add
index.train(np.zeros((16, 8)))
index.add_with_ids(np.zeros((16, 8)), np.arange(16))
self.assertRaises(RuntimeError, index.add, np.zeros((16, 8)))

def test_ivf_hnsw(self):
index = faiss.index_factory(123, "IVF100_HNSW,Flat")
quantizer = faiss.downcast_index(index.quantizer)
Expand Down Expand Up @@ -337,4 +351,4 @@ def test_replace_vt(self):
index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1)
index.replace_vt(faiss.ITQTransform(10, 10))
gc.collect()
index.vt.d_out # this should not crash
index.vt.d_out # this should not crash

0 comments on commit af70c5b

Please sign in to comment.