diff --git a/gym/spaces/box.py b/gym/spaces/box.py index e621cfc35d4..12a8c4b1be0 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -23,7 +23,7 @@ class Box(Space): """ - def __init__(self, low, high, shape=None, dtype=np.float32): + def __init__(self, low, high, shape=None, dtype=np.float32, seed=None): assert dtype is not None, "dtype must be explicitly provided. " self.dtype = np.dtype(dtype) @@ -81,7 +81,7 @@ def _get_precision(dtype): self.bounded_below = -np.inf < self.low self.bounded_above = np.inf > self.high - super(Box, self).__init__(self.shape, self.dtype) + super(Box, self).__init__(self.shape, self.dtype, seed) def is_bounded(self, manner="both"): below = np.all(self.bounded_below) diff --git a/gym/spaces/dict.py b/gym/spaces/dict.py index 2a9d6ecc669..3e82c77b901 100644 --- a/gym/spaces/dict.py +++ b/gym/spaces/dict.py @@ -33,10 +33,11 @@ class Dict(Space): }) """ - def __init__(self, spaces=None, **spaces_kwargs): + def __init__(self, spaces=None, seed=None, **spaces_kwargs): assert (spaces is None) or ( not spaces_kwargs ), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)" + if spaces is None: spaces = spaces_kwargs if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict): @@ -49,28 +50,45 @@ def __init__(self, spaces=None, **spaces_kwargs): space, Space ), "Values of the dict should be instances of gym.Space" super(Dict, self).__init__( - None, None + None, None, seed ) # None for shape and dtype, since it'll require special handling def seed(self, seed=None): - seed = super().seed(seed) - try: - subseeds = self.np_random.choice( - np.iinfo(int).max, - size=len(self.spaces), - replace=False, # unique subseed for each subspace - ) - except ValueError: - subseeds = self.np_random.choice( - np.iinfo(int).max, - size=len(self.spaces), - replace=True, # we get more than INT_MAX subspaces - ) - - for subspace, subseed in zip(self.spaces.values(), subseeds): - seed.append(subspace.seed(int(subseed))[0]) - - return seed + seeds = [] + if isinstance(seed, dict): + for key, seed_key in zip(self.spaces, seed): + assert key == seed_key, print( + "Key value", + seed_key, + "in passed seed dict did not match key value", + key, + "in spaces Dict.", + ) + seeds += self.spaces[key].seed(seed[seed_key]) + elif isinstance(seed, int): + seeds = super().seed(seed) + try: + subseeds = self.np_random.choice( + np.iinfo(int).max, + size=len(self.spaces), + replace=False, # unique subseed for each subspace + ) + except ValueError: + subseeds = self.np_random.choice( + np.iinfo(int).max, + size=len(self.spaces), + replace=True, # we get more than INT_MAX subspaces + ) + + for subspace, subseed in zip(self.spaces.values(), subseeds): + seeds.append(subspace.seed(int(subseed))[0]) + elif seed is None: + for space in self.spaces.values(): + seeds += space.seed(seed) + else: + raise TypeError("Passed seed not of an expected type: dict or int or None") + + return seeds def sample(self): return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()]) diff --git a/gym/spaces/discrete.py b/gym/spaces/discrete.py index bdbecd942ce..eb358e6c1af 100644 --- a/gym/spaces/discrete.py +++ b/gym/spaces/discrete.py @@ -11,10 +11,10 @@ class Discrete(Space): """ - def __init__(self, n): + def __init__(self, n, seed=None): assert n >= 0 self.n = n - super(Discrete, self).__init__((), np.int64) + super(Discrete, self).__init__((), np.int64, seed) def sample(self): return self.np_random.randint(self.n) diff --git a/gym/spaces/multi_binary.py b/gym/spaces/multi_binary.py index 5054983972c..d8b315a1c88 100644 --- a/gym/spaces/multi_binary.py +++ b/gym/spaces/multi_binary.py @@ -26,13 +26,13 @@ class MultiBinary(Space): """ - def __init__(self, n): + def __init__(self, n, seed=None): self.n = n if type(n) in [tuple, list, np.ndarray]: input_n = n else: input_n = (n,) - super(MultiBinary, self).__init__(input_n, np.int8) + super(MultiBinary, self).__init__(input_n, np.int8, seed) def sample(self): return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype) diff --git a/gym/spaces/multi_discrete.py b/gym/spaces/multi_discrete.py index 4ab7b41b27a..fdfdb19fd51 100644 --- a/gym/spaces/multi_discrete.py +++ b/gym/spaces/multi_discrete.py @@ -25,14 +25,14 @@ class MultiDiscrete(Space): """ - def __init__(self, nvec, dtype=np.int64): + def __init__(self, nvec, dtype=np.int64, seed=None): """ nvec: vector of counts of each categorical variable """ assert (np.array(nvec) > 0).all(), "nvec (counts) have to be positive" self.nvec = np.asarray(nvec, dtype=dtype) - super(MultiDiscrete, self).__init__(self.nvec.shape, dtype) + super(MultiDiscrete, self).__init__(self.nvec.shape, dtype, seed) def sample(self): return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype( diff --git a/gym/spaces/space.py b/gym/spaces/space.py index 9bf124cef3d..7274e94b4d5 100644 --- a/gym/spaces/space.py +++ b/gym/spaces/space.py @@ -16,12 +16,14 @@ class Space(object): not handle custom spaces properly. Use custom spaces with care. """ - def __init__(self, shape=None, dtype=None): + def __init__(self, shape=None, dtype=None, seed=None): import numpy as np # takes about 300-400ms to import, so we load lazily self._shape = None if shape is None else tuple(shape) self.dtype = None if dtype is None else np.dtype(dtype) self._np_random = None + if seed is not None: + self.seed(seed) @property def np_random(self): diff --git a/gym/spaces/tests/test_spaces.py b/gym/spaces/tests/test_spaces.py index 1d427789809..3198857d267 100644 --- a/gym/spaces/tests/test_spaces.py +++ b/gym/spaces/tests/test_spaces.py @@ -180,6 +180,53 @@ def test_bad_space_calls(space_fn): space_fn() +def test_seed_Dict(): + test_space = Dict( + { + "a": Box(low=0, high=1, shape=(3, 3)), + "b": Dict( + { + "b_1": Box(low=-100, high=100, shape=(2,)), + "b_2": Box(low=-1, high=1, shape=(2,)), + } + ), + "c": Discrete(5), + } + ) + + seed_dict = { + "a": 0, + "b": { + "b_1": 1, + "b_2": 2, + }, + "c": 3, + } + + test_space.seed(seed_dict) + + # "Unpack" the dict sub-spaces into individual spaces + a = Box(low=0, high=1, shape=(3, 3)) + a.seed(0) + b_1 = Box(low=-100, high=100, shape=(2,)) + b_1.seed(1) + b_2 = Box(low=-1, high=1, shape=(2,)) + b_2.seed(2) + c = Discrete(5) + c.seed(3) + + for i in range(10): + test_s = test_space.sample() + a_s = a.sample() + assert (test_s["a"] == a_s).all() + b_1_s = b_1.sample() + assert (test_s["b"]["b_1"] == b_1_s).all() + b_2_s = b_2.sample() + assert (test_s["b"]["b_2"] == b_2_s).all() + c_s = c.sample() + assert test_s["c"] == c_s + + def test_box_dtype_check(): # Related Issues: # https://github.com/openai/gym/issues/2357 diff --git a/gym/spaces/tuple.py b/gym/spaces/tuple.py index c00935ff23a..bb3133a5063 100644 --- a/gym/spaces/tuple.py +++ b/gym/spaces/tuple.py @@ -10,33 +10,44 @@ class Tuple(Space): self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3))) """ - def __init__(self, spaces): + def __init__(self, spaces, seed=None): self.spaces = spaces for space in spaces: assert isinstance( space, Space ), "Elements of the tuple must be instances of gym.Space" - super(Tuple, self).__init__(None, None) + super(Tuple, self).__init__(None, None, seed) def seed(self, seed=None): - seed = super().seed(seed) - try: - subseeds = self.np_random.choice( - np.iinfo(int).max, - size=len(self.spaces), - replace=False, # unique subseed for each subspace - ) - except ValueError: - subseeds = self.np_random.choice( - np.iinfo(int).max, - size=len(self.spaces), - replace=True, # we get more than INT_MAX subspaces - ) + seeds = [] + + if isinstance(seed, list): + for i, space in enumerate(self.spaces): + seeds += space.seed(seed[i]) + elif isinstance(seed, int): + seeds = super().seed(seed) + try: + subseeds = self.np_random.choice( + np.iinfo(int).max, + size=len(self.spaces), + replace=False, # unique subseed for each subspace + ) + except ValueError: + subseeds = self.np_random.choice( + np.iinfo(int).max, + size=len(self.spaces), + replace=True, # we get more than INT_MAX subspaces + ) - for subspace, subseed in zip(self.spaces, subseeds): - seed.append(subspace.seed(int(subseed))[0]) + for subspace, subseed in zip(self.spaces, subseeds): + seeds.append(subspace.seed(int(subseed))[0]) + elif seed is None: + for space in self.spaces: + seeds += space.seed(seed) + else: + raise TypeError("Passed seed not of an expected type: list or int or None") - return seed + return seeds def sample(self): return tuple([space.sample() for space in self.spaces])