Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Fix Multidiscrete support #4869

Merged
merged 4 commits into from
May 30, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 5 additions & 26 deletions python/ray/rllib/agents/impala/vtrace_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
from ray.rllib.models.action_dist import MultiCategorical
from ray.rllib.models.action_dist import Categorical
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.explained_variance import explained_variance
Expand Down Expand Up @@ -191,9 +191,7 @@ def __init__(self,
unpacked_outputs = tf.split(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used below in a spot.

self.model.outputs, output_hidden_shape, axis=1)

dist_inputs = unpacked_outputs if is_multidiscrete else \
self.model.outputs
action_dist = dist_class(dist_inputs)
action_dist = dist_class(self.model.outputs)

values = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
Expand Down Expand Up @@ -258,32 +256,13 @@ def make_time_major(tensor, drop_last=False):
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
bootstrap_value=make_time_major(values)[-1],
dist_class=dist_class,
dist_class=Categorical if is_multidiscrete else dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=self.config["vf_loss_coeff"],
entropy_coeff=self.config["entropy_coeff"],
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])

# KL divergence between worker and learner logits for debugging
model_dist = MultiCategorical(unpacked_outputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the KL stats since I doubt they were very useful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should stay in. They can be useful for debugging the policy (e.g. large fluctuations are a bad sign). See slide 21 here.

behaviour_dist = MultiCategorical(unpacked_behaviour_logits)

kls = model_dist.kl(behaviour_dist)
if len(kls) > 1:
self.KL_stats = {}

for i, kl in enumerate(kls):
self.KL_stats.update({
"mean_KL_{}".format(i): tf.reduce_mean(kl),
"max_KL_{}".format(i): tf.reduce_max(kl),
})
else:
self.KL_stats = {
"mean_KL": tf.reduce_mean(kls[0]),
"max_KL": tf.reduce_max(kls[0]),
}

# Initialize TFPolicy
loss_in = [
(SampleBatch.ACTIONS, actions),
Expand Down Expand Up @@ -318,7 +297,7 @@ def make_time_major(tensor, drop_last=False):
self.sess.run(tf.global_variables_initializer())

self.stats_fetches = {
LEARNER_STATS_KEY: dict({
LEARNER_STATS_KEY: {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"entropy": self.loss.entropy,
Expand All @@ -328,7 +307,7 @@ def make_time_major(tensor, drop_last=False):
"vf_explained_var": explained_variance(
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
tf.reshape(make_time_major(values, drop_last=True), [-1])),
}, **self.KL_stats),
},
}

@override(TFPolicy)
Expand Down
11 changes: 7 additions & 4 deletions python/ray/rllib/models/action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Categorical(ActionDistribution):
@override(ActionDistribution)
def logp(self, x):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=self.inputs, labels=x)
logits=self.inputs, labels=tf.cast(x, tf.int32))

@override(ActionDistribution)
def entropy(self):
Expand Down Expand Up @@ -126,14 +126,17 @@ def _build_sample_op(self):
class MultiCategorical(ActionDistribution):
"""Categorical distribution for discrete action spaces."""

def __init__(self, inputs):
self.cats = [Categorical(input_) for input_ in inputs]
def __init__(self, inputs, input_lens):
self.cats = [
Categorical(input_)
for input_ in tf.split(inputs, input_lens, axis=1)
]
self.sample_op = self._build_sample_op()

def logp(self, actions):
# If tensor is provided, unstack it into list
if isinstance(actions, tf.Tensor):
actions = tf.unstack(actions, axis=1)
actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
logps = tf.stack(
[cat.logp(act) for cat, act in zip(self.cats, actions)])
return tf.reduce_sum(logps, axis=0)
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def get_action_dist(action_space, config, dist_type=None, torch=False):
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
if torch:
raise NotImplementedError
return MultiCategorical, int(sum(action_space.nvec))
return partial(MultiCategorical, input_lens=action_space.nvec), \
int(sum(action_space.nvec))

raise NotImplementedError("Unsupported args: {} {}".format(
action_space, dist_type))
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/tests/test_supported_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import traceback

import gym
from gym.spaces import Box, Discrete, Tuple, Dict
from gym.spaces import Box, Discrete, Tuple, Dict, MultiDiscrete
from gym.envs.registration import EnvSpec
import numpy as np
import sys
Expand All @@ -17,6 +17,7 @@
ACTION_SPACES_TO_TEST = {
"discrete": Discrete(5),
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be an explicit test for the case that failed earlier? It's possible that this test already covers it though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this covers the failing case. The test fails before the other changes in the PR.

"tuple": Tuple(
[Discrete(2),
Discrete(3),
Expand Down