Skip to content

Commit

Permalink
refactor: Return info dict right after reset
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Jan 3, 2025
1 parent 5caac61 commit 1156440
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions generals/envs/multiagent_gymnasium_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,22 @@ def reset(
elif hasattr(self, "replay"):
del self.replay

obs1 = self.game.agent_observation(self.agents[0]).as_tensor()
obs2 = self.game.agent_observation(self.agents[1]).as_tensor()
observations = np.stack([obs1, obs2], dtype=np.float32)
_obs = {agent: self.game.agent_observation(agent) for agent in self.agents}
observations = np.stack([_obs[agent].as_tensor() for agent in self.agents], dtype=np.float32)

info: dict[str, Any] = {}
return observations, info
infos: dict[str, Any] = self.game.get_infos()
# flatten infos
infos = {
agent: [
infos[agent]["army"],
infos[agent]["land"],
infos[agent]["is_done"],
infos[agent]["is_winner"],
compute_valid_move_mask(_obs[agent]),
]
for i, agent in enumerate(self.agents)
}
return observations, infos

def step(self, actions: list[Action]) -> tuple[Any, Any, bool, bool, dict[str, Any]]:
_actions = {
Expand Down

0 comments on commit 1156440

Please sign in to comment.