From 3f590fe2181d9b21bbd9df4cab5cb02d69defdf1 Mon Sep 17 00:00:00 2001 From: Matej Straka Date: Fri, 11 Oct 2024 14:26:01 +0200 Subject: [PATCH] feat: Add RemoveActionMaskWrapper --- generals/envs/env.py | 4 ++- generals/envs/wrappers/test_wrappers.py | 33 +++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/generals/envs/env.py b/generals/envs/env.py index 6cbd92f..6c86077 100644 --- a/generals/envs/env.py +++ b/generals/envs/env.py @@ -1,6 +1,6 @@ from .gymnasium_environment import GymnasiumGenerals from .pettingzoo_environment import PettingZooGenerals -from .wrappers.test_wrappers import NormalizeObservationWrapper +from .wrappers.test_wrappers import NormalizeObservationWrapper, RemoveActionMaskWrapper from generals.agents import Agent, AgentFactory from generals import GridFactory @@ -50,4 +50,6 @@ def gym_generals_v0( agent_color=agent_color, ) env = NormalizeObservationWrapper(env) + env = RemoveActionMaskWrapper(env) + print(env.observation_space) return env diff --git a/generals/envs/wrappers/test_wrappers.py b/generals/envs/wrappers/test_wrappers.py index 289299c..f4ca506 100644 --- a/generals/envs/wrappers/test_wrappers.py +++ b/generals/envs/wrappers/test_wrappers.py @@ -68,3 +68,36 @@ def observation(self, observation): ) observation["observation"] = _observation return observation + + +class RemoveActionMaskWrapper(gym.ObservationWrapper): + def __init__(self, env): + super(RemoveActionMaskWrapper, self).__init__(env) + grid_multi_binary = gym.spaces.MultiBinary(self.game.grid_dims) + unit_box = gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32) + self.observation_space = gym.spaces.Dict( + { + "army": gym.spaces.Box( + low=0, high=1, shape=self.game.grid_dims, dtype=np.float32 + ), + "general": grid_multi_binary, + "city": grid_multi_binary, + "owned_cells": grid_multi_binary, + "opponent_cells": grid_multi_binary, + "neutral_cells": grid_multi_binary, + "visible_cells": grid_multi_binary, + "structure": grid_multi_binary, + "owned_land_count": unit_box, + "owned_army_count": unit_box, + "opponent_land_count": unit_box, + "opponent_army_count": unit_box, + "is_winner": gym.spaces.Discrete(2), + "timestep": unit_box, + } + ) + + def observation(self, observation): + _observation = ( + observation["observation"] if "observation" in observation else observation + ) + return _observation