Skip to content

Commit

Permalink
fix: division by 0 (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi authored May 9, 2024
1 parent 304e931 commit 75d752f
Show file tree
Hide file tree
Showing 20 changed files with 53 additions and 46 deletions.
4 changes: 2 additions & 2 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -663,13 +663,13 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -661,13 +661,13 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,13 +691,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,13 +725,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,13 +710,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,13 +725,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,13 +874,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,13 +964,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
14 changes: 8 additions & 6 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,12 @@ def player(
# Sync timers
if not timer.disabled:
timer_metrics = timer.compute()
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/sps_env_interaction" in timer_metrics and timer_metrics["Time/sps_env_interaction"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down Expand Up @@ -563,7 +564,8 @@ def trainer(
# Sync distributed timers
if not timer.disabled:
timers = timer.compute()
metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]})
if "Time/train_time" in timers and timers["Time/train_time"] > 0:
metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]})
timer.reset()

# Send metrics to the player
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
14 changes: 8 additions & 6 deletions sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,12 @@ def player(
# Sync timers
if not timer.disabled:
timer_metrics = timer.compute()
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down Expand Up @@ -500,7 +501,8 @@ def trainer(
# Sync distributed timers
if not timer.disabled:
timers = timer.compute()
metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]})
if "Time/train_time" in timers and timers["Time/train_time"] > 0:
metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]})
timer.reset()

if global_rank == 1:
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
if "Time/train_time" in timer_metrics:
if "Time/train_time" in timer_metrics and timer_metrics["Time/train_time"] > 0:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
if "Time/env_interaction_time" in timer_metrics and timer_metrics["Time/env_interaction_time"] > 0:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def check_configs(cfg: Dict[str, Any]):
if cfg.algo.learning_starts is not None and cfg.algo.learning_starts < 0:
raise ValueError("The `algo.learning_starts` parameter must be greater or equal to zero.")

if cfg.env.action_repeat < 1:
cfg.env.action_repeat = 1


def check_configs_evaluation(cfg: DictConfig):
if cfg.float32_matmul_precision not in {"medium", "high", "highest"}:
Expand Down

0 comments on commit 75d752f

Please sign in to comment.