Skip to content

Commit

Permalink
Make Tuple and Dicts be seedable with lists and dicts of seeds + make…
Browse files Browse the repository at this point in the history
… the seed in default initialization controllable (openai#1774)

* Make the seed in default initialization controllable

Since seed() is being called in default initialization of Space, it should be controllable for reproducibility.

* Updated derived classes of Space to have their seeds controllable at initialization.

* Allow Tuple's spaces to each have their own seed

* Added dict based seeding for Dict space; test cases for Tuple and Dict seeding

* Update discrete.py

* Update test_spaces.py

* Add seed to __init__()

* blacked

* Fix black

* Fix failing tests
  • Loading branch information
RaghuSpaceRajan authored Sep 13, 2021
1 parent 0eabbaf commit c571b0d
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 47 deletions.
4 changes: 2 additions & 2 deletions gym/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Box(Space):
"""

def __init__(self, low, high, shape=None, dtype=np.float32):
def __init__(self, low, high, shape=None, dtype=np.float32, seed=None):
assert dtype is not None, "dtype must be explicitly provided. "
self.dtype = np.dtype(dtype)

Expand Down Expand Up @@ -81,7 +81,7 @@ def _get_precision(dtype):
self.bounded_below = -np.inf < self.low
self.bounded_above = np.inf > self.high

super(Box, self).__init__(self.shape, self.dtype)
super(Box, self).__init__(self.shape, self.dtype, seed)

def is_bounded(self, manner="both"):
below = np.all(self.bounded_below)
Expand Down
58 changes: 38 additions & 20 deletions gym/spaces/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ class Dict(Space):
})
"""

def __init__(self, spaces=None, **spaces_kwargs):
def __init__(self, spaces=None, seed=None, **spaces_kwargs):
assert (spaces is None) or (
not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"

if spaces is None:
spaces = spaces_kwargs
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
Expand All @@ -49,28 +50,45 @@ def __init__(self, spaces=None, **spaces_kwargs):
space, Space
), "Values of the dict should be instances of gym.Space"
super(Dict, self).__init__(
None, None
None, None, seed
) # None for shape and dtype, since it'll require special handling

def seed(self, seed=None):
seed = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)

for subspace, subseed in zip(self.spaces.values(), subseeds):
seed.append(subspace.seed(int(subseed))[0])

return seed
seeds = []
if isinstance(seed, dict):
for key, seed_key in zip(self.spaces, seed):
assert key == seed_key, print(
"Key value",
seed_key,
"in passed seed dict did not match key value",
key,
"in spaces Dict.",
)
seeds += self.spaces[key].seed(seed[seed_key])
elif isinstance(seed, int):
seeds = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)

for subspace, subseed in zip(self.spaces.values(), subseeds):
seeds.append(subspace.seed(int(subseed))[0])
elif seed is None:
for space in self.spaces.values():
seeds += space.seed(seed)
else:
raise TypeError("Passed seed not of an expected type: dict or int or None")

return seeds

def sample(self):
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
Expand Down
4 changes: 2 additions & 2 deletions gym/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ class Discrete(Space):
"""

def __init__(self, n):
def __init__(self, n, seed=None):
assert n >= 0
self.n = n
super(Discrete, self).__init__((), np.int64)
super(Discrete, self).__init__((), np.int64, seed)

def sample(self):
return self.np_random.randint(self.n)
Expand Down
4 changes: 2 additions & 2 deletions gym/spaces/multi_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ class MultiBinary(Space):
"""

def __init__(self, n):
def __init__(self, n, seed=None):
self.n = n
if type(n) in [tuple, list, np.ndarray]:
input_n = n
else:
input_n = (n,)
super(MultiBinary, self).__init__(input_n, np.int8)
super(MultiBinary, self).__init__(input_n, np.int8, seed)

def sample(self):
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype)
Expand Down
4 changes: 2 additions & 2 deletions gym/spaces/multi_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class MultiDiscrete(Space):
"""

def __init__(self, nvec, dtype=np.int64):
def __init__(self, nvec, dtype=np.int64, seed=None):
"""
nvec: vector of counts of each categorical variable
"""
assert (np.array(nvec) > 0).all(), "nvec (counts) have to be positive"
self.nvec = np.asarray(nvec, dtype=dtype)

super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype, seed)

def sample(self):
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
Expand Down
4 changes: 3 additions & 1 deletion gym/spaces/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ class Space(object):
not handle custom spaces properly. Use custom spaces with care.
"""

def __init__(self, shape=None, dtype=None):
def __init__(self, shape=None, dtype=None, seed=None):
import numpy as np # takes about 300-400ms to import, so we load lazily

self._shape = None if shape is None else tuple(shape)
self.dtype = None if dtype is None else np.dtype(dtype)
self._np_random = None
if seed is not None:
self.seed(seed)

@property
def np_random(self):
Expand Down
47 changes: 47 additions & 0 deletions gym/spaces/tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,53 @@ def test_bad_space_calls(space_fn):
space_fn()


def test_seed_Dict():
test_space = Dict(
{
"a": Box(low=0, high=1, shape=(3, 3)),
"b": Dict(
{
"b_1": Box(low=-100, high=100, shape=(2,)),
"b_2": Box(low=-1, high=1, shape=(2,)),
}
),
"c": Discrete(5),
}
)

seed_dict = {
"a": 0,
"b": {
"b_1": 1,
"b_2": 2,
},
"c": 3,
}

test_space.seed(seed_dict)

# "Unpack" the dict sub-spaces into individual spaces
a = Box(low=0, high=1, shape=(3, 3))
a.seed(0)
b_1 = Box(low=-100, high=100, shape=(2,))
b_1.seed(1)
b_2 = Box(low=-1, high=1, shape=(2,))
b_2.seed(2)
c = Discrete(5)
c.seed(3)

for i in range(10):
test_s = test_space.sample()
a_s = a.sample()
assert (test_s["a"] == a_s).all()
b_1_s = b_1.sample()
assert (test_s["b"]["b_1"] == b_1_s).all()
b_2_s = b_2.sample()
assert (test_s["b"]["b_2"] == b_2_s).all()
c_s = c.sample()
assert test_s["c"] == c_s


def test_box_dtype_check():
# Related Issues:
# https://github.com/openai/gym/issues/2357
Expand Down
47 changes: 29 additions & 18 deletions gym/spaces/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,44 @@ class Tuple(Space):
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
"""

def __init__(self, spaces):
def __init__(self, spaces, seed=None):
self.spaces = spaces
for space in spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
super(Tuple, self).__init__(None, None)
super(Tuple, self).__init__(None, None, seed)

def seed(self, seed=None):
seed = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)
seeds = []

if isinstance(seed, list):
for i, space in enumerate(self.spaces):
seeds += space.seed(seed[i])
elif isinstance(seed, int):
seeds = super().seed(seed)
try:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=False, # unique subseed for each subspace
)
except ValueError:
subseeds = self.np_random.choice(
np.iinfo(int).max,
size=len(self.spaces),
replace=True, # we get more than INT_MAX subspaces
)

for subspace, subseed in zip(self.spaces, subseeds):
seed.append(subspace.seed(int(subseed))[0])
for subspace, subseed in zip(self.spaces, subseeds):
seeds.append(subspace.seed(int(subseed))[0])
elif seed is None:
for space in self.spaces:
seeds += space.seed(seed)
else:
raise TypeError("Passed seed not of an expected type: list or int or None")

return seed
return seeds

def sample(self):
return tuple([space.sample() for space in self.spaces])
Expand Down

0 comments on commit c571b0d

Please sign in to comment.