From a8f551ed4426f8f88ed1fb4d59fcf30f770cf79a Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 12 Sep 2021 00:54:52 +0800 Subject: [PATCH] Make `MultiDiscrete` a `Tuple`-like space (#2364) * Make MultiDiscrete a Tuple-like space * Update test cases for MultiDiscrete --- gym/spaces/multi_discrete.py | 17 +++++++++- gym/spaces/tests/test_spaces.py | 60 ++++++++++++++++++++++++++++++--- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/gym/spaces/multi_discrete.py b/gym/spaces/multi_discrete.py index e4d30a0fa7e..4ab7b41b27a 100644 --- a/gym/spaces/multi_discrete.py +++ b/gym/spaces/multi_discrete.py @@ -1,5 +1,7 @@ import numpy as np +from gym.logger import warn from .space import Space +from .discrete import Discrete class MultiDiscrete(Space): @@ -24,7 +26,6 @@ class MultiDiscrete(Space): """ def __init__(self, nvec, dtype=np.int64): - """ nvec: vector of counts of each categorical variable """ @@ -54,5 +55,19 @@ def from_jsonable(self, sample_n): def __repr__(self): return "MultiDiscrete({})".format(self.nvec) + def __getitem__(self, index): + nvec = self.nvec[index] + if nvec.ndim == 0: + subspace = Discrete(nvec) + else: + subspace = MultiDiscrete(nvec, self.dtype) + subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility + return subspace + + def __len__(self): + if self.nvec.ndim >= 2: + warn("Get length of a multi-dimensional MultiDiscrete space.") + return len(self.nvec) + def __eq__(self, other): return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec) diff --git a/gym/spaces/tests/test_spaces.py b/gym/spaces/tests/test_spaces.py index 5a43f154d60..1d427789809 100644 --- a/gym/spaces/tests/test_spaces.py +++ b/gym/spaces/tests/test_spaces.py @@ -1,6 +1,5 @@ import json # note: ujson fails this test due to float equality import copy -from collections import OrderedDict import numpy as np import pytest @@ -244,6 +243,10 @@ def convert_sample_hashable(sample): return sample +def sample_equal(sample1, sample2): + return convert_sample_hashable(sample1) == convert_sample_hashable(sample2) + + @pytest.mark.parametrize( "space", [ @@ -277,9 +280,7 @@ def test_seed_reproducibility(space): space2.seed(None) assert space1.seed(0) == space2.seed(0) - - sample1, sample2 = space1.sample(), space2.sample() - assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2) + assert sample_equal(space1.sample(), space2.sample()) @pytest.mark.parametrize( @@ -314,3 +315,54 @@ def test_seed_subspace_incorrelated(space): ] assert len(states) == len(set(states)) + + +def test_multidiscrete_as_tuple(): + # 1D multi-discrete + space = MultiDiscrete([3, 4, 5]) + + assert space.shape == (3,) + assert space[0] == Discrete(3) + assert space[0:1] == MultiDiscrete([3]) + assert space[0:2] == MultiDiscrete([3, 4]) + assert space[:] == space and space[:] is not space + assert len(space) == 3 + + # 2D multi-discrete + space = MultiDiscrete([[3, 4, 5], [6, 7, 8]]) + + assert space.shape == (2, 3) + assert space[0, 1] == Discrete(4) + assert space[0] == MultiDiscrete([3, 4, 5]) + assert space[0:1] == MultiDiscrete([[3, 4, 5]]) + assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]]) + assert space[:, 0:1] == MultiDiscrete([[3], [6]]) + assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]]) + assert space[:] == space and space[:] is not space + assert space[:, :] == space and space[:, :] is not space + + +def test_multidiscrete_subspace_reproducibility(): + # 1D multi-discrete + space = MultiDiscrete([100, 200, 300]) + space.seed(None) + + assert sample_equal(space[0].sample(), space[0].sample()) + assert sample_equal(space[0:1].sample(), space[0:1].sample()) + assert sample_equal(space[0:2].sample(), space[0:2].sample()) + assert sample_equal(space[:].sample(), space[:].sample()) + assert sample_equal(space[:].sample(), space.sample()) + + # 2D multi-discrete + space = MultiDiscrete([[300, 400, 500], [600, 700, 800]]) + space.seed(None) + + assert sample_equal(space[0, 1].sample(), space[0, 1].sample()) + assert sample_equal(space[0].sample(), space[0].sample()) + assert sample_equal(space[0:1].sample(), space[0:1].sample()) + assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample()) + assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample()) + assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample()) + assert sample_equal(space[:].sample(), space[:].sample()) + assert sample_equal(space[:, :].sample(), space[:, :].sample()) + assert sample_equal(space[:, :].sample(), space.sample())