Skip to content

Commit

Permalink
[RLlib] Fix TorchMultiCategorical.to_deterministic to return a mult…
Browse files Browse the repository at this point in the history
…i-dimensional tensor instead of a list. (ray-project#49098)

Signed-off-by: gexiaoxiao7 <1004083966@qq.com>
  • Loading branch information
simonsays1980 authored and gexiaoxiao7 committed Dec 9, 2024
1 parent d9417dc commit 1372174
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions rllib/models/torch/torch_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,17 @@ def from_logits(

return TorchMultiCategorical(categoricals=categoricals)

def to_deterministic(self) -> "TorchMultiDistribution":
return TorchMultiDistribution([cat.to_deterministic() for cat in self._cats])
def to_deterministic(self) -> "TorchDeterministic":
if self._cats[0].probs is not None:
probs_or_logits = nn.utils.rnn.pad_sequence(
[cat.logits.t() for cat in self._cats], padding_value=-torch.inf
)
else:
probs_or_logits = nn.utils.rnn.pad_sequence(
[cat.logits.t() for cat in self._cats], padding_value=-torch.inf
)

return TorchDeterministic(loc=torch.argmax(probs_or_logits, dim=0))


@DeveloperAPI
Expand Down

0 comments on commit 1372174

Please sign in to comment.