Skip to content

Commit

Permalink
Forward environment properties to the wrapper (openai#2373)
Browse files Browse the repository at this point in the history
* Forward environment properties to the wrapper, fixes openai#2175

* Add tests for property forwarding in Wrapper

* Rename klass to class_ in test_core
  • Loading branch information
tristandeleu authored Sep 17, 2021
1 parent 5a4709b commit d35d211
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 5 deletions.
49 changes: 45 additions & 4 deletions gym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,11 @@ class Wrapper(Env):

def __init__(self, env):
self.env = env
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
self.reward_range = self.env.reward_range
self.metadata = self.env.metadata

self._action_space = None
self._observation_space = None
self._reward_range = None
self._metadata = None

def __getattr__(self, name):
if name.startswith("_"):
Expand All @@ -244,6 +245,46 @@ def spec(self):
def class_name(cls):
return cls.__name__

@property
def action_space(self):
if self._action_space is None:
return self.env.action_space
return self._action_space

@action_space.setter
def action_space(self, space):
self._action_space = space

@property
def observation_space(self):
if self._observation_space is None:
return self.env.observation_space
return self._observation_space

@observation_space.setter
def observation_space(self, space):
self._observation_space = space

@property
def reward_range(self):
if self._reward_range is None:
return self.env.reward_range
return self._reward_range

@reward_range.setter
def reward_range(self, value):
self._reward_range = value

@property
def metadata(self):
if self._metadata is None:
return self.env.metadata
return self._metadata

@metadata.setter
def metadata(self, value):
self._metadata = value

def step(self, action):
return self.env.step(action)

Expand Down
95 changes: 94 additions & 1 deletion gym/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from gym import core
import pytest
import numpy as np

from gym import core, spaces


class ArgumentEnv(core.Env):
Expand All @@ -9,9 +12,99 @@ def __init__(self, arg):
self.arg = arg


class UnittestEnv(core.Env):
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3)

def reset(self):
return self.observation_space.sample() # Dummy observation

def step(self, action):
observation = self.observation_space.sample() # Dummy observation
return (observation, 0.0, False, {})


class UnknownSpacesEnv(core.Env):
"""This environment defines its observation & action spaces only
after the first call to reset. Although this pattern is sometimes
necessary when implementing a new environment (e.g. if it depends
on external resources), it is not encouraged.
"""

def reset(self):
self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(3)
return self.observation_space.sample() # Dummy observation

def step(self, action):
observation = self.observation_space.sample() # Dummy observation
return (observation, 0.0, False, {})


class NewPropertyWrapper(core.Wrapper):
def __init__(
self,
env,
observation_space=None,
action_space=None,
reward_range=None,
metadata=None,
):
super().__init__(env)
if observation_space is not None:
# Only set the observation space if not None to test property forwarding
self.observation_space = observation_space
if action_space is not None:
self.action_space = action_space
if reward_range is not None:
self.reward_range = reward_range
if metadata is not None:
self.metadata = metadata


def test_env_instantiation():
# This looks like a pretty trivial, but given our usage of
# __new__, it's worth having.
env = ArgumentEnv("arg")
assert env.arg == "arg"
assert env.calls == 1


properties = [
{
"observation_space": spaces.Box(
low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32
)
},
{"action_space": spaces.Discrete(2)},
{"reward_range": (-1.0, 1.0)},
{"metadata": {"render.modes": ["human", "rgb_array"]}},
{
"observation_space": spaces.Box(
low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32
),
"action_space": spaces.Discrete(2),
},
]


@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv])
@pytest.mark.parametrize("props", properties)
def test_wrapper_property_forwarding(class_, props):
env = class_()
env = NewPropertyWrapper(env, **props)

# If UnknownSpacesEnv, then call reset to define the spaces
if isinstance(env.unwrapped, UnknownSpacesEnv):
_ = env.reset()

# Test the properties set by the wrapper
for key, value in props.items():
assert getattr(env, key) == value

# Otherwise, test if the properties are forwarded
all_properties = {"observation_space", "action_space", "reward_range", "metadata"}
for key in all_properties - props.keys():
assert getattr(env, key) == getattr(env.unwrapped, key)

0 comments on commit d35d211

Please sign in to comment.