Skip to content

Commit

Permalink
feat: Add RemoveActionMaskWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 11, 2024
1 parent 74e447a commit 3f590fe
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
4 changes: 3 additions & 1 deletion generals/envs/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .gymnasium_environment import GymnasiumGenerals
from .pettingzoo_environment import PettingZooGenerals
from .wrappers.test_wrappers import NormalizeObservationWrapper
from .wrappers.test_wrappers import NormalizeObservationWrapper, RemoveActionMaskWrapper
from generals.agents import Agent, AgentFactory

from generals import GridFactory
Expand Down Expand Up @@ -50,4 +50,6 @@ def gym_generals_v0(
agent_color=agent_color,
)
env = NormalizeObservationWrapper(env)
env = RemoveActionMaskWrapper(env)
print(env.observation_space)
return env
33 changes: 33 additions & 0 deletions generals/envs/wrappers/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,36 @@ def observation(self, observation):
)
observation["observation"] = _observation
return observation


class RemoveActionMaskWrapper(gym.ObservationWrapper):
def __init__(self, env):
super(RemoveActionMaskWrapper, self).__init__(env)
grid_multi_binary = gym.spaces.MultiBinary(self.game.grid_dims)
unit_box = gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
self.observation_space = gym.spaces.Dict(
{
"army": gym.spaces.Box(
low=0, high=1, shape=self.game.grid_dims, dtype=np.float32
),
"general": grid_multi_binary,
"city": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"structure": grid_multi_binary,
"owned_land_count": unit_box,
"owned_army_count": unit_box,
"opponent_land_count": unit_box,
"opponent_army_count": unit_box,
"is_winner": gym.spaces.Discrete(2),
"timestep": unit_box,
}
)

def observation(self, observation):
_observation = (
observation["observation"] if "observation" in observation else observation
)
return _observation

0 comments on commit 3f590fe

Please sign in to comment.