Skip to content

Commit

Permalink
removed calls to reset from init (openai#2394)
Browse files Browse the repository at this point in the history
* removed all calls to reset

* passing tests

* fix off-by-one error

* revert

* merge master into branch

* add OrderEnforcing Wrapper

* add orderenforcing to the docs

* add option for disabling

* add argument to EnvSpec
  • Loading branch information
ahmedo42 authored Sep 16, 2021
1 parent e212043 commit 2754d97
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 13 deletions.
11 changes: 11 additions & 0 deletions docs/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,15 @@ Lastly the `name_prefix` allows you to customize the name of the videos.
`TimeLimit(env, max_episode_steps)` [text]
* Needs review (including for good assertion messages and test coverage)

`OrderEnforcing(env)` [text]

`OrderEnforcing` is a light-weight wrapper that throws an exception when `env.step()` is called before `env.reset()`, the wrapper is enabled by default for environment specs without `max_episode_steps` and can be disabled by passing `order_enforce=False` like:
```python3
register(
id="CustomEnv-v1",
entry_point="...",
order_enforce=False,
)
```

Some sort of vector environment conversion wrapper needs to be added here, this will be figured out after the API is changed.
2 changes: 0 additions & 2 deletions gym/envs/box2d/bipedal_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def __init__(self):
categoryBits=0x0001,
)

self.reset()

high = np.array([np.inf] * 24).astype(np.float32)
self.action_space = spaces.Box(
np.array([-1, -1, -1, -1]).astype(np.float32),
Expand Down
2 changes: 0 additions & 2 deletions gym/envs/box2d/lunar_lander.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def __init__(self):
# Nop, fire left engine, main engine, right engine
self.action_space = spaces.Discrete(4)

self.reset()

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
Expand Down
1 change: 0 additions & 1 deletion gym/envs/classic_control/continuous_mountain_car.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(self, goal_velocity=0):
)

self.seed()
self.reset()

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
Expand Down
17 changes: 13 additions & 4 deletions gym/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class EnvSpec(object):
reward_threshold (Optional[int]): The reward threshold before the task is considered solved
nondeterministic (bool): Whether this environment is non-deterministic even after seeding
max_episode_steps (Optional[int]): The maximum number of steps that an episode can consist of
order_enforce (Optional[int]): Whether to wrap the environment in an orderEnforcing wrapper
kwargs (dict): The kwargs to pass to the environment class
"""
Expand All @@ -51,13 +52,15 @@ def __init__(
reward_threshold=None,
nondeterministic=False,
max_episode_steps=None,
order_enforce=True,
kwargs=None,
):
self.id = id
self.entry_point = entry_point
self.reward_threshold = reward_threshold
self.nondeterministic = nondeterministic
self.max_episode_steps = max_episode_steps
self.order_enforce = order_enforce
self._kwargs = {} if kwargs is None else kwargs

match = env_id_re.search(id)
Expand All @@ -77,8 +80,10 @@ def make(self, **kwargs):
self.id
)
)

_kwargs = self._kwargs.copy()
_kwargs.update(kwargs)

if callable(self.entry_point):
env = self.entry_point(**_kwargs)
else:
Expand All @@ -89,7 +94,15 @@ def make(self, **kwargs):
spec = copy.deepcopy(self)
spec._kwargs = _kwargs
env.unwrapped.spec = spec
if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit

env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
else:
if self.order_enforce:
from gym.wrappers.order_enforcing import OrderEnforcing

env = OrderEnforcing(env)
return env

def __repr__(self):
Expand All @@ -115,10 +128,6 @@ def make(self, path, **kwargs):
logger.info("Making new env: %s", path)
spec = self.spec(path)
env = spec.make(**kwargs)
if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit

env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
return env

def all(self):
Expand Down
2 changes: 0 additions & 2 deletions gym/envs/toy_text/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def __init__(self, natural=False, sab=False):

# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
self.sab = sab
# Start the first game
self.reset()

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
Expand Down
6 changes: 5 additions & 1 deletion gym/utils/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,14 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -

# ============= Check the spaces (observation and action) ================
_check_spaces(env)

# Define aliases for convenience
observation_space = env.observation_space
action_space = env.action_space
try:
env.step(env.action_space.sample())

except AssertionError as e:
assert str(e) == "Cannot call env.step() before calling reset()"

# Warn the user if needed.
# A warning means that the environment may run but not work properly with popular RL libraries.
Expand Down
16 changes: 16 additions & 0 deletions gym/wrappers/order_enforcing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import gym


class OrderEnforcing(gym.Wrapper):
def __init__(self, env):
super(OrderEnforcing, self).__init__(env)
self._has_reset = False

def step(self, action):
assert self._has_reset, "Cannot call env.step() before calling reset()"
observation, reward, done, info = self.env.step(action)
return observation, reward, done, info

def reset(self, **kwargs):
self._has_reset = True
return self.env.reset(**kwargs)
3 changes: 2 additions & 1 deletion gym/wrappers/test_record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


def test_record_video_using_default_trigger():

env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos")
env.reset()
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_record_video_within_vector():
_, _, _, infos = envs.step(envs.action_space.sample())
for info in infos:
if "episode" in info.keys():
print(f"i, episode_reward={info['episode']['r']}")
print(f"episode_reward={info['episode']['r']}")
break
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
Expand Down

0 comments on commit 2754d97

Please sign in to comment.