Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let the user choose for num_threads and matmul precision #203

Merged
merged 1 commit into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions sheeprl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def run_algorithm(cfg: Dict[str, Any]):
Args:
cfg (Dict[str, Any]): the loaded configuration.
"""

# Torch settings
os.environ["OMP_NUM_THREADS"] = str(cfg.num_threads)
torch.set_float32_matmul_precision(cfg.float32_matmul_precision)

# Given the algorithm's name, retrieve the module where
# 'cfg.algo.name'.py is contained; from there retrieve the
# 'register_algorithm'-decorated entrypoint;
Expand Down Expand Up @@ -172,6 +177,10 @@ def eval_algorithm(cfg: DictConfig):
"""
cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True))

# Torch settings
os.environ["OMP_NUM_THREADS"] = str(cfg.num_threads)
torch.set_float32_matmul_precision(cfg.float32_matmul_precision)

# TODO: change the number of devices when FSDP will be supported
accelerator = cfg.fabric.get("accelerator", "auto")
fabric: Fabric = hydra.utils.instantiate(
Expand Down Expand Up @@ -224,6 +233,12 @@ def check_configs(cfg: Dict[str, Any]):
Args:
cfg (Dict[str, Any]): the loaded configuration to check.
"""
if cfg.float32_matmul_precision not in {"medium", "high", "highest"}:
raise ValueError(
f"Invalid value '{cfg.float32_matmul_precision}' for the 'float32_matmul_precision' parameter. "
"It must be one of 'medium', 'high' or 'highest'."
)

decoupled = False
algo_name = cfg.algo.name
for _, _algos in algorithm_registry.items():
Expand Down Expand Up @@ -286,6 +301,12 @@ def check_configs(cfg: Dict[str, Any]):


def check_configs_evaluation(cfg: DictConfig):
if cfg.float32_matmul_precision not in {"medium", "high", "highest"}:
raise ValueError(
f"Invalid value '{cfg.float32_matmul_precision}' for the 'float32_matmul_precision' parameter. "
"It must be one of 'medium', 'high' or 'highest'."
)

if cfg.checkpoint_path is None:
raise ValueError("You must specify the evaluation checkpoint path")

Expand Down
1 change: 1 addition & 0 deletions sheeprl/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ defaults:
- exp: ???

num_threads: 1
float32_matmul_precision: "high"

# Set it to True to run a single optimization step
dry_run: False
Expand Down
2 changes: 2 additions & 0 deletions sheeprl/configs/eval_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ env:
capture_video: True

seed: null
num_threads: 1
disable_grads: True
checkpoint_path: ???
float32_matmul_precision: "high"
Loading