Skip to content

Commit

Permalink
Allow search Index without Gt
Browse files Browse the repository at this point in the history
Summary: There are few fixes in this diff which allows us to execute search on an existing index without needing to compare it with ground truth. This has been currently added to only knn search (not range)

Differential Revision: D61825365
  • Loading branch information
kuarora authored and facebook-github-bot committed Sep 3, 2024
1 parent 501a8be commit 1bf383c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 39 deletions.
66 changes: 48 additions & 18 deletions benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor):
if hasattr(knn_desc, "index"):
return

if knn_desc.index_desc.index is not None:
if hasattr(knn_desc.index_desc, "index"):
knn_desc.index = knn_desc.index_desc.index
knn_desc.index.knn_name = knn_desc.get_name()
knn_desc.index.search_params = knn_desc.search_params
Expand All @@ -359,6 +359,7 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor):
metric=self.distance_metric,
bucket=knn_desc.index_desc.bucket,
index_path=knn_desc.index_desc.path,
index_name=knn_desc.index_desc.get_name(),
# knn_name=knn_desc.get_name(),
search_params=knn_desc.search_params,
)
Expand Down Expand Up @@ -544,15 +545,39 @@ def experiment(parameters, cost_metric, perf_metric):
def knn_search_benchmark(
self, dry_run, results: Dict[str, Any], knn_desc: KnnDescriptor
):
gt_knn_D = None
gt_knn_I = None
if hasattr(self, "gt_knn_D"):
gt_knn_D = self.gt_knn_D
gt_knn_I = self.gt_knn_I

if not knn_desc.index.is_flat_index() and gt_knn_I is None:
key = knn_desc.index.get_knn_search_name(
search_parameters=knn_desc.search_params,
query_vectors=knn_desc.query_dataset,
k=knn_desc.k,
reconstruct=False,
)
metrics, requires = knn_desc.index.knn_search(
dry_run,
knn_desc.search_params,
knn_desc.query_dataset,
knn_desc.k,
)[3:]
if requires is not None:
return results, requires
results["experiments"][key] = metrics
return results, requires

return self.search_benchmark(
name="knn_search",
search_func=lambda parameters: knn_desc.index.knn_search(
dry_run,
parameters,
knn_desc.query_dataset,
knn_desc.k,
self.gt_knn_I,
self.gt_knn_D,
gt_knn_I,
gt_knn_D,
)[3:],
key_func=lambda parameters: knn_desc.index.get_knn_search_name(
search_parameters=parameters,
Expand Down Expand Up @@ -634,6 +659,7 @@ class ExecutionOperator:
train_op: Optional[TrainOperator] = None
build_op: Optional[BuildOperator] = None
search_op: Optional[SearchOperator] = None
compute_gt: bool = True

def __post_init__(self):
if self.distance_metric == "IP":
Expand Down Expand Up @@ -698,16 +724,11 @@ def search_one(
faiss.omp_set_num_threads(self.num_threads)
assert self.search_op is not None

if not dry_run:
if not dry_run and self.compute_gt:
self.create_gt_knn(knn_desc)
self.create_range_ref_knn(knn_desc)

self.search_op.build_index_wrapper(knn_desc)
meta, requires = knn_desc.index.fetch_meta(dry_run=dry_run)
if requires is not None:
# return results, (requires if train else None)
return results, requires
results["indices"][knn_desc.index.get_codec_name()] = meta

# results, requires = self.reconstruct_benchmark(
# dry_run=True,
Expand Down Expand Up @@ -741,8 +762,8 @@ def search_one(
assert requires is None

if (
knn_desc.range_ref_index_desc is None or
not knn_desc.index.supports_range_search()
knn_desc.range_ref_index_desc is None
or not knn_desc.index.supports_range_search()
):
return results, None

Expand All @@ -766,9 +787,11 @@ def search_one(
ref_index_desc.search_params,
range_metric,
)
gt_rsm = self.search_op.range_ground_truth(
gt_radius, range_search_metric_function
)
gt_rsm = None
if self.compute_gt:
gt_rsm = self.search_op.range_ground_truth(
gt_radius, range_search_metric_function
)
results, requires = self.search_op.range_search_benchmark(
dry_run=True,
results=results,
Expand Down Expand Up @@ -847,9 +870,13 @@ def create_gt_knn(self, knn_desc, search=True) -> Optional[KnnDescriptor]:
if self.search_op:
gt_knn_desc = self.search_op.get_flat_desc(knn_desc.flat_name())
if gt_knn_desc is None:
gt_index_desc = self.build_op.get_flat_desc(
knn_desc.index_desc.flat_name()
)
if knn_desc.index_desc is not None:
gt_index_desc = knn_desc.gt_index_desc
else:
gt_index_desc = self.build_op.get_flat_desc(
knn_desc.index_desc.flat_name()
)
knn_desc.gt_index_desc = gt_index_desc
assert gt_index_desc is not None
gt_knn_desc = KnnDescriptor(
d=knn_desc.d,
Expand Down Expand Up @@ -933,7 +960,10 @@ def execute(self, results: Dict[str, Any], dry_run: False):
if self.search_op is not None:
for desc in self.search_op.knn_descs:
results, requires = self.search_one(
knn_desc=desc, results=results, dry_run=dry_run, range=self.search_op.range
knn_desc=desc,
results=results,
dry_run=dry_run,
range=self.search_op.range,
)
if dry_run:
if requires is None:
Expand Down
62 changes: 41 additions & 21 deletions benchs/bench_fw/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def knn_search(
D_gt=None,
):
logger.info("knn_search: begin")
if search_parameters is not None and search_parameters["snap"] == 1:
if search_parameters is not None and search_parameters.get("snap", 0) == 1:
query_vectors = self.snap(query_vectors)
filename = (
self.get_knn_search_name(search_parameters, query_vectors, k)
Expand Down Expand Up @@ -322,7 +322,11 @@ def knn_search(
else:
xq = self.io.get_dataset(query_vectors)
(D, I), t, _ = timer("knn_search", lambda: index.search(xq, k))
if self.is_flat() or not hasattr(self, "database_vectors"): # TODO
if (
self.is_flat()
or not hasattr(self, "database_vectors")
or (self.database_vectors is None)
): # TODO
R = D
else:
xq = self.io.get_dataset(query_vectors)
Expand All @@ -340,9 +344,7 @@ def knn_search(
"nlist": int(stats.nlist // repeat),
"ndis": int(stats.ndis // repeat),
"nheap_updates": int(stats.nheap_updates // repeat),
"quantization_time": int(
stats.quantization_time // repeat
),
"quantization_time": int(stats.quantization_time // repeat),
"search_time": int(stats.search_time // repeat),
}
self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P])
Expand All @@ -352,20 +354,24 @@ def knn_search(
"factory": self.get_model_name(),
"construction_params": self.get_construction_params(),
"search_params": search_parameters,
"knn_intersection": knn_intersection_measure(
I,
I_gt,
)
if I_gt is not None
else None,
"distance_ratio": distance_ratio_measure(
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None,
"knn_intersection": (
knn_intersection_measure(
I,
I_gt,
)
if I_gt is not None
else None
),
"distance_ratio": (
distance_ratio_measure(
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None
),
}
logger.info("knn_search: end")
return D, I, R, P, None
Expand Down Expand Up @@ -467,7 +473,7 @@ def range_search(
radius: Optional[float] = None,
):
logger.info("range_search: begin")
if search_parameters is not None and search_parameters.get("snap") == 1:
if search_parameters is not None and search_parameters.get("snap", 0) == 1:
query_vectors = self.snap(query_vectors)
filename = (
self.get_range_search_name(
Expand Down Expand Up @@ -607,6 +613,12 @@ def get_codec(self):
Index.cached_codec.popitem(last=False)
return Index.cached_codec[codec_name]

def get_model(self):
return self.get_index()

def get_model_name(self):
return self.get_index_name()

def get_codec_name(self) -> Optional[str]:
return self.codec_name

Expand Down Expand Up @@ -709,6 +721,11 @@ def get_operating_points(self):
def add_range_or_val(name, range):
op.add_range(
name,
(
[self.search_params[name]]
if self.search_params and name in self.search_params
else range
),
[self.search_params[name]]
if self.search_params and name in self.search_params
else range,
Expand Down Expand Up @@ -808,7 +825,10 @@ def get_pretransform(self):
return quantizer

def get_model_name(self):
return os.path.basename(self.path)
if self.path is not None:
return os.path.basename(self.path)
else:
return self.get_codec_name()

def fetch_meta(self, dry_run=False):
return None, None
Expand Down

0 comments on commit 1bf383c

Please sign in to comment.