Skip to content

Commit

Permalink
improve infocallback, try lunarLander
Browse files Browse the repository at this point in the history
  • Loading branch information
valentincuzin committed Jan 3, 2025
1 parent 3af62e9 commit 6f46ea7
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 23 deletions.
6 changes: 3 additions & 3 deletions src/Environments/LunarLander.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, algo: Algo):
def __repr__(self):
return "LunarLander-v3"

def success_func(self, env: gym.Env, info: dict) -> bool:
def success_func(self, env: gym.Env, info: dict) -> tuple[bool|bool]:
"""
Cette fonction vérifie si le lander est "awake" et met à jour l'info.
"""
Expand All @@ -22,9 +22,9 @@ def success_func(self, env: gym.Env, info: dict) -> bool:

# check if the lander is awake
if hasattr(base_env, "lander") and not base_env.lander.awake:
return True
return True, False
else:
return False
return False, True

def objective_metric(self, states)-> list[dict[str, float]]:
pass # TODO
33 changes: 14 additions & 19 deletions src/PolicyTrainer/TrainingInfoCallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,29 @@ def __init__(self):

def _on_training_start(self):
"""Init at the begin of the training"""
self.num_envs = self.training_env.num_envs
self.current_episode_rewards = np.zeros(self.num_envs)
self.current_episode_lengths = np.zeros(self.num_envs, dtype=int)
self.current_episode_rewards = 0
self.current_episode_lengths = 0

def _on_step(self) -> bool:
"""call every steps"""
obs = self.locals["new_obs"]
rewards = self.locals["rewards"]
dones = self.locals["dones"]

self.current_episode_rewards += rewards
self.current_episode_rewards += rewards[0]
self.current_episode_lengths += 1

for i in range(self.num_envs):
if dones[i]:
self.training_metrics["episode_observations"].append(
obs
)
self.training_metrics["episode_rewards"].append(
self.current_episode_rewards[i]
)
self.training_metrics["episode_lengths"].append(
self.current_episode_lengths[i]
self.training_metrics["episode_observations"].append(
obs[0]
)

self.current_episode_rewards[i] = 0
self.current_episode_lengths[i] = 0

if dones[0]:
self.training_metrics["episode_rewards"].append(
self.current_episode_rewards
)
self.training_metrics["episode_lengths"].append(
self.current_episode_lengths
)
self.current_episode_rewards = 0
self.current_episode_lengths = 0
return True

def _on_training_end(self) -> None:
Expand Down
Loading

0 comments on commit 6f46ea7

Please sign in to comment.