From 216896d4545696c745bf9bf220ce9b21652f37ec Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 30 Jan 2024 17:35:48 +0100 Subject: [PATCH 1/3] Evaluate functions run with torch.no_grad() by defeault --- sheeprl/utils/registry.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sheeprl/utils/registry.py b/sheeprl/utils/registry.py index aa876774..d8cec6c8 100644 --- a/sheeprl/utils/registry.py +++ b/sheeprl/utils/registry.py @@ -3,6 +3,8 @@ import sys from typing import Any, Callable, Dict, List +import torch + # Mapping of tasks with their relative algorithms. # A new task can be added as: # tasks[module] = [..., {"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] @@ -35,7 +37,13 @@ def _register_algorithm(fn: Callable[..., Any], decoupled: bool = False) -> Call return fn -def _register_evaluation(fn: Callable[..., Any], algorithms: str | List[str]) -> Callable[..., Any]: +def _register_evaluation( + fn: Callable[..., Any], algorithms: str | List[str], disable_grads: bool = True +) -> Callable[..., Any]: + # Disable gradient computation for evaluation + if disable_grads and torch.is_grad_enabled(): + fn = torch.no_grad(fn) + # lookup containing module if fn.__module__ == "__main__": return fn @@ -101,8 +109,8 @@ def inner_decorator(fn): return inner_decorator -def register_evaluation(algorithms: str | List[str]): +def register_evaluation(algorithms: str | List[str], disable_grads: bool = True): def inner_decorator(fn): - return _register_evaluation(fn, algorithms=algorithms) + return _register_evaluation(fn, algorithms=algorithms, disable_grads=disable_grads) return inner_decorator From b0ffa3651e70f1a2bfa458ea2f6a69b7c70fd92d Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 30 Jan 2024 17:41:06 +0100 Subject: [PATCH 2/3] Import user-defined evaluation file --- sheeprl/cli.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 802dbb5b..74e2775c 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -186,12 +186,14 @@ def eval_algorithm(cfg: DictConfig): # the entrypoint will be launched by Fabric with `fabric.launch(entrypoint)` module = None entrypoint = None + evaluation_file = None algo_name = cfg.algo.name for _module, _algos in evaluation_registry.items(): for _algo in _algos: if algo_name == _algo["name"]: module = _module entrypoint = _algo["entrypoint"] + evaluation_file = _algo["evaluation_file"] break if module is None: raise RuntimeError(f"Given the algorithm named `{algo_name}`, no module has been found to be imported.") @@ -200,7 +202,12 @@ def eval_algorithm(cfg: DictConfig): f"Given the module and algorithm named `{module}` and `{algo_name}` respectively, " "no entrypoint has been found to be imported." ) - task = importlib.import_module(f"{module}.evaluate") + if evaluation_file is None: + raise RuntimeError( + f"Given the module and algorithm named `{module}` and `{algo_name}` respectively, " + "no evaluation file has been found to be imported." + ) + task = importlib.import_module(f"{module}.{evaluation_file}") command = task.__dict__[entrypoint] fabric.launch(command, cfg, state) From dd41bb7bd182c24601067d0d7437c045ff65fc4b Mon Sep 17 00:00:00 2001 From: belerico Date: Mon, 5 Feb 2024 09:44:11 +0100 Subject: [PATCH 3/3] Disable grads during evaluation with user choosing --- sheeprl/cli.py | 3 +++ sheeprl/configs/eval_config.yaml | 1 + sheeprl/utils/registry.py | 14 +++----------- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 74e2775c..4861d827 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -6,6 +6,7 @@ from typing import Any, Dict import hydra +import torch from lightning import Fabric from lightning.fabric.strategies import STRATEGY_REGISTRY, DDPStrategy, SingleDeviceStrategy, Strategy from omegaconf import DictConfig, OmegaConf, open_dict @@ -209,6 +210,8 @@ def eval_algorithm(cfg: DictConfig): ) task = importlib.import_module(f"{module}.{evaluation_file}") command = task.__dict__[entrypoint] + if getattr(cfg, "disable_grads", True): + command = torch.no_grad(command) fabric.launch(command, cfg, state) diff --git a/sheeprl/configs/eval_config.yaml b/sheeprl/configs/eval_config.yaml index 2eeba61a..6487d7a3 100644 --- a/sheeprl/configs/eval_config.yaml +++ b/sheeprl/configs/eval_config.yaml @@ -17,4 +17,5 @@ fabric: env: capture_video: True +disable_grads: True checkpoint_path: ??? diff --git a/sheeprl/utils/registry.py b/sheeprl/utils/registry.py index d8cec6c8..aa876774 100644 --- a/sheeprl/utils/registry.py +++ b/sheeprl/utils/registry.py @@ -3,8 +3,6 @@ import sys from typing import Any, Callable, Dict, List -import torch - # Mapping of tasks with their relative algorithms. # A new task can be added as: # tasks[module] = [..., {"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] @@ -37,13 +35,7 @@ def _register_algorithm(fn: Callable[..., Any], decoupled: bool = False) -> Call return fn -def _register_evaluation( - fn: Callable[..., Any], algorithms: str | List[str], disable_grads: bool = True -) -> Callable[..., Any]: - # Disable gradient computation for evaluation - if disable_grads and torch.is_grad_enabled(): - fn = torch.no_grad(fn) - +def _register_evaluation(fn: Callable[..., Any], algorithms: str | List[str]) -> Callable[..., Any]: # lookup containing module if fn.__module__ == "__main__": return fn @@ -109,8 +101,8 @@ def inner_decorator(fn): return inner_decorator -def register_evaluation(algorithms: str | List[str], disable_grads: bool = True): +def register_evaluation(algorithms: str | List[str]): def inner_decorator(fn): - return _register_evaluation(fn, algorithms=algorithms, disable_grads=disable_grads) + return _register_evaluation(fn, algorithms=algorithms) return inner_decorator