Skip to content

Commit

Permalink
refactor: Fix some naming stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 12, 2024
1 parent 03465b8 commit acf8fd5
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 34 deletions.
3 changes: 2 additions & 1 deletion examples/complete_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
env = gym.make(
"gym-generals-v0", # Environment name
grid_factory=grid_factory, # Grid factory
agent=agent, # Your agent (used to get metadata like name and color)
npc=npc, # NPC that will play against the agent
agent_id="Agent", # Agent ID
agent_color=(67, 70, 86), # Agent color
render_mode="human", # "human" mode is for rendering, None is for no rendering
)

Expand Down
2 changes: 1 addition & 1 deletion examples/pettingzoo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
} # Environment calls agents by name

# Create environment -- render modes: {None, "human"}
env = gym.make("pz-generals-v0", agent_ids=list(agents.keys()), render_mode="human")
env = gym.make("pz-generals-v0", agents=list(agents.keys()), render_mode="human")
observations, info = env.reset()

done = False
Expand Down
3 changes: 2 additions & 1 deletion examples/record_replay_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
env = gym.make(
"gym-generals-v0", # Environment name
grid_factory=grid_factory, # Grid factory
agent=agent, # Your agent (used to get metadata like name and color)
agent_id="Agent", # Agent ID
agent_color=(67, 70, 86), # Agent color
npc=npc, # NPC that will play against the agent
)

Expand Down
14 changes: 3 additions & 11 deletions generals/envs/env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from .gymnasium_generals import GymnasiumGenerals
from .pettingzoo_generals import PettingZooGenerals
from .gymnasium_wrappers import (
NormalizeObservationWrapper,
RemoveActionMaskWrapper,
ObservationAsImageWrapper,
)
from generals.agents import Agent, AgentFactory

from generals import GridFactory
Expand All @@ -22,13 +17,13 @@

def pz_generals_v0(
grid_factory: GridFactory = GridFactory(),
agent_ids: list[str] = None,
agents: list[str] = None,
render_mode=None,
):
assert len(agent_ids) == 2, "For now, only 2 agents are supported in PZ_Generals."
assert len(agents) == 2, "For now, only 2 agents are supported in PZ_Generals."
env = PettingZooGenerals(
grid_factory=grid_factory,
agent_ids=agent_ids,
agents=agents,
render_mode=render_mode,
)
return env
Expand All @@ -54,7 +49,4 @@ def gym_generals_v0(
agent_color=agent_color,
reward_fn=reward_fn,
)
# env = NormalizeObservationWrapper(env)
# env = RemoveActionMaskWrapper(env)
# env = ObservationAsImageWrapper(env)
return env
3 changes: 2 additions & 1 deletion generals/envs/gymnasium_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ def _default_reward(
return reward

def close(self) -> None:
self.gui.close()
if hasattr(self, "replay"):
self.gui.close()
23 changes: 12 additions & 11 deletions generals/envs/pettingzoo_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PettingZooGenerals(pettingzoo.ParallelEnv):
def __init__(
self,
grid_factory: GridFactory,
agent_ids: list[str],
agents: list[str],
reward_fn: RewardFn = None,
render_mode=None,
):
Expand All @@ -46,10 +46,10 @@ def __init__(

self.agent_data = {
agent_id: {"color": color}
for agent_id, color in zip(agent_ids, self.default_colors)
for agent_id, color in zip(agents, self.default_colors)
}
self.agent_ids = agent_ids
self.possible_agents = agent_ids
self.agents = agents
self.possible_agents = agents

assert len(self.possible_agents) == len(
set(self.possible_agents)
Expand Down Expand Up @@ -83,7 +83,7 @@ def reset(
else:
grid = self.grid_factory.grid_from_generator(seed=seed)

self.game = Game(grid, self.agent_ids)
self.game = Game(grid, self.agents)

if self.render_mode == "human":
self.gui = GUI(self.game, self.agent_data, GuiMode.TRAIN)
Expand All @@ -99,7 +99,7 @@ def reset(
del self.replay

observations = self.game.get_all_observations()
infos = {agent: {} for agent in self.agent_ids}
infos = {agent: {} for agent in self.agents}
return observations, infos

def step(
Expand All @@ -112,9 +112,9 @@ def step(
dict[AgentID, Info],
]:
observations, infos = self.game.step(actions)
truncated = {agent: False for agent in self.agent_ids} # no truncation
truncated = {agent: False for agent in self.agents} # no truncation
terminated = {
agent: True if self.game.is_done() else False for agent in self.agent_ids
agent: True if self.game.is_done() else False for agent in self.agents
}
rewards = {
agent: self.reward_fn(
Expand All @@ -123,7 +123,7 @@ def step(
terminated[agent] or truncated[agent],
infos[agent],
)
for agent in self.agent_ids
for agent in self.agents
}

if hasattr(self, "replay"):
Expand All @@ -132,7 +132,7 @@ def step(
# if any agent dies, all agents are terminated
terminate = any(terminated.values())
if terminate:
self.agents_ids = []
self.agents = []
if hasattr(self, "replay"):
self.replay.store()
return observations, rewards, terminated, truncated, infos
Expand All @@ -154,4 +154,5 @@ def _default_reward(
return reward

def close(self) -> None:
self.gui.close()
if self.render_mode == "human":
self.gui.close()
15 changes: 7 additions & 8 deletions tests/parallel_api_check.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations
from generals import pz_generals
from generals.agents import RandomAgent
from generals.core.grid import GridFactory
from generals import GridFactory, AgentFactory
import warnings

import numpy as np
import gymnasium as gym

from pettingzoo.test.api_test import missing_attr_warning
from pettingzoo.utils.conversions import (
Expand Down Expand Up @@ -141,13 +140,13 @@ def parallel_api_test(par_env: ParallelEnv, num_cycles=1000):

if __name__ == "__main__":
mapper = GridFactory()
agent1 = RandomAgent(name="A")
agent2 = RandomAgent(name="B")
agent1 = AgentFactory.make_agent("expander", id="A")
agent2 = AgentFactory.make_agent("random", id="B")
agents = {
agent1.name: agent1,
agent2.name: agent2,
agent1.id: agent1,
agent2.id: agent2,
}
env = pz_generals(mapper, agents)
env = gym.make("pz-generals-v0", agents=list(agents.keys()), grid_factory=mapper)
# test the environment with parallel_api_test
import time
start = time.time()
Expand Down

0 comments on commit acf8fd5

Please sign in to comment.