Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Better support MultiBinary spaces by treating Tuples as superset of them in ComplexInputNet #28900

Merged
7 changes: 5 additions & 2 deletions rllib/models/tf/complex_input_net.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gym.spaces import Box, Discrete, MultiDiscrete
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary
import numpy as np
import tree # pip install dm_tree

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name):
concat_size = 0
for i, component in enumerate(self.flattened_input_space):
# Image space.
if len(component.shape) == 3:
if len(component.shape) == 3 and isinstance(component, Box):
config = {
"conv_filters": model_config["conv_filters"]
if "conv_filters" in model_config
Expand All @@ -78,6 +78,9 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name):
elif isinstance(component, (Discrete, MultiDiscrete)):
if isinstance(component, Discrete):
size = component.n
elif isinstance(component, MultiBinary):
# Treat MultiBinary as Tuple
size = np.product(component.n)
else:
size = np.sum(component.nvec)
config = {
Expand Down
7 changes: 5 additions & 2 deletions rllib/models/torch/complex_input_net.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gym.spaces import Box, Discrete, MultiDiscrete
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary
import numpy as np
import tree # pip install dm_tree

Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name):
concat_size = 0
for i, component in enumerate(self.flattened_input_space):
# Image space.
if len(component.shape) == 3:
if len(component.shape) == 3 and isinstance(component, Box):
config = {
"conv_filters": model_config["conv_filters"]
if "conv_filters" in model_config
Expand Down Expand Up @@ -99,6 +99,9 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name):
elif isinstance(component, (Discrete, MultiDiscrete)):
if isinstance(component, Discrete):
size = component.n
elif isinstance(component, MultiBinary):
# Treat MultiBinary as Tuple
size = np.product(component.n)
else:
size = np.sum(component.nvec)
config = {
Expand Down
7 changes: 5 additions & 2 deletions rllib/tests/test_supported_spaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gym.spaces import Box, Dict, Discrete, Tuple, MultiDiscrete
from gym.spaces import Box, Dict, Discrete, Tuple, MultiDiscrete, MultiBinary
import numpy as np
import unittest

Expand Down Expand Up @@ -30,6 +30,7 @@
"yet_another_nested_dict": Dict({"a": Tuple([Discrete(2), Discrete(3)])}),
}
),
# TODO: Support "multi_binary": MultiBinary([...]),
}

OBSERVATION_SPACES_TO_TEST = {
Expand All @@ -45,6 +46,7 @@
"position": Box(-1.0, 1.0, (5,), dtype=np.float32),
}
),
"multi_binary": MultiBinary([3, 10, 10]),
}


Expand Down Expand Up @@ -120,7 +122,8 @@ def _do_check(alg, config, a_name, o_name):
for _ in framework_iterator(config, frameworks=frameworks):
# Zip through action- and obs-spaces.
for a_name, o_name in zip(
ACTION_SPACES_TO_TEST.keys(), OBSERVATION_SPACES_TO_TEST.keys()
ACTION_SPACES_TO_TEST.keys(),
list(OBSERVATION_SPACES_TO_TEST.keys()) + ["discrete"],
):
_do_check(alg, config, a_name, o_name)
# Do the remaining obs spaces.
Expand Down