Skip to content

Commit

Permalink
fix: Update RemoveActionMaskWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Nov 22, 2024
1 parent 0829aa8 commit 8b9676e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
5 changes: 5 additions & 0 deletions generals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,10 @@ def _register_gym_generals_envs():
entry_point="generals.envs.initializers:gyms_generals_normalized_v0",
)

register(
id="gym-generals-image-v0",
entry_point="generals.envs.initializers:gym_image_observations",
)


_register_gym_generals_envs()
27 changes: 14 additions & 13 deletions generals/envs/gymnasium_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,25 @@ def observation(self, observation):
class RemoveActionMaskWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
grid_multi_binary = gym.spaces.MultiBinary(self.game.grid_dims)
unit_box = gym.spaces.Box(low=0, high=1, dtype=np.float32)
grid_multi_binary = gym.spaces.MultiBinary(self.grid_dims)
grid_discrete = np.ones(self.grid_dims, dtype=int) * self.max_army_value
self.observation_space = gym.spaces.Dict(
{
"army": gym.spaces.Box(low=0, high=1, shape=self.game.grid_dims, dtype=np.float32),
"general": grid_multi_binary,
"city": grid_multi_binary,
"armies": gym.spaces.MultiDiscrete(grid_discrete),
"generals": grid_multi_binary,
"cities": grid_multi_binary,
"mountains": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"fog_cells": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": unit_box,
"owned_army_count": unit_box,
"opponent_land_count": unit_box,
"opponent_army_count": unit_box,
"is_winner": gym.spaces.Discrete(2),
"timestep": unit_box,
"owned_land_count": gym.spaces.Discrete(self.max_army_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_army_value),
"opponent_army_count": gym.spaces.Discrete(self.max_army_value),
"timestep": gym.spaces.Discrete(self.max_timestep),
"priority": gym.spaces.Discrete(2),
}
)

Expand Down
23 changes: 22 additions & 1 deletion generals/envs/initializers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from generals import GridFactory
from generals.agents import Agent
from generals.envs.gymnasium_generals import GymnasiumGenerals, RewardFn
from generals.envs.gymnasium_wrappers import NormalizedObservationWrapper
from generals.envs.gymnasium_wrappers import NormalizedObservationWrapper, ObservationAsImageWrapper

"""
Here we can define environment initialization functions that
Expand Down Expand Up @@ -35,3 +35,24 @@ def gym_generals_normalized_v0(
)
env = NormalizedObservationWrapper(_env)
return env

def gym_image_observations(
grid_factory: GridFactory | None = None,
npc: Agent | None = None,
agent: Agent | None = None,
render_mode: str | None = None,
reward_fn: RewardFn | None = None,
):
"""
Example of a Gymnasium environment initializer that creates
an environment that returns image observations.
"""
_env = GymnasiumGenerals(
grid_factory=grid_factory,
npc=npc,
agent=agent,
render_mode=render_mode,
reward_fn=reward_fn,
)
env = ObservationAsImageWrapper(_env)
return env

0 comments on commit 8b9676e

Please sign in to comment.