Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Normalize env #2387

Merged
merged 15 commits into from
Sep 9, 2021
11 changes: 9 additions & 2 deletions docs/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ Gym includes numerous wrappers for environments that include preprocessing and v
`TimeAwareObservation(env)` [text]
* Needs review (including for good assertion messages and test coverage)


`NormalizeObservation(env, epsilon=1e-8)` [text]
* This wrapper normalizes the observations to have approximately zero mean and unit variance


`NormalizeReward(env, gamma=0.99, epsilon=1e-8)` [text]
* This wrapper scales the rewards, which are divided through by the standard deviation of a rolling discounted returns. See page 3 of from [Engstrom, Ilyas et al. (2020)](https://arxiv.org/pdf/2005.12729.pdf)

## Action Wrappers

`ClipAction(env)` [text]
Expand Down Expand Up @@ -97,8 +105,7 @@ the `RecordVideo` uses episode counts to trigger video recording based on the `e
which is a cubic progression for early episodes (1,8,27,...) and then every 1000 episodes (1000, 2000, 3000...).
This can be changed by modifying the `episode_trigger` argument of the `RecordVideo`).

Alternatively, you may also trigger the the video recording based on the environment steps via the
`step_trigger` like
Alternatively, you may also trigger the the video recording based on the environment steps via the `step_trigger` like

```python
import gym
Expand Down
1 change: 1 addition & 0 deletions gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
from gym.wrappers.resize_observation import ResizeObservation
from gym.wrappers.clip_action import ClipAction
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
105 changes: 105 additions & 0 deletions gym/wrappers/normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np
import gym


# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, "float64")
self.var = np.ones(shape, "float64")
self.count = epsilon

def update(self, x):
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)

def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)


def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
delta = batch_mean - mean
tot_count = count + batch_count

new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count

return new_mean, new_var, new_count


class NormalizeObservation(gym.core.Wrapper):
def __init__(
self,
env,
epsilon=1e-8,
):
super(NormalizeObservation, self).__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
if self.is_vector_env:
self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape)
else:
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon

def step(self, action):
obs, rews, dones, infos = self.env.step(action)
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs, rews, dones, infos

def reset(self):
obs = self.env.reset()
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs

def normalize(self, obs):
self.obs_rms.update(obs)
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)


class NormalizeReward(gym.core.Wrapper):
def __init__(
self,
env,
gamma=0.99,
epsilon=1e-8,
):
super(NormalizeReward, self).__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
self.return_rms = RunningMeanStd(shape=())
self.returns = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon

def step(self, action):
obs, rews, dones, infos = self.env.step(action)
if not self.is_vector_env:
rews = np.array([rews])
self.returns = self.returns * self.gamma + rews
rews = self.normalize(rews)
self.returns[dones] = 0.0
if not self.is_vector_env:
rews = rews[0]
return obs, rews, dones, infos

def normalize(self, rews):
self.return_rms.update(self.returns)
return rews / np.sqrt(self.return_rms.var + self.epsilon)
108 changes: 108 additions & 0 deletions gym/wrappers/test_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import gym
import numpy as np
from numpy.testing import assert_almost_equal

from gym.wrappers.normalize import NormalizeObservation, NormalizeReward


class DummyRewardEnv(gym.Env):
metadata = {}

def __init__(self, return_reward_idx=0):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(
low=np.array([-1.0]), high=np.array([1.0])
)
self.returned_rewards = [0, 1, 2, 3, 4]
self.return_reward_idx = return_reward_idx
self.t = self.return_reward_idx

def step(self, action):
self.t += 1
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}

def reset(self):
self.t = self.return_reward_idx
return np.array([self.t])


def make_env(return_reward_idx):
def thunk():
env = DummyRewardEnv(return_reward_idx)
return env

return thunk


def test_normalize_observation():
env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeObservation(env)
env.reset()
env.step(env.action_space.sample())
assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4)
env.step(env.action_space.sample())
assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4)


def test_normalize_return():
env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeReward(env)
env.reset()
env.step(env.action_space.sample())
assert_almost_equal(
env.return_rms.mean,
np.mean([1]), # [first return]
decimal=4,
)
env.step(env.action_space.sample())
assert_almost_equal(
env.return_rms.mean,
np.mean([2 + env.gamma * 1, 1]), # [second return, first return]
decimal=4,
)


def test_normalize_observation_vector_env():
env_fns = [make_env(0), make_env(1)]
envs = gym.vector.SyncVectorEnv(env_fns)
envs.reset()
obs, reward, _, _ = envs.step(envs.action_space.sample())
np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4)
np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4)

env_fns = [make_env(0), make_env(1)]
envs = gym.vector.SyncVectorEnv(env_fns)
envs = NormalizeObservation(envs)
envs.reset()
assert_almost_equal(
envs.obs_rms.mean,
np.mean([0.5]), # the mean of first observations [[0, 1]]
decimal=4,
)
obs, reward, _, _ = envs.step(envs.action_space.sample())
assert_almost_equal(
envs.obs_rms.mean,
np.mean([1.0]), # the mean of first and second observations [[0, 1], [1, 2]]
decimal=4,
)


def test_normalize_return_vector_env():
env_fns = [make_env(0), make_env(1)]
envs = gym.vector.SyncVectorEnv(env_fns)
envs = NormalizeReward(envs)
obs = envs.reset()
obs, reward, _, _ = envs.step(envs.action_space.sample())
assert_almost_equal(
envs.return_rms.mean,
np.mean([1.5]), # the mean of first returns [[1, 2]]
decimal=4,
)
obs, reward, _, _ = envs.step(envs.action_space.sample())
assert_almost_equal(
envs.return_rms.mean,
np.mean(
[[1, 2], [2 + envs.gamma * 1, 3 + envs.gamma * 2]]
), # the mean of first and second returns [[1, 2], [2 + envs.gamma * 1, 3 + envs.gamma * 2]]
decimal=4,
)