Skip to content

Commit

Permalink
Use custom definitions for algorithms (#527)
Browse files Browse the repository at this point in the history
* Pass the value of the definitions argument down to use custom definitions for the algorithms

(cherry picked from commit c7aac44)

* fixed indentation

---------

Co-authored-by: tapas <tapas@pizzapatapa.me>
Co-authored-by: Martin Aumüller <maau@itu.dk>
  • Loading branch information
3 people authored May 24, 2024
1 parent 43d0538 commit fcdf494
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
18 changes: 11 additions & 7 deletions ann_benchmarks/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, D
print(f"Error loading YAML from {config_file}: {e}")
return configs

def _get_algorithm_definitions(point_type: str, distance_metric: str) -> Dict[str, Dict[str, Any]]:
def _get_algorithm_definitions(point_type: str, distance_metric: str, base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, Dict[str, Any]]:
"""Get algorithm definitions for a specific point type and distance metric.
A specific algorithm folder can have multiple algorithm definitions for a given point type and
Expand Down Expand Up @@ -188,7 +188,7 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str) -> Dict[st
}
```
"""
configs = load_configs(point_type)
configs = load_configs(point_type, base_dir)
definitions = {}

# param `_` is filename, not specific name
Expand Down Expand Up @@ -341,12 +341,16 @@ def create_definitions_from_algorithm(name: str, algo: Dict[str, Any], dimension
return definitions

def get_definitions(
dimension: int,
point_type: str = "float",
distance_metric: str = "euclidean",
count: int = 10
dimension: int,
point_type: str = "float",
distance_metric: str = "euclidean",
count: int = 10,
base_dir: str = "ann_benchmarks/algorithms"
) -> List[Definition]:
algorithm_definitions = _get_algorithm_definitions(point_type=point_type, distance_metric=distance_metric)
algorithm_definitions = _get_algorithm_definitions(point_type=point_type,
distance_metric=distance_metric,
base_dir=base_dir
)

definitions: List[Definition] = []

Expand Down
3 changes: 2 additions & 1 deletion ann_benchmarks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def main():
dimension=dimension,
point_type=dataset.attrs.get("point_type", "float"),
distance_metric=dataset.attrs["distance"],
count=args.count
count=args.count,
base_dir=args.definitions,
)
random.shuffle(definitions)

Expand Down

0 comments on commit fcdf494

Please sign in to comment.