From 1bf383cc335a6e7d7b0425ca5d49b0c27de8e9d7 Mon Sep 17 00:00:00 2001 From: Kumar Saurabh Arora Date: Tue, 3 Sep 2024 14:15:28 -0700 Subject: [PATCH] Allow search Index without Gt Summary: There are few fixes in this diff which allows us to execute search on an existing index without needing to compare it with ground truth. This has been currently added to only knn search (not range) Differential Revision: D61825365 --- benchs/bench_fw/benchmark.py | 66 ++++++++++++++++++++++++++---------- benchs/bench_fw/index.py | 62 +++++++++++++++++++++------------ 2 files changed, 89 insertions(+), 39 deletions(-) diff --git a/benchs/bench_fw/benchmark.py b/benchs/bench_fw/benchmark.py index e6220330b7..8e4d4d1d6a 100644 --- a/benchs/bench_fw/benchmark.py +++ b/benchs/bench_fw/benchmark.py @@ -348,7 +348,7 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor): if hasattr(knn_desc, "index"): return - if knn_desc.index_desc.index is not None: + if hasattr(knn_desc.index_desc, "index"): knn_desc.index = knn_desc.index_desc.index knn_desc.index.knn_name = knn_desc.get_name() knn_desc.index.search_params = knn_desc.search_params @@ -359,6 +359,7 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor): metric=self.distance_metric, bucket=knn_desc.index_desc.bucket, index_path=knn_desc.index_desc.path, + index_name=knn_desc.index_desc.get_name(), # knn_name=knn_desc.get_name(), search_params=knn_desc.search_params, ) @@ -544,6 +545,30 @@ def experiment(parameters, cost_metric, perf_metric): def knn_search_benchmark( self, dry_run, results: Dict[str, Any], knn_desc: KnnDescriptor ): + gt_knn_D = None + gt_knn_I = None + if hasattr(self, "gt_knn_D"): + gt_knn_D = self.gt_knn_D + gt_knn_I = self.gt_knn_I + + if not knn_desc.index.is_flat_index() and gt_knn_I is None: + key = knn_desc.index.get_knn_search_name( + search_parameters=knn_desc.search_params, + query_vectors=knn_desc.query_dataset, + k=knn_desc.k, + reconstruct=False, + ) + metrics, requires = knn_desc.index.knn_search( + dry_run, + knn_desc.search_params, + knn_desc.query_dataset, + knn_desc.k, + )[3:] + if requires is not None: + return results, requires + results["experiments"][key] = metrics + return results, requires + return self.search_benchmark( name="knn_search", search_func=lambda parameters: knn_desc.index.knn_search( @@ -551,8 +576,8 @@ def knn_search_benchmark( parameters, knn_desc.query_dataset, knn_desc.k, - self.gt_knn_I, - self.gt_knn_D, + gt_knn_I, + gt_knn_D, )[3:], key_func=lambda parameters: knn_desc.index.get_knn_search_name( search_parameters=parameters, @@ -634,6 +659,7 @@ class ExecutionOperator: train_op: Optional[TrainOperator] = None build_op: Optional[BuildOperator] = None search_op: Optional[SearchOperator] = None + compute_gt: bool = True def __post_init__(self): if self.distance_metric == "IP": @@ -698,16 +724,11 @@ def search_one( faiss.omp_set_num_threads(self.num_threads) assert self.search_op is not None - if not dry_run: + if not dry_run and self.compute_gt: self.create_gt_knn(knn_desc) self.create_range_ref_knn(knn_desc) self.search_op.build_index_wrapper(knn_desc) - meta, requires = knn_desc.index.fetch_meta(dry_run=dry_run) - if requires is not None: - # return results, (requires if train else None) - return results, requires - results["indices"][knn_desc.index.get_codec_name()] = meta # results, requires = self.reconstruct_benchmark( # dry_run=True, @@ -741,8 +762,8 @@ def search_one( assert requires is None if ( - knn_desc.range_ref_index_desc is None or - not knn_desc.index.supports_range_search() + knn_desc.range_ref_index_desc is None + or not knn_desc.index.supports_range_search() ): return results, None @@ -766,9 +787,11 @@ def search_one( ref_index_desc.search_params, range_metric, ) - gt_rsm = self.search_op.range_ground_truth( - gt_radius, range_search_metric_function - ) + gt_rsm = None + if self.compute_gt: + gt_rsm = self.search_op.range_ground_truth( + gt_radius, range_search_metric_function + ) results, requires = self.search_op.range_search_benchmark( dry_run=True, results=results, @@ -847,9 +870,13 @@ def create_gt_knn(self, knn_desc, search=True) -> Optional[KnnDescriptor]: if self.search_op: gt_knn_desc = self.search_op.get_flat_desc(knn_desc.flat_name()) if gt_knn_desc is None: - gt_index_desc = self.build_op.get_flat_desc( - knn_desc.index_desc.flat_name() - ) + if knn_desc.index_desc is not None: + gt_index_desc = knn_desc.gt_index_desc + else: + gt_index_desc = self.build_op.get_flat_desc( + knn_desc.index_desc.flat_name() + ) + knn_desc.gt_index_desc = gt_index_desc assert gt_index_desc is not None gt_knn_desc = KnnDescriptor( d=knn_desc.d, @@ -933,7 +960,10 @@ def execute(self, results: Dict[str, Any], dry_run: False): if self.search_op is not None: for desc in self.search_op.knn_descs: results, requires = self.search_one( - knn_desc=desc, results=results, dry_run=dry_run, range=self.search_op.range + knn_desc=desc, + results=results, + dry_run=dry_run, + range=self.search_op.range, ) if dry_run: if requires is None: diff --git a/benchs/bench_fw/index.py b/benchs/bench_fw/index.py index 090722f54a..72d5af8df2 100644 --- a/benchs/bench_fw/index.py +++ b/benchs/bench_fw/index.py @@ -274,7 +274,7 @@ def knn_search( D_gt=None, ): logger.info("knn_search: begin") - if search_parameters is not None and search_parameters["snap"] == 1: + if search_parameters is not None and search_parameters.get("snap", 0) == 1: query_vectors = self.snap(query_vectors) filename = ( self.get_knn_search_name(search_parameters, query_vectors, k) @@ -322,7 +322,11 @@ def knn_search( else: xq = self.io.get_dataset(query_vectors) (D, I), t, _ = timer("knn_search", lambda: index.search(xq, k)) - if self.is_flat() or not hasattr(self, "database_vectors"): # TODO + if ( + self.is_flat() + or not hasattr(self, "database_vectors") + or (self.database_vectors is None) + ): # TODO R = D else: xq = self.io.get_dataset(query_vectors) @@ -340,9 +344,7 @@ def knn_search( "nlist": int(stats.nlist // repeat), "ndis": int(stats.ndis // repeat), "nheap_updates": int(stats.nheap_updates // repeat), - "quantization_time": int( - stats.quantization_time // repeat - ), + "quantization_time": int(stats.quantization_time // repeat), "search_time": int(stats.search_time // repeat), } self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P]) @@ -352,20 +354,24 @@ def knn_search( "factory": self.get_model_name(), "construction_params": self.get_construction_params(), "search_params": search_parameters, - "knn_intersection": knn_intersection_measure( - I, - I_gt, - ) - if I_gt is not None - else None, - "distance_ratio": distance_ratio_measure( - I, - R, - D_gt, - self.metric_type, - ) - if D_gt is not None - else None, + "knn_intersection": ( + knn_intersection_measure( + I, + I_gt, + ) + if I_gt is not None + else None + ), + "distance_ratio": ( + distance_ratio_measure( + I, + R, + D_gt, + self.metric_type, + ) + if D_gt is not None + else None + ), } logger.info("knn_search: end") return D, I, R, P, None @@ -467,7 +473,7 @@ def range_search( radius: Optional[float] = None, ): logger.info("range_search: begin") - if search_parameters is not None and search_parameters.get("snap") == 1: + if search_parameters is not None and search_parameters.get("snap", 0) == 1: query_vectors = self.snap(query_vectors) filename = ( self.get_range_search_name( @@ -607,6 +613,12 @@ def get_codec(self): Index.cached_codec.popitem(last=False) return Index.cached_codec[codec_name] + def get_model(self): + return self.get_index() + + def get_model_name(self): + return self.get_index_name() + def get_codec_name(self) -> Optional[str]: return self.codec_name @@ -709,6 +721,11 @@ def get_operating_points(self): def add_range_or_val(name, range): op.add_range( name, + ( + [self.search_params[name]] + if self.search_params and name in self.search_params + else range + ), [self.search_params[name]] if self.search_params and name in self.search_params else range, @@ -808,7 +825,10 @@ def get_pretransform(self): return quantizer def get_model_name(self): - return os.path.basename(self.path) + if self.path is not None: + return os.path.basename(self.path) + else: + return self.get_codec_name() def fetch_meta(self, dry_run=False): return None, None