diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 9360c578..8eeef18d 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -53,6 +53,11 @@ def run_algorithm(cfg: Dict[str, Any]): Args: cfg (Dict[str, Any]): the loaded configuration. """ + + # Torch settings + os.environ["OMP_NUM_THREADS"] = str(cfg.num_threads) + torch.set_float32_matmul_precision(cfg.float32_matmul_precision) + # Given the algorithm's name, retrieve the module where # 'cfg.algo.name'.py is contained; from there retrieve the # 'register_algorithm'-decorated entrypoint; @@ -172,6 +177,10 @@ def eval_algorithm(cfg: DictConfig): """ cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) + # Torch settings + os.environ["OMP_NUM_THREADS"] = str(cfg.num_threads) + torch.set_float32_matmul_precision(cfg.float32_matmul_precision) + # TODO: change the number of devices when FSDP will be supported accelerator = cfg.fabric.get("accelerator", "auto") fabric: Fabric = hydra.utils.instantiate( @@ -224,6 +233,12 @@ def check_configs(cfg: Dict[str, Any]): Args: cfg (Dict[str, Any]): the loaded configuration to check. """ + if cfg.float32_matmul_precision not in {"medium", "high", "highest"}: + raise ValueError( + f"Invalid value '{cfg.float32_matmul_precision}' for the 'float32_matmul_precision' parameter. " + "It must be one of 'medium', 'high' or 'highest'." + ) + decoupled = False algo_name = cfg.algo.name for _, _algos in algorithm_registry.items(): @@ -286,6 +301,12 @@ def check_configs(cfg: Dict[str, Any]): def check_configs_evaluation(cfg: DictConfig): + if cfg.float32_matmul_precision not in {"medium", "high", "highest"}: + raise ValueError( + f"Invalid value '{cfg.float32_matmul_precision}' for the 'float32_matmul_precision' parameter. " + "It must be one of 'medium', 'high' or 'highest'." + ) + if cfg.checkpoint_path is None: raise ValueError("You must specify the evaluation checkpoint path") diff --git a/sheeprl/configs/config.yaml b/sheeprl/configs/config.yaml index a9e795cf..0a743380 100644 --- a/sheeprl/configs/config.yaml +++ b/sheeprl/configs/config.yaml @@ -15,6 +15,7 @@ defaults: - exp: ??? num_threads: 1 +float32_matmul_precision: "high" # Set it to True to run a single optimization step dry_run: False diff --git a/sheeprl/configs/eval_config.yaml b/sheeprl/configs/eval_config.yaml index bc1b4df6..52d99423 100644 --- a/sheeprl/configs/eval_config.yaml +++ b/sheeprl/configs/eval_config.yaml @@ -18,5 +18,7 @@ env: capture_video: True seed: null +num_threads: 1 disable_grads: True checkpoint_path: ??? +float32_matmul_precision: "high" \ No newline at end of file