-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward
to further simplify the user experience).
#47889
Changes from 2 commits
103eb20
2d99364
d85d95e
9196ca2
5d32009
26e14e2
1a08119
717f9e8
8a009a9
be89e68
9ac093b
81da5de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,8 @@ def compute_loss_for_module( | |
batch: Dict[str, Any], | ||
fwd_out: Dict[str, TensorType], | ||
) -> TensorType: | ||
module = self.module[module_id].unwrapped() | ||
|
||
# Possibly apply masking to some sub loss terms and to the total loss term | ||
# at the end. Masking could be used for RNN-based model (zero padded `batch`) | ||
# and for PPO's batched value function (and bootstrap value) computations, | ||
|
@@ -55,12 +57,8 @@ def possibly_masked_mean(data_): | |
else: | ||
possibly_masked_mean = torch.mean | ||
|
||
action_dist_class_train = ( | ||
self.module[module_id].unwrapped().get_train_action_dist_cls() | ||
) | ||
action_dist_class_exploration = ( | ||
self.module[module_id].unwrapped().get_exploration_action_dist_cls() | ||
) | ||
action_dist_class_train = module.get_train_action_dist_cls() | ||
action_dist_class_exploration = module.get_exploration_action_dist_cls() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey, I have a question here: shouldn't This affects KL loss calculation which might end up using a different distribution class ( |
||
|
||
curr_action_dist = action_dist_class_train.from_logits( | ||
fwd_out[Columns.ACTION_DIST_INPUTS] | ||
|
@@ -91,12 +89,14 @@ def possibly_masked_mean(data_): | |
|
||
# Compute a value function loss. | ||
if config.use_critic: | ||
value_fn_out = fwd_out[Columns.VF_PREDS] | ||
value_fn_out = module.compute_values( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this gives again problems in the DDP case. I remember similar problems with CQL and SAC when not running everything in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, good point, I think you are right. Let's see what the tests say ... |
||
batch, features=fwd_out.get(Columns.FEATURES) | ||
) | ||
vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0) | ||
vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param) | ||
mean_vf_loss = possibly_masked_mean(vf_loss_clipped) | ||
mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) | ||
# Ignore the value function. | ||
# Ignore the value function -> Set all to 0.0. | ||
else: | ||
z = torch.tensor(0.0, device=surrogate_loss.device) | ||
value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Any, Dict | ||
from typing import Any, Dict, Optional | ||
|
||
from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule | ||
from ray.rllib.core.columns import Columns | ||
|
@@ -17,63 +17,50 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule): | |
framework: str = "torch" | ||
|
||
@override(RLModule) | ||
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]: | ||
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
"""Default forward pass (used for inference and exploration).""" | ||
output = {} | ||
|
||
# Encoder forward pass. | ||
encoder_outs = self.encoder(batch) | ||
# Stateful encoder? | ||
if Columns.STATE_OUT in encoder_outs: | ||
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] | ||
|
||
# Pi head. | ||
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) | ||
|
||
return output | ||
|
||
@override(RLModule) | ||
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
return self._forward_inference(batch) | ||
|
||
@override(RLModule) | ||
def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: | ||
if self.config.inference_only: | ||
raise RuntimeError( | ||
"Trying to train a module that is not a learner module. Set the " | ||
"flag `inference_only=False` when building the module." | ||
) | ||
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
"""Train forward pass (keep features for possible shared value func. call).""" | ||
output = {} | ||
|
||
# Shared encoder. | ||
encoder_outs = self.encoder(batch) | ||
output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right! |
||
if Columns.STATE_OUT in encoder_outs: | ||
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] | ||
|
||
# Value head. | ||
vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC]) | ||
# Squeeze out last dim (value function node). | ||
output[Columns.VF_PREDS] = vf_out.squeeze(-1) | ||
|
||
# Policy head. | ||
action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) | ||
output[Columns.ACTION_DIST_INPUTS] = action_logits | ||
|
||
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) | ||
return output | ||
|
||
@override(ValueFunctionAPI) | ||
def compute_values(self, batch: Dict[str, Any]) -> TensorType: | ||
# Separate vf-encoder. | ||
if hasattr(self.encoder, "critic_encoder"): | ||
batch_ = batch | ||
if self.is_stateful(): | ||
# The recurrent encoders expect a `(state_in, h)` key in the | ||
# input dict while the key returned is `(state_in, critic, h)`. | ||
batch_ = batch.copy() | ||
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] | ||
encoder_outs = self.encoder.critic_encoder(batch_)[ENCODER_OUT] | ||
# Shared encoder. | ||
else: | ||
encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC] | ||
def compute_values( | ||
self, | ||
batch: Dict[str, Any], | ||
features: Optional[Any] = None, | ||
) -> TensorType: | ||
if features is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. This would mean that we would have to change the batch (add a new key to it) during the update procedure, which might clash when we have to (torch) compile this operation. We had the same problem with tf-static graph.
|
||
# Separate vf-encoder. | ||
if hasattr(self.encoder, "critic_encoder"): | ||
batch_ = batch | ||
if self.is_stateful(): | ||
# The recurrent encoders expect a `(state_in, h)` key in the | ||
# input dict while the key returned is `(state_in, critic, h)`. | ||
batch_ = batch.copy() | ||
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] | ||
features = self.encoder.critic_encoder(batch_)[ENCODER_OUT] | ||
# Shared encoder. | ||
else: | ||
features = self.encoder(batch)[ENCODER_OUT][CRITIC] | ||
|
||
# Value head. | ||
vf_out = self.vf(encoder_outs) | ||
vf_out = self.vf(features) | ||
# Squeeze out last dimension (single node value head). | ||
return vf_out.squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to use
unwrapped
in case DDP is used?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question! DDP already wraps this method to use the
unwrapped
underlying RLModule, so this is ok here.