forked from openai/gym
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial commit * undo black * add code * add test cases and refactor * add docs * black * documentation update * break feature apart * quick fix * quick fix * quick fix * update documentation * update documentation * Update wrapper naming * fix ci
- Loading branch information
Showing
4 changed files
with
223 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import numpy as np | ||
import gym | ||
|
||
|
||
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py | ||
class RunningMeanStd(object): | ||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | ||
def __init__(self, epsilon=1e-4, shape=()): | ||
self.mean = np.zeros(shape, "float64") | ||
self.var = np.ones(shape, "float64") | ||
self.count = epsilon | ||
|
||
def update(self, x): | ||
batch_mean = np.mean(x, axis=0) | ||
batch_var = np.var(x, axis=0) | ||
batch_count = x.shape[0] | ||
self.update_from_moments(batch_mean, batch_var, batch_count) | ||
|
||
def update_from_moments(self, batch_mean, batch_var, batch_count): | ||
self.mean, self.var, self.count = update_mean_var_count_from_moments( | ||
self.mean, self.var, self.count, batch_mean, batch_var, batch_count | ||
) | ||
|
||
|
||
def update_mean_var_count_from_moments( | ||
mean, var, count, batch_mean, batch_var, batch_count | ||
): | ||
delta = batch_mean - mean | ||
tot_count = count + batch_count | ||
|
||
new_mean = mean + delta * batch_count / tot_count | ||
m_a = var * count | ||
m_b = batch_var * batch_count | ||
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count | ||
new_var = M2 / tot_count | ||
new_count = tot_count | ||
|
||
return new_mean, new_var, new_count | ||
|
||
|
||
class NormalizeObservation(gym.core.Wrapper): | ||
def __init__( | ||
self, | ||
env, | ||
epsilon=1e-8, | ||
): | ||
super(NormalizeObservation, self).__init__(env) | ||
self.num_envs = getattr(env, "num_envs", 1) | ||
self.is_vector_env = getattr(env, "is_vector_env", False) | ||
if self.is_vector_env: | ||
self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape) | ||
else: | ||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) | ||
self.epsilon = epsilon | ||
|
||
def step(self, action): | ||
obs, rews, dones, infos = self.env.step(action) | ||
if self.is_vector_env: | ||
obs = self.normalize(obs) | ||
else: | ||
obs = self.normalize(np.array([obs]))[0] | ||
return obs, rews, dones, infos | ||
|
||
def reset(self): | ||
obs = self.env.reset() | ||
if self.is_vector_env: | ||
obs = self.normalize(obs) | ||
else: | ||
obs = self.normalize(np.array([obs]))[0] | ||
return obs | ||
|
||
def normalize(self, obs): | ||
self.obs_rms.update(obs) | ||
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon) | ||
|
||
|
||
class NormalizeReward(gym.core.Wrapper): | ||
def __init__( | ||
self, | ||
env, | ||
gamma=0.99, | ||
epsilon=1e-8, | ||
): | ||
super(NormalizeReward, self).__init__(env) | ||
self.num_envs = getattr(env, "num_envs", 1) | ||
self.is_vector_env = getattr(env, "is_vector_env", False) | ||
self.return_rms = RunningMeanStd(shape=()) | ||
self.returns = np.zeros(self.num_envs) | ||
self.gamma = gamma | ||
self.epsilon = epsilon | ||
|
||
def step(self, action): | ||
obs, rews, dones, infos = self.env.step(action) | ||
if not self.is_vector_env: | ||
rews = np.array([rews]) | ||
self.returns = self.returns * self.gamma + rews | ||
rews = self.normalize(rews) | ||
self.returns[dones] = 0.0 | ||
if not self.is_vector_env: | ||
rews = rews[0] | ||
return obs, rews, dones, infos | ||
|
||
def normalize(self, rews): | ||
self.return_rms.update(self.returns) | ||
return rews / np.sqrt(self.return_rms.var + self.epsilon) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import gym | ||
import numpy as np | ||
from numpy.testing import assert_almost_equal | ||
|
||
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward | ||
|
||
|
||
class DummyRewardEnv(gym.Env): | ||
metadata = {} | ||
|
||
def __init__(self, return_reward_idx=0): | ||
self.action_space = gym.spaces.Discrete(2) | ||
self.observation_space = gym.spaces.Box( | ||
low=np.array([-1.0]), high=np.array([1.0]) | ||
) | ||
self.returned_rewards = [0, 1, 2, 3, 4] | ||
self.return_reward_idx = return_reward_idx | ||
self.t = self.return_reward_idx | ||
|
||
def step(self, action): | ||
self.t += 1 | ||
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {} | ||
|
||
def reset(self): | ||
self.t = self.return_reward_idx | ||
return np.array([self.t]) | ||
|
||
|
||
def make_env(return_reward_idx): | ||
def thunk(): | ||
env = DummyRewardEnv(return_reward_idx) | ||
return env | ||
|
||
return thunk | ||
|
||
|
||
def test_normalize_observation(): | ||
env = DummyRewardEnv(return_reward_idx=0) | ||
env = NormalizeObservation(env) | ||
env.reset() | ||
env.step(env.action_space.sample()) | ||
assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4) | ||
env.step(env.action_space.sample()) | ||
assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4) | ||
|
||
|
||
def test_normalize_return(): | ||
env = DummyRewardEnv(return_reward_idx=0) | ||
env = NormalizeReward(env) | ||
env.reset() | ||
env.step(env.action_space.sample()) | ||
assert_almost_equal( | ||
env.return_rms.mean, | ||
np.mean([1]), # [first return] | ||
decimal=4, | ||
) | ||
env.step(env.action_space.sample()) | ||
assert_almost_equal( | ||
env.return_rms.mean, | ||
np.mean([2 + env.gamma * 1, 1]), # [second return, first return] | ||
decimal=4, | ||
) | ||
|
||
|
||
def test_normalize_observation_vector_env(): | ||
env_fns = [make_env(0), make_env(1)] | ||
envs = gym.vector.SyncVectorEnv(env_fns) | ||
envs.reset() | ||
obs, reward, _, _ = envs.step(envs.action_space.sample()) | ||
np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4) | ||
np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4) | ||
|
||
env_fns = [make_env(0), make_env(1)] | ||
envs = gym.vector.SyncVectorEnv(env_fns) | ||
envs = NormalizeObservation(envs) | ||
envs.reset() | ||
assert_almost_equal( | ||
envs.obs_rms.mean, | ||
np.mean([0.5]), # the mean of first observations [[0, 1]] | ||
decimal=4, | ||
) | ||
obs, reward, _, _ = envs.step(envs.action_space.sample()) | ||
assert_almost_equal( | ||
envs.obs_rms.mean, | ||
np.mean([1.0]), # the mean of first and second observations [[0, 1], [1, 2]] | ||
decimal=4, | ||
) | ||
|
||
|
||
def test_normalize_return_vector_env(): | ||
env_fns = [make_env(0), make_env(1)] | ||
envs = gym.vector.SyncVectorEnv(env_fns) | ||
envs = NormalizeReward(envs) | ||
obs = envs.reset() | ||
obs, reward, _, _ = envs.step(envs.action_space.sample()) | ||
assert_almost_equal( | ||
envs.return_rms.mean, | ||
np.mean([1.5]), # the mean of first returns [[1, 2]] | ||
decimal=4, | ||
) | ||
obs, reward, _, _ = envs.step(envs.action_space.sample()) | ||
assert_almost_equal( | ||
envs.return_rms.mean, | ||
np.mean( | ||
[[1, 2], [2 + envs.gamma * 1, 3 + envs.gamma * 2]] | ||
), # the mean of first and second returns [[1, 2], [2 + envs.gamma * 1, 3 + envs.gamma * 2]] | ||
decimal=4, | ||
) |