diff --git a/docs/wrappers.md b/docs/wrappers.md index 81fc3ce4835..7b88c3f5a01 100644 --- a/docs/wrappers.md +++ b/docs/wrappers.md @@ -141,4 +141,15 @@ Lastly the `name_prefix` allows you to customize the name of the videos. `TimeLimit(env, max_episode_steps)` [text] * Needs review (including for good assertion messages and test coverage) +`OrderEnforcing(env)` [text] + +`OrderEnforcing` is a light-weight wrapper that throws an exception when `env.step()` is called before `env.reset()`, the wrapper is enabled by default for environment specs without `max_episode_steps` and can be disabled by passing `order_enforce=False` like: +```python3 +register( + id="CustomEnv-v1", + entry_point="...", + order_enforce=False, +) +``` + Some sort of vector environment conversion wrapper needs to be added here, this will be figured out after the API is changed. diff --git a/gym/envs/box2d/bipedal_walker.py b/gym/envs/box2d/bipedal_walker.py index 039e59f2aff..34b564f217c 100644 --- a/gym/envs/box2d/bipedal_walker.py +++ b/gym/envs/box2d/bipedal_walker.py @@ -142,8 +142,6 @@ def __init__(self): categoryBits=0x0001, ) - self.reset() - high = np.array([np.inf] * 24).astype(np.float32) self.action_space = spaces.Box( np.array([-1, -1, -1, -1]).astype(np.float32), diff --git a/gym/envs/box2d/lunar_lander.py b/gym/envs/box2d/lunar_lander.py index efa68243e00..57dbbaec6f7 100644 --- a/gym/envs/box2d/lunar_lander.py +++ b/gym/envs/box2d/lunar_lander.py @@ -117,8 +117,6 @@ def __init__(self): # Nop, fire left engine, main engine, right engine self.action_space = spaces.Discrete(4) - self.reset() - def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] diff --git a/gym/envs/classic_control/continuous_mountain_car.py b/gym/envs/classic_control/continuous_mountain_car.py index 286bd2210d7..65314444fc9 100644 --- a/gym/envs/classic_control/continuous_mountain_car.py +++ b/gym/envs/classic_control/continuous_mountain_car.py @@ -85,7 +85,6 @@ def __init__(self, goal_velocity=0): ) self.seed() - self.reset() def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) diff --git a/gym/envs/registration.py b/gym/envs/registration.py index a29bfd4f07b..c6a18cc0254 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -40,6 +40,7 @@ class EnvSpec(object): reward_threshold (Optional[int]): The reward threshold before the task is considered solved nondeterministic (bool): Whether this environment is non-deterministic even after seeding max_episode_steps (Optional[int]): The maximum number of steps that an episode can consist of + order_enforce (Optional[int]): Whether to wrap the environment in an orderEnforcing wrapper kwargs (dict): The kwargs to pass to the environment class """ @@ -51,6 +52,7 @@ def __init__( reward_threshold=None, nondeterministic=False, max_episode_steps=None, + order_enforce=True, kwargs=None, ): self.id = id @@ -58,6 +60,7 @@ def __init__( self.reward_threshold = reward_threshold self.nondeterministic = nondeterministic self.max_episode_steps = max_episode_steps + self.order_enforce = order_enforce self._kwargs = {} if kwargs is None else kwargs match = env_id_re.search(id) @@ -77,8 +80,10 @@ def make(self, **kwargs): self.id ) ) + _kwargs = self._kwargs.copy() _kwargs.update(kwargs) + if callable(self.entry_point): env = self.entry_point(**_kwargs) else: @@ -89,7 +94,15 @@ def make(self, **kwargs): spec = copy.deepcopy(self) spec._kwargs = _kwargs env.unwrapped.spec = spec + if env.spec.max_episode_steps is not None: + from gym.wrappers.time_limit import TimeLimit + env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps) + else: + if self.order_enforce: + from gym.wrappers.order_enforcing import OrderEnforcing + + env = OrderEnforcing(env) return env def __repr__(self): @@ -115,10 +128,6 @@ def make(self, path, **kwargs): logger.info("Making new env: %s", path) spec = self.spec(path) env = spec.make(**kwargs) - if env.spec.max_episode_steps is not None: - from gym.wrappers.time_limit import TimeLimit - - env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps) return env def all(self): diff --git a/gym/envs/toy_text/blackjack.py b/gym/envs/toy_text/blackjack.py index 103f0a43076..fdbf149f42f 100644 --- a/gym/envs/toy_text/blackjack.py +++ b/gym/envs/toy_text/blackjack.py @@ -86,8 +86,6 @@ def __init__(self, natural=False, sab=False): # Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural self.sab = sab - # Start the first game - self.reset() def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) diff --git a/gym/utils/env_checker.py b/gym/utils/env_checker.py index 87e959f1ca2..66159861232 100644 --- a/gym/utils/env_checker.py +++ b/gym/utils/env_checker.py @@ -300,10 +300,14 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - # ============= Check the spaces (observation and action) ================ _check_spaces(env) - # Define aliases for convenience observation_space = env.observation_space action_space = env.action_space + try: + env.step(env.action_space.sample()) + + except AssertionError as e: + assert str(e) == "Cannot call env.step() before calling reset()" # Warn the user if needed. # A warning means that the environment may run but not work properly with popular RL libraries. diff --git a/gym/wrappers/order_enforcing.py b/gym/wrappers/order_enforcing.py new file mode 100644 index 00000000000..f3bbd838dd1 --- /dev/null +++ b/gym/wrappers/order_enforcing.py @@ -0,0 +1,16 @@ +import gym + + +class OrderEnforcing(gym.Wrapper): + def __init__(self, env): + super(OrderEnforcing, self).__init__(env) + self._has_reset = False + + def step(self, action): + assert self._has_reset, "Cannot call env.step() before calling reset()" + observation, reward, done, info = self.env.step(action) + return observation, reward, done, info + + def reset(self, **kwargs): + self._has_reset = True + return self.env.reset(**kwargs) diff --git a/gym/wrappers/test_record_video.py b/gym/wrappers/test_record_video.py index 0ff107f1a69..509fe908709 100644 --- a/gym/wrappers/test_record_video.py +++ b/gym/wrappers/test_record_video.py @@ -11,6 +11,7 @@ def test_record_video_using_default_trigger(): + env = gym.make("CartPole-v1") env = gym.wrappers.RecordVideo(env, "videos") env.reset() @@ -66,7 +67,7 @@ def test_record_video_within_vector(): _, _, _, infos = envs.step(envs.action_space.sample()) for info in infos: if "episode" in info.keys(): - print(f"i, episode_reward={info['episode']['r']}") + print(f"episode_reward={info['episode']['r']}") break assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]