Skip to content

Commit

Permalink
Import user-defined evaluation file (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico authored Jan 31, 2024
1 parent 030cdd5 commit cd8a84c
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion sheeprl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,14 @@ def eval_algorithm(cfg: DictConfig):
# the entrypoint will be launched by Fabric with `fabric.launch(entrypoint)`
module = None
entrypoint = None
evaluation_file = None
algo_name = cfg.algo.name
for _module, _algos in evaluation_registry.items():
for _algo in _algos:
if algo_name == _algo["name"]:
module = _module
entrypoint = _algo["entrypoint"]
evaluation_file = _algo["evaluation_file"]
break
if module is None:
raise RuntimeError(f"Given the algorithm named `{algo_name}`, no module has been found to be imported.")
Expand All @@ -200,7 +202,12 @@ def eval_algorithm(cfg: DictConfig):
f"Given the module and algorithm named `{module}` and `{algo_name}` respectively, "
"no entrypoint has been found to be imported."
)
task = importlib.import_module(f"{module}.evaluate")
if evaluation_file is None:
raise RuntimeError(
f"Given the module and algorithm named `{module}` and `{algo_name}` respectively, "
"no evaluation file has been found to be imported."
)
task = importlib.import_module(f"{module}.{evaluation_file}")
command = task.__dict__[entrypoint]
fabric.launch(command, cfg, state)

Expand Down

0 comments on commit cd8a84c

Please sign in to comment.