diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 74e2775c..4861d827 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -6,6 +6,7 @@ from typing import Any, Dict import hydra +import torch from lightning import Fabric from lightning.fabric.strategies import STRATEGY_REGISTRY, DDPStrategy, SingleDeviceStrategy, Strategy from omegaconf import DictConfig, OmegaConf, open_dict @@ -209,6 +210,8 @@ def eval_algorithm(cfg: DictConfig): ) task = importlib.import_module(f"{module}.{evaluation_file}") command = task.__dict__[entrypoint] + if getattr(cfg, "disable_grads", True): + command = torch.no_grad(command) fabric.launch(command, cfg, state) diff --git a/sheeprl/configs/eval_config.yaml b/sheeprl/configs/eval_config.yaml index 2eeba61a..6487d7a3 100644 --- a/sheeprl/configs/eval_config.yaml +++ b/sheeprl/configs/eval_config.yaml @@ -17,4 +17,5 @@ fabric: env: capture_video: True +disable_grads: True checkpoint_path: ???