Skip to content

Commit

Permalink
Moved add_sa_codes, sa_code_size to Index, IndexBinary base classes (#…
Browse files Browse the repository at this point in the history
…3989)

Summary:

Moved add_sa_codes, sa_code_size to Index, IndexBinary base classes from IndexIVF to support adding coded vectors with ids using IDMap2,PQ

For an alternative approach, see previous attempt with merge_ids and merge_codes: D64941798

Reviewed By: mnorris11

Differential Revision: D64972587
  • Loading branch information
Amir Sadoughi authored and facebook-github-bot committed Oct 28, 2024
1 parent 5539039 commit db28e9b
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 23 deletions.
33 changes: 13 additions & 20 deletions demos/index_pq_flat_separate_codes_from_codebook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss
#!/usr/bin/env -S grimaldi --kernel faiss_binary_local
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
Expand Down Expand Up @@ -39,7 +39,6 @@ def read_ids_codes():


def write_ids_codes(ids, codes):
# print(ids, codes)
np.save("/tmp/ids.npy", ids)
np.save("/tmp/codes.npy", codes.reshape(len(ids), -1))

Expand All @@ -49,46 +48,40 @@ def write_template_index(template_index):


def read_template_index_instance():
pq_index = faiss.read_index("/tmp/template.index")
return pq_index, faiss.IndexIDMap2(pq_index)
return faiss.read_index("/tmp/template.index")

""":py"""
# at train time

template_index = faiss.IndexPQ(d, M, nbits)
template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}")
template_index.train(training_data)
write_template_index(template_index)

""":py"""
# New database vector

template_instance_index, id_wrapper_index = read_template_index_instance()
database_vector_id, database_vector_float32 = np.int64(
np.random.rand() * 10000
), np.random.rand(1, d).astype("float32")
index = read_template_index_instance()
database_vector_id, database_vector_float32 = np.random.randint(10000), np.random.rand(1, d).astype(np.float32)
ids, codes = read_ids_codes()
# print(ids, codes)
code = template_instance_index.sa_encode(database_vector_float32)

code = index.index.sa_encode(database_vector_float32)

if ids is not None and codes is not None:
ids = np.concatenate((ids, [database_vector_id]))
codes = np.vstack((codes, code))
else:
ids = np.array([database_vector_id])
codes = np.array([code])

write_ids_codes(ids, codes)

""":py '1545041403561975'"""
""":py '331546060044009'"""
# then at query time
query_vector_float32 = np.random.rand(1, d).astype("float32")
template_index_instance, id_wrapper_index = read_template_index_instance()
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
id_wrapper_index = read_template_index_instance()
ids, codes = read_ids_codes()

for code in codes:
for c in code:
template_index_instance.codes.push_back(int(c))
template_index_instance.ntotal = len(codes)
for i in ids:
id_wrapper_index.id_map.push_back(int(i))
id_wrapper_index.add_sa_codes(codes, ids)

id_wrapper_index.search(query_vector_float32, k=5)

Expand Down
4 changes: 4 additions & 0 deletions faiss/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ void Index::sa_decode(idx_t, const uint8_t*, float*) const {
FAISS_THROW_MSG("standalone codec not implemented for this type of index");
}

void Index::add_sa_codes(idx_t, const uint8_t*, const idx_t*) {
FAISS_THROW_MSG("add_sa_codes not implemented for this type of index");
}

namespace {

// storage that explicitly reconstructs vectors before computing distances
Expand Down
7 changes: 7 additions & 0 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ struct Index {
* trained in the same way and have the same
* parameters). Otherwise throw. */
virtual void check_compatible_for_merge(const Index& otherIndex) const;

/** Add vectors that are computed with the standalone codec
*
* @param codes codes to add size n * sa_code_size()
* @param xids corresponding ids, size n
*/
virtual void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
};

} // namespace faiss
Expand Down
11 changes: 11 additions & 0 deletions faiss/IndexBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,15 @@ void IndexBinary::check_compatible_for_merge(
FAISS_THROW_MSG("check_compatible_for_merge() not implemented");
}

size_t IndexBinary::sa_code_size() const {
return code_size;
}

void IndexBinary::add_sa_codes(
idx_t n,
const uint8_t* codes,
const idx_t* xids) {
add_with_ids(n, codes, xids);
}

} // namespace faiss
6 changes: 6 additions & 0 deletions faiss/IndexBinary.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ struct IndexBinary {
* parameters). Otherwise throw. */
virtual void check_compatible_for_merge(
const IndexBinary& otherIndex) const;

/** size of the produced codes in bytes */
virtual size_t sa_code_size() const;

/** Same as add_with_ids for IndexBinary. */
virtual void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
};

} // namespace faiss
Expand Down
9 changes: 9 additions & 0 deletions faiss/IndexFlatCodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ void IndexFlatCodes::add(idx_t n, const float* x) {
ntotal += n;
}

void IndexFlatCodes::add_sa_codes(
idx_t n,
const uint8_t* codes_in,
const idx_t* /* xids */) {
codes.resize((ntotal + n) * code_size);
memcpy(codes.data() + (ntotal * code_size), codes_in, n * code_size);
ntotal += n;
}

void IndexFlatCodes::reset() {
codes.clear();
ntotal = 0;
Expand Down
3 changes: 3 additions & 0 deletions faiss/IndexFlatCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ struct IndexFlatCodes : Index {

virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;

virtual void add_sa_codes(idx_t n, const uint8_t* x, const idx_t* xids)
override;

// permute_entries. perm of size ntotal maps new to old positions
void permute_entries(const idx_t* perm);
};
Expand Down
17 changes: 17 additions & 0 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
this->ntotal = index->ntotal;
}

template <typename IndexT>
size_t IndexIDMapTemplate<IndexT>::sa_code_size() const {
return index->sa_code_size();
}

template <typename IndexT>
void IndexIDMapTemplate<IndexT>::add_sa_codes(
idx_t n,
const uint8_t* codes,
const idx_t* xids) {
index->add_sa_codes(n, codes, xids);
for (idx_t i = 0; i < n; i++) {
id_map.push_back(xids[i]);
}
this->ntotal = index->ntotal;
}

namespace {

/// RAII object to reset the IDSelector in the params object
Expand Down
3 changes: 3 additions & 0 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ struct IndexIDMapTemplate : IndexT {
void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
void check_compatible_for_merge(const IndexT& otherIndex) const override;

size_t sa_code_size() const override;
void add_sa_codes(idx_t n, const uint8_t* x, const idx_t* xids) override;

~IndexIDMapTemplate() override;
IndexIDMapTemplate() {
own_fields = false;
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ struct IndexIVF : Index, IndexIVFInterface {
* @param codes codes to add size n * sa_code_size()
* @param xids corresponding ids, size n
*/
void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids)
override;

/** Train the encoder for the vectors.
*
Expand Down
3 changes: 1 addition & 2 deletions faiss/python/class_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,7 @@ def replacement_permute_entries(self, perm):
replacement_range_search_preassigned, ignore_missing=True)
replace_method(the_class, 'sa_encode', replacement_sa_encode)
replace_method(the_class, 'sa_decode', replacement_sa_decode)
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes,
ignore_missing=True)
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes)
replace_method(the_class, 'permute_entries', replacement_permute_entries,
ignore_missing=True)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_standalone_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,27 @@ def test_transfer(self):
np.testing.assert_array_equal(Dref, Dnew)


class TestIDMap(unittest.TestCase):
def test_idmap(self):
ds = SyntheticDataset(32, 2000, 200, 100)
ids = np.random.randint(10000, size=ds.nb, dtype='int64')
index = faiss.index_factory(ds.d, "IDMap2,PQ8x2")
index.train(ds.get_train())
index.add_with_ids(ds.get_database(), ids)
Dref, Iref = index.search(ds.get_queries(), 10)

index.reset()

index.train(ds.get_train())
codes = index.index.sa_encode(ds.get_database())
index.add_sa_codes(codes, ids)
Dnew, Inew = index.search(ds.get_queries(), 10)

np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_equal(Dref, Dnew)



class TestRefine(unittest.TestCase):

def test_refine(self):
Expand Down

0 comments on commit db28e9b

Please sign in to comment.