From fb9483615a9e82c3f1eb6c4692ec25b88c697b1a Mon Sep 17 00:00:00 2001 From: Martin Aumueller Date: Tue, 21 Jan 2025 12:52:26 +0100 Subject: [PATCH 1/2] fix --list-algorithms using path names instead of algorithm names (fixes #555) --- ann_benchmarks/definitions.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/ann_benchmarks/definitions.py b/ann_benchmarks/definitions.py index 24f1a2312..1ee72e704 100644 --- a/ann_benchmarks/definitions.py +++ b/ann_benchmarks/definitions.py @@ -145,16 +145,15 @@ def load_configs(point_type: str, base_dir: str = "ann_benchmarks/algorithms") - print(f"Error loading YAML from {config_file}: {e}") return configs -def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, Dict[str, Any]]: - """Load algorithm configurations for a given point_type.""" +def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> List[Dict[str, Any]]: + """Load algorithm configurations.""" config_files = get_config_files(base_dir=base_dir) - configs = {} + configs = [] for config_file in config_files: with open(config_file, 'r') as stream: try: config_data = yaml.safe_load(stream) - algorithm_name = os.path.basename(os.path.dirname(config_file)) - configs[algorithm_name] = config_data + configs.append(config_data) except yaml.YAMLError as e: print(f"Error loading YAML from {config_file}: {e}") return configs @@ -211,16 +210,27 @@ def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None: base_dir (str, optional): The base directory where the algorithms are stored. Defaults to "ann_benchmarks/algorithms". """ - definitions = _get_definitions(base_dir) - - print("The following algorithms are supported...", definitions) - for algorithm in definitions: + all_configs = _get_definitions(base_dir) + data = {} + for algo_configs in all_configs: + for point_type, config_for_point_type in algo_configs.items(): + for metric, ccc in config_for_point_type.items(): + algo_name = ccc[0]["name"] + if algo_name not in data: + data[algo_name] = {} + if point_type not in data[algo_name]: + data[algo_name][point_type] = [] + data[algo_name][point_type].append(metric) + + print("The following algorithms are supported:", ", ".join(data)) + print("Details of support metrics and data types: ") + for algorithm in data: print('\t... for the algorithm "%s"...' % algorithm) - for point_type in definitions[algorithm]: + for point_type in data[algorithm]: print('\t\t... and the point type "%s", metrics: ' % point_type) - for metric in definitions[algorithm][point_type]: + for metric in data[algorithm][point_type]: print("\t\t\t%s" % metric) From 1b31d6b521143c7c2f23d0fea4a990199cdad2d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Aum=C3=BCller?= Date: Tue, 21 Jan 2025 13:06:13 +0100 Subject: [PATCH 2/2] Fixed typo. --- ann_benchmarks/definitions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ann_benchmarks/definitions.py b/ann_benchmarks/definitions.py index 1ee72e704..ce4c3b8f7 100644 --- a/ann_benchmarks/definitions.py +++ b/ann_benchmarks/definitions.py @@ -223,7 +223,7 @@ def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None: data[algo_name][point_type].append(metric) print("The following algorithms are supported:", ", ".join(data)) - print("Details of support metrics and data types: ") + print("Details of supported metrics and data types: ") for algorithm in data: print('\t... for the algorithm "%s"...' % algorithm)