diff --git a/examples/complete_example.py b/examples/complete_example.py index 2ca9f83..2b89c65 100644 --- a/examples/complete_example.py +++ b/examples/complete_example.py @@ -17,8 +17,9 @@ env = gym.make( "gym-generals-v0", # Environment name grid_factory=grid_factory, # Grid factory - agent=agent, # Your agent (used to get metadata like name and color) npc=npc, # NPC that will play against the agent + agent_id="Agent", # Agent ID + agent_color=(67, 70, 86), # Agent color render_mode="human", # "human" mode is for rendering, None is for no rendering ) diff --git a/examples/pettingzoo_example.py b/examples/pettingzoo_example.py index fa001bb..2badac9 100644 --- a/examples/pettingzoo_example.py +++ b/examples/pettingzoo_example.py @@ -11,7 +11,7 @@ } # Environment calls agents by name # Create environment -- render modes: {None, "human"} -env = gym.make("pz-generals-v0", agent_ids=list(agents.keys()), render_mode="human") +env = gym.make("pz-generals-v0", agents=list(agents.keys()), render_mode="human") observations, info = env.reset() done = False diff --git a/examples/record_replay_example.py b/examples/record_replay_example.py index faee6e0..f62f772 100644 --- a/examples/record_replay_example.py +++ b/examples/record_replay_example.py @@ -17,7 +17,8 @@ env = gym.make( "gym-generals-v0", # Environment name grid_factory=grid_factory, # Grid factory - agent=agent, # Your agent (used to get metadata like name and color) + agent_id="Agent", # Agent ID + agent_color=(67, 70, 86), # Agent color npc=npc, # NPC that will play against the agent ) diff --git a/generals/envs/env.py b/generals/envs/env.py index 73030ec..7ce4522 100644 --- a/generals/envs/env.py +++ b/generals/envs/env.py @@ -1,10 +1,5 @@ from .gymnasium_generals import GymnasiumGenerals from .pettingzoo_generals import PettingZooGenerals -from .gymnasium_wrappers import ( - NormalizeObservationWrapper, - RemoveActionMaskWrapper, - ObservationAsImageWrapper, -) from generals.agents import Agent, AgentFactory from generals import GridFactory @@ -22,13 +17,13 @@ def pz_generals_v0( grid_factory: GridFactory = GridFactory(), - agent_ids: list[str] = None, + agents: list[str] = None, render_mode=None, ): - assert len(agent_ids) == 2, "For now, only 2 agents are supported in PZ_Generals." + assert len(agents) == 2, "For now, only 2 agents are supported in PZ_Generals." env = PettingZooGenerals( grid_factory=grid_factory, - agent_ids=agent_ids, + agents=agents, render_mode=render_mode, ) return env @@ -54,7 +49,4 @@ def gym_generals_v0( agent_color=agent_color, reward_fn=reward_fn, ) - # env = NormalizeObservationWrapper(env) - # env = RemoveActionMaskWrapper(env) - # env = ObservationAsImageWrapper(env) return env diff --git a/generals/envs/gymnasium_generals.py b/generals/envs/gymnasium_generals.py index 9b9edcb..2602a84 100644 --- a/generals/envs/gymnasium_generals.py +++ b/generals/envs/gymnasium_generals.py @@ -144,4 +144,5 @@ def _default_reward( return reward def close(self) -> None: - self.gui.close() + if hasattr(self, "replay"): + self.gui.close() diff --git a/generals/envs/pettingzoo_generals.py b/generals/envs/pettingzoo_generals.py index 635f267..1cf610d 100644 --- a/generals/envs/pettingzoo_generals.py +++ b/generals/envs/pettingzoo_generals.py @@ -33,7 +33,7 @@ class PettingZooGenerals(pettingzoo.ParallelEnv): def __init__( self, grid_factory: GridFactory, - agent_ids: list[str], + agents: list[str], reward_fn: RewardFn = None, render_mode=None, ): @@ -46,10 +46,10 @@ def __init__( self.agent_data = { agent_id: {"color": color} - for agent_id, color in zip(agent_ids, self.default_colors) + for agent_id, color in zip(agents, self.default_colors) } - self.agent_ids = agent_ids - self.possible_agents = agent_ids + self.agents = agents + self.possible_agents = agents assert len(self.possible_agents) == len( set(self.possible_agents) @@ -83,7 +83,7 @@ def reset( else: grid = self.grid_factory.grid_from_generator(seed=seed) - self.game = Game(grid, self.agent_ids) + self.game = Game(grid, self.agents) if self.render_mode == "human": self.gui = GUI(self.game, self.agent_data, GuiMode.TRAIN) @@ -99,7 +99,7 @@ def reset( del self.replay observations = self.game.get_all_observations() - infos = {agent: {} for agent in self.agent_ids} + infos = {agent: {} for agent in self.agents} return observations, infos def step( @@ -112,9 +112,9 @@ def step( dict[AgentID, Info], ]: observations, infos = self.game.step(actions) - truncated = {agent: False for agent in self.agent_ids} # no truncation + truncated = {agent: False for agent in self.agents} # no truncation terminated = { - agent: True if self.game.is_done() else False for agent in self.agent_ids + agent: True if self.game.is_done() else False for agent in self.agents } rewards = { agent: self.reward_fn( @@ -123,7 +123,7 @@ def step( terminated[agent] or truncated[agent], infos[agent], ) - for agent in self.agent_ids + for agent in self.agents } if hasattr(self, "replay"): @@ -132,7 +132,7 @@ def step( # if any agent dies, all agents are terminated terminate = any(terminated.values()) if terminate: - self.agents_ids = [] + self.agents = [] if hasattr(self, "replay"): self.replay.store() return observations, rewards, terminated, truncated, infos @@ -154,4 +154,5 @@ def _default_reward( return reward def close(self) -> None: - self.gui.close() + if self.render_mode == "human": + self.gui.close() diff --git a/tests/parallel_api_check.py b/tests/parallel_api_check.py index a5b622e..f26988e 100644 --- a/tests/parallel_api_check.py +++ b/tests/parallel_api_check.py @@ -1,10 +1,9 @@ from __future__ import annotations -from generals import pz_generals -from generals.agents import RandomAgent -from generals.core.grid import GridFactory +from generals import GridFactory, AgentFactory import warnings import numpy as np +import gymnasium as gym from pettingzoo.test.api_test import missing_attr_warning from pettingzoo.utils.conversions import ( @@ -141,13 +140,13 @@ def parallel_api_test(par_env: ParallelEnv, num_cycles=1000): if __name__ == "__main__": mapper = GridFactory() - agent1 = RandomAgent(name="A") - agent2 = RandomAgent(name="B") + agent1 = AgentFactory.make_agent("expander", id="A") + agent2 = AgentFactory.make_agent("random", id="B") agents = { - agent1.name: agent1, - agent2.name: agent2, + agent1.id: agent1, + agent2.id: agent2, } - env = pz_generals(mapper, agents) + env = gym.make("pz-generals-v0", agents=list(agents.keys()), grid_factory=mapper) # test the environment with parallel_api_test import time start = time.time()