diff --git a/benchs/bench_fw/benchmark.py b/benchs/bench_fw/benchmark.py index e6220330b7..8e4d4d1d6a 100644 --- a/benchs/bench_fw/benchmark.py +++ b/benchs/bench_fw/benchmark.py @@ -348,7 +348,7 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor): if hasattr(knn_desc, "index"): return - if knn_desc.index_desc.index is not None: + if hasattr(knn_desc.index_desc, "index"): knn_desc.index = knn_desc.index_desc.index knn_desc.index.knn_name = knn_desc.get_name() knn_desc.index.search_params = knn_desc.search_params @@ -359,6 +359,7 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor): metric=self.distance_metric, bucket=knn_desc.index_desc.bucket, index_path=knn_desc.index_desc.path, + index_name=knn_desc.index_desc.get_name(), # knn_name=knn_desc.get_name(), search_params=knn_desc.search_params, ) @@ -544,6 +545,30 @@ def experiment(parameters, cost_metric, perf_metric): def knn_search_benchmark( self, dry_run, results: Dict[str, Any], knn_desc: KnnDescriptor ): + gt_knn_D = None + gt_knn_I = None + if hasattr(self, "gt_knn_D"): + gt_knn_D = self.gt_knn_D + gt_knn_I = self.gt_knn_I + + if not knn_desc.index.is_flat_index() and gt_knn_I is None: + key = knn_desc.index.get_knn_search_name( + search_parameters=knn_desc.search_params, + query_vectors=knn_desc.query_dataset, + k=knn_desc.k, + reconstruct=False, + ) + metrics, requires = knn_desc.index.knn_search( + dry_run, + knn_desc.search_params, + knn_desc.query_dataset, + knn_desc.k, + )[3:] + if requires is not None: + return results, requires + results["experiments"][key] = metrics + return results, requires + return self.search_benchmark( name="knn_search", search_func=lambda parameters: knn_desc.index.knn_search( @@ -551,8 +576,8 @@ def knn_search_benchmark( parameters, knn_desc.query_dataset, knn_desc.k, - self.gt_knn_I, - self.gt_knn_D, + gt_knn_I, + gt_knn_D, )[3:], key_func=lambda parameters: knn_desc.index.get_knn_search_name( search_parameters=parameters, @@ -634,6 +659,7 @@ class ExecutionOperator: train_op: Optional[TrainOperator] = None build_op: Optional[BuildOperator] = None search_op: Optional[SearchOperator] = None + compute_gt: bool = True def __post_init__(self): if self.distance_metric == "IP": @@ -698,16 +724,11 @@ def search_one( faiss.omp_set_num_threads(self.num_threads) assert self.search_op is not None - if not dry_run: + if not dry_run and self.compute_gt: self.create_gt_knn(knn_desc) self.create_range_ref_knn(knn_desc) self.search_op.build_index_wrapper(knn_desc) - meta, requires = knn_desc.index.fetch_meta(dry_run=dry_run) - if requires is not None: - # return results, (requires if train else None) - return results, requires - results["indices"][knn_desc.index.get_codec_name()] = meta # results, requires = self.reconstruct_benchmark( # dry_run=True, @@ -741,8 +762,8 @@ def search_one( assert requires is None if ( - knn_desc.range_ref_index_desc is None or - not knn_desc.index.supports_range_search() + knn_desc.range_ref_index_desc is None + or not knn_desc.index.supports_range_search() ): return results, None @@ -766,9 +787,11 @@ def search_one( ref_index_desc.search_params, range_metric, ) - gt_rsm = self.search_op.range_ground_truth( - gt_radius, range_search_metric_function - ) + gt_rsm = None + if self.compute_gt: + gt_rsm = self.search_op.range_ground_truth( + gt_radius, range_search_metric_function + ) results, requires = self.search_op.range_search_benchmark( dry_run=True, results=results, @@ -847,9 +870,13 @@ def create_gt_knn(self, knn_desc, search=True) -> Optional[KnnDescriptor]: if self.search_op: gt_knn_desc = self.search_op.get_flat_desc(knn_desc.flat_name()) if gt_knn_desc is None: - gt_index_desc = self.build_op.get_flat_desc( - knn_desc.index_desc.flat_name() - ) + if knn_desc.index_desc is not None: + gt_index_desc = knn_desc.gt_index_desc + else: + gt_index_desc = self.build_op.get_flat_desc( + knn_desc.index_desc.flat_name() + ) + knn_desc.gt_index_desc = gt_index_desc assert gt_index_desc is not None gt_knn_desc = KnnDescriptor( d=knn_desc.d, @@ -933,7 +960,10 @@ def execute(self, results: Dict[str, Any], dry_run: False): if self.search_op is not None: for desc in self.search_op.knn_descs: results, requires = self.search_one( - knn_desc=desc, results=results, dry_run=dry_run, range=self.search_op.range + knn_desc=desc, + results=results, + dry_run=dry_run, + range=self.search_op.range, ) if dry_run: if requires is None: diff --git a/benchs/bench_fw/index.py b/benchs/bench_fw/index.py index 090722f54a..72d5af8df2 100644 --- a/benchs/bench_fw/index.py +++ b/benchs/bench_fw/index.py @@ -274,7 +274,7 @@ def knn_search( D_gt=None, ): logger.info("knn_search: begin") - if search_parameters is not None and search_parameters["snap"] == 1: + if search_parameters is not None and search_parameters.get("snap", 0) == 1: query_vectors = self.snap(query_vectors) filename = ( self.get_knn_search_name(search_parameters, query_vectors, k) @@ -322,7 +322,11 @@ def knn_search( else: xq = self.io.get_dataset(query_vectors) (D, I), t, _ = timer("knn_search", lambda: index.search(xq, k)) - if self.is_flat() or not hasattr(self, "database_vectors"): # TODO + if ( + self.is_flat() + or not hasattr(self, "database_vectors") + or (self.database_vectors is None) + ): # TODO R = D else: xq = self.io.get_dataset(query_vectors) @@ -340,9 +344,7 @@ def knn_search( "nlist": int(stats.nlist // repeat), "ndis": int(stats.ndis // repeat), "nheap_updates": int(stats.nheap_updates // repeat), - "quantization_time": int( - stats.quantization_time // repeat - ), + "quantization_time": int(stats.quantization_time // repeat), "search_time": int(stats.search_time // repeat), } self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P]) @@ -352,20 +354,24 @@ def knn_search( "factory": self.get_model_name(), "construction_params": self.get_construction_params(), "search_params": search_parameters, - "knn_intersection": knn_intersection_measure( - I, - I_gt, - ) - if I_gt is not None - else None, - "distance_ratio": distance_ratio_measure( - I, - R, - D_gt, - self.metric_type, - ) - if D_gt is not None - else None, + "knn_intersection": ( + knn_intersection_measure( + I, + I_gt, + ) + if I_gt is not None + else None + ), + "distance_ratio": ( + distance_ratio_measure( + I, + R, + D_gt, + self.metric_type, + ) + if D_gt is not None + else None + ), } logger.info("knn_search: end") return D, I, R, P, None @@ -467,7 +473,7 @@ def range_search( radius: Optional[float] = None, ): logger.info("range_search: begin") - if search_parameters is not None and search_parameters.get("snap") == 1: + if search_parameters is not None and search_parameters.get("snap", 0) == 1: query_vectors = self.snap(query_vectors) filename = ( self.get_range_search_name( @@ -607,6 +613,12 @@ def get_codec(self): Index.cached_codec.popitem(last=False) return Index.cached_codec[codec_name] + def get_model(self): + return self.get_index() + + def get_model_name(self): + return self.get_index_name() + def get_codec_name(self) -> Optional[str]: return self.codec_name @@ -709,6 +721,11 @@ def get_operating_points(self): def add_range_or_val(name, range): op.add_range( name, + ( + [self.search_params[name]] + if self.search_params and name in self.search_params + else range + ), [self.search_params[name]] if self.search_params and name in self.search_params else range, @@ -808,7 +825,10 @@ def get_pretransform(self): return quantizer def get_model_name(self): - return os.path.basename(self.path) + if self.path is not None: + return os.path.basename(self.path) + else: + return self.get_codec_name() def fetch_meta(self, dry_run=False): return None, None