Skip to content

Commit

Permalink
[Update space]: support flat_dim, flatten, unflatten (fixes #1310) (#…
Browse files Browse the repository at this point in the history
…1319)

* Update space.py

* Update box.py

* Update discrete.py

* Update tuple_space.py

* Update box.py

* Update box.py

* Update discrete.py

* Update space.py

* Update box.py

* Update discrete.py

* Update tuple_space.py

* Update multi_binary.py

* Update multi_discrete.py

* Update and rename dict_space.py to dict.py

* Update tuple_space.py

* Rename tuple_space.py to tuple.py

* Update __init__.py

* Update multi_binary.py

* Update multi_discrete.py

* Update space.py

* Update box.py

* Update discrete.py

* Update multi_binary.py

* Update multi_discrete.py

* Update __init__.py

* Update __init__.py

* Update multi_discrete.py

* Update __init__.py

* Update box.py

* Update box.py

* Update multi_discrete.py

* Update discrete.py

* Update multi_discrete.py

* Update discrete.py

* Update dict.py

* Update dict.py

* Update multi_binary.py

* Update multi_discrete.py

* Update tuple.py

* Update discrete.py

* Update __init__.py

* Update box.py

* Update and rename dict.py to dict_space.py

* Update dict_space.py

* Update dict_space.py

* Update dict_space.py

* Update discrete.py

* Update multi_binary.py

* Create utils.py

* Update __init__.py

* Update multi_discrete.py

* Update multi_discrete.py

* Update space.py

* Update and rename tuple.py to tuple_space.py
  • Loading branch information
zuoxingdong authored and pzhokhov committed Mar 24, 2019
1 parent 07645bd commit 5efcd86
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
6 changes: 5 additions & 1 deletion gym/spaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
from gym.spaces.tuple_space import Tuple
from gym.spaces.dict_space import Dict

__all__ = ["Space", "Box", "Discrete", "MultiDiscrete", "MultiBinary", "Tuple", "Dict"]
from gym.spaces.utils import flatdim
from gym.spaces.utils import flatten
from gym.spaces.utils import unflatten

__all__ = ["Space", "Box", "Discrete", "MultiDiscrete", "MultiBinary", "Tuple", "Dict", "flatdim", "flatten", "unflatten"]
69 changes: 69 additions & 0 deletions gym/spaces/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np

from gym.spaces import Box
from gym.spaces import Discrete
from gym.spaces import MultiDiscrete
from gym.spaces import MultiBinary
from gym.spaces import Tuple
from gym.spaces import Dict


def flatdim(space):
if isinstance(space, Box):
return int(np.prod(space.shape))
elif isinstance(space, Discrete):
return int(space.n)
elif isinstance(space, Tuple):
return int(sum([flatdim(s) for s in space.spaces]))
elif isinstance(space, Dict):
return int(sum([flatdim(s) for s in space.spaces.values()]))
elif isinstance(space, MultiBinary):
return int(space.n)
elif isinstance(space, MultiDiscrete):
return int(np.prod(space.shape))
else:
raise NotImplementedError


def flatten(space, x):
if isinstance(space, Box):
return np.asarray(x, dtype=np.float32).flatten()
elif isinstance(space, Discrete):
onehot = np.zeros(space.n, dtype=np.float32)
onehot[x] = 1.0
return onehot
elif isinstance(space, Tuple):
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
elif isinstance(space, Dict):
return np.concatenate([flatten(space.spaces[key], item) for key, item in x.items()])
elif isinstance(space, MultiBinary):
return np.asarray(x).flatten()
elif isinstance(space, MultiDiscrete):
return np.asarray(x).flatten()
else:
raise NotImplementedError


def unflatten(space, x):
if isinstance(space, Box):
return np.asarray(x, dtype=np.float32).reshape(space.shape)
elif isinstance(space, Discrete):
return int(np.nonzero(x)[0][0])
elif isinstance(space, Tuple):
dims = [flatdim(s) for s in space.spaces]
list_flattened = np.split(x, np.cumsum(dims)[:-1])
list_unflattened = [unflatten(s, flattened)
for flattened, s in zip(list_flattened, space.spaces)]
return tuple(list_unflattened)
elif isinstance(space, Dict):
dims = [flatdim(s) for s in space.spaces.values()]
list_flattened = np.split(x, np.cumsum(dims)[:-1])
list_unflattened = [(key, unflatten(s, flattened))
for flattened, (key, s) in zip(list_flattened, space.spaces.items())]
return dict(list_unflattened)
elif isinstance(space, MultiBinary):
return np.asarray(x).reshape(space.shape)
elif isinstance(space, MultiDiscrete):
return np.asarray(x).reshape(space.shape)
else:
raise NotImplementedError

0 comments on commit 5efcd86

Please sign in to comment.