Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add python tutorial on different indexs refinement and respect accuracy measurement #3480

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tutorial/python/9-RefineComparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import faiss

from faiss.contrib.evaluation import knn_intersection_measure
from faiss.contrib import datasets

# 64-dim vectors, 50000 vectors in the training, 100000 in database,
# 10000 in queries, dtype ('float32')
ds = datasets.SyntheticDataset(64, 50000, 100000, 10000)
d = 64 # dimension

# Constructing the refine PQ index with SQfp16 with index factory
index_fp16 = faiss.index_factory(d, 'PQ32x4fs,Refine(SQfp16)')
index_fp16.train(ds.get_train())
index_fp16.add(ds.get_database())

# Constructing the refine PQ index with SQ8
index_sq8 = faiss.index_factory(d, 'PQ32x4fs,Refine(SQ8)')
index_sq8.train(ds.get_train())
index_sq8.add(ds.get_database())

# Parameterization on k factor while doing search for index refinement
k_factor = 3.0
params = faiss.IndexRefineSearchParameters(k_factor=k_factor)

# Perform index search using different index refinement
D_fp16, I_fp16 = index_fp16.search(ds.get_queries(), 100, params=params)
D_sq8, I_sq8 = index_sq8.search(ds.get_queries(), 100, params=params)

# Calculating knn intersection measure for different index types on refinement
KIM_fp16 = knn_intersection_measure(I_fp16, ds.get_groundtruth())
KIM_sq8 = knn_intersection_measure(I_sq8, ds.get_groundtruth())

# KNN intersection measure accuracy shows that choosing SQ8 impacts accuracy
assert (KIM_fp16 > KIM_sq8)

print(I_sq8[:5])
print(I_fp16[:5])