Skip to content

Commit

Permalink
Make MultiDiscrete a Tuple-like space (openai#2364)
Browse files Browse the repository at this point in the history
* Make MultiDiscrete a Tuple-like space

* Update test cases for MultiDiscrete
  • Loading branch information
XuehaiPan authored Sep 11, 2021
1 parent 8da6224 commit a8f551e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
17 changes: 16 additions & 1 deletion gym/spaces/multi_discrete.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from gym.logger import warn
from .space import Space
from .discrete import Discrete


class MultiDiscrete(Space):
Expand All @@ -24,7 +26,6 @@ class MultiDiscrete(Space):
"""

def __init__(self, nvec, dtype=np.int64):

"""
nvec: vector of counts of each categorical variable
"""
Expand Down Expand Up @@ -54,5 +55,19 @@ def from_jsonable(self, sample_n):
def __repr__(self):
return "MultiDiscrete({})".format(self.nvec)

def __getitem__(self, index):
nvec = self.nvec[index]
if nvec.ndim == 0:
subspace = Discrete(nvec)
else:
subspace = MultiDiscrete(nvec, self.dtype)
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
return subspace

def __len__(self):
if self.nvec.ndim >= 2:
warn("Get length of a multi-dimensional MultiDiscrete space.")
return len(self.nvec)

def __eq__(self, other):
return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)
60 changes: 56 additions & 4 deletions gym/spaces/tests/test_spaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json # note: ujson fails this test due to float equality
import copy
from collections import OrderedDict

import numpy as np
import pytest
Expand Down Expand Up @@ -244,6 +243,10 @@ def convert_sample_hashable(sample):
return sample


def sample_equal(sample1, sample2):
return convert_sample_hashable(sample1) == convert_sample_hashable(sample2)


@pytest.mark.parametrize(
"space",
[
Expand Down Expand Up @@ -277,9 +280,7 @@ def test_seed_reproducibility(space):
space2.seed(None)

assert space1.seed(0) == space2.seed(0)

sample1, sample2 = space1.sample(), space2.sample()
assert convert_sample_hashable(sample1) == convert_sample_hashable(sample2)
assert sample_equal(space1.sample(), space2.sample())


@pytest.mark.parametrize(
Expand Down Expand Up @@ -314,3 +315,54 @@ def test_seed_subspace_incorrelated(space):
]

assert len(states) == len(set(states))


def test_multidiscrete_as_tuple():
# 1D multi-discrete
space = MultiDiscrete([3, 4, 5])

assert space.shape == (3,)
assert space[0] == Discrete(3)
assert space[0:1] == MultiDiscrete([3])
assert space[0:2] == MultiDiscrete([3, 4])
assert space[:] == space and space[:] is not space
assert len(space) == 3

# 2D multi-discrete
space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])

assert space.shape == (2, 3)
assert space[0, 1] == Discrete(4)
assert space[0] == MultiDiscrete([3, 4, 5])
assert space[0:1] == MultiDiscrete([[3, 4, 5]])
assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
assert space[:, 0:1] == MultiDiscrete([[3], [6]])
assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
assert space[:] == space and space[:] is not space
assert space[:, :] == space and space[:, :] is not space


def test_multidiscrete_subspace_reproducibility():
# 1D multi-discrete
space = MultiDiscrete([100, 200, 300])
space.seed(None)

assert sample_equal(space[0].sample(), space[0].sample())
assert sample_equal(space[0:1].sample(), space[0:1].sample())
assert sample_equal(space[0:2].sample(), space[0:2].sample())
assert sample_equal(space[:].sample(), space[:].sample())
assert sample_equal(space[:].sample(), space.sample())

# 2D multi-discrete
space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
space.seed(None)

assert sample_equal(space[0, 1].sample(), space[0, 1].sample())
assert sample_equal(space[0].sample(), space[0].sample())
assert sample_equal(space[0:1].sample(), space[0:1].sample())
assert sample_equal(space[0:2, :].sample(), space[0:2, :].sample())
assert sample_equal(space[:, 0:1].sample(), space[:, 0:1].sample())
assert sample_equal(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
assert sample_equal(space[:].sample(), space[:].sample())
assert sample_equal(space[:, :].sample(), space[:, :].sample())
assert sample_equal(space[:, :].sample(), space.sample())

0 comments on commit a8f551e

Please sign in to comment.