diff --git a/ann_benchmarks/definitions.py b/ann_benchmarks/definitions.py index 24f1a231..ce4c3b8f 100644 --- a/ann_benchmarks/definitions.py +++ b/ann_benchmarks/definitions.py @@ -145,16 +145,15 @@ def load_configs(point_type: str, base_dir: str = "ann_benchmarks/algorithms") - print(f"Error loading YAML from {config_file}: {e}") return configs -def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, Dict[str, Any]]: - """Load algorithm configurations for a given point_type.""" +def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> List[Dict[str, Any]]: + """Load algorithm configurations.""" config_files = get_config_files(base_dir=base_dir) - configs = {} + configs = [] for config_file in config_files: with open(config_file, 'r') as stream: try: config_data = yaml.safe_load(stream) - algorithm_name = os.path.basename(os.path.dirname(config_file)) - configs[algorithm_name] = config_data + configs.append(config_data) except yaml.YAMLError as e: print(f"Error loading YAML from {config_file}: {e}") return configs @@ -211,16 +210,27 @@ def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None: base_dir (str, optional): The base directory where the algorithms are stored. Defaults to "ann_benchmarks/algorithms". """ - definitions = _get_definitions(base_dir) - - print("The following algorithms are supported...", definitions) - for algorithm in definitions: + all_configs = _get_definitions(base_dir) + data = {} + for algo_configs in all_configs: + for point_type, config_for_point_type in algo_configs.items(): + for metric, ccc in config_for_point_type.items(): + algo_name = ccc[0]["name"] + if algo_name not in data: + data[algo_name] = {} + if point_type not in data[algo_name]: + data[algo_name][point_type] = [] + data[algo_name][point_type].append(metric) + + print("The following algorithms are supported:", ", ".join(data)) + print("Details of supported metrics and data types: ") + for algorithm in data: print('\t... for the algorithm "%s"...' % algorithm) - for point_type in definitions[algorithm]: + for point_type in data[algorithm]: print('\t\t... and the point type "%s", metrics: ' % point_type) - for metric in definitions[algorithm][point_type]: + for metric in data[algorithm][point_type]: print("\t\t\t%s" % metric)