diff --git a/gym/spaces/space.py b/gym/spaces/space.py index 7274e94b4d5..4c077064390 100644 --- a/gym/spaces/space.py +++ b/gym/spaces/space.py @@ -60,6 +60,25 @@ def contains(self, x): def __contains__(self, x): return self.contains(x) + def __setstate__(self, state): + # Don't mutate the original state + state = dict(state) + + # Allow for loading of legacy states. + # See: + # https://github.com/openai/gym/pull/2397 -- shape + # https://github.com/openai/gym/pull/1913 -- np_random + # + if "shape" in state: + state["_shape"] = state["shape"] + del state["shape"] + if "np_random" in state: + state["_np_random"] = state["np_random"] + del state["np_random"] + + # Update our state + self.__dict__.update(state) + def to_jsonable(self, sample_n): """Convert a batch of samples from this space to a JSONable data type.""" # By default, assume identity is JSONable diff --git a/gym/spaces/tests/test_spaces.py b/gym/spaces/tests/test_spaces.py index 3198857d267..3bdc23ffca7 100644 --- a/gym/spaces/tests/test_spaces.py +++ b/gym/spaces/tests/test_spaces.py @@ -413,3 +413,25 @@ def test_multidiscrete_subspace_reproducibility(): assert sample_equal(space[:].sample(), space[:].sample()) assert sample_equal(space[:, :].sample(), space[:, :].sample()) assert sample_equal(space[:, :].sample(), space.sample()) + + +def test_space_legacy_state_pickling(): + legacy_state = { + "shape": ( + 1, + 2, + 3, + ), + "dtype": np.int64, + "np_random": np.random.default_rng(), + "n": 3, + } + space = Discrete(1) + space.__setstate__(legacy_state) + + assert space.shape == legacy_state["shape"] + assert space._shape == legacy_state["shape"] + assert space.np_random == legacy_state["np_random"] + assert space._np_random == legacy_state["np_random"] + assert space.n == 3 + assert space.dtype == legacy_state["dtype"]