-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce gen…
…eric `_forward` to further simplify the user experience). (#47889)
- Loading branch information
Showing
29 changed files
with
545 additions
and
585 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,32 @@ | ||
from ray.rllib.algorithms.bc.bc_rl_module import BCRLModule | ||
from typing import Any, Dict | ||
|
||
from ray.rllib.core import Columns | ||
from ray.rllib.core.models.base import ENCODER_OUT | ||
from ray.rllib.core.rl_module.rl_module import RLModule | ||
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule | ||
from ray.rllib.utils.annotations import override | ||
|
||
|
||
class BCTorchRLModule(TorchRLModule): | ||
@override(RLModule) | ||
def setup(self): | ||
# __sphinx_doc_begin__ | ||
# Build models from catalog | ||
self.encoder = self.catalog.build_encoder(framework=self.framework) | ||
self.pi = self.catalog.build_pi_head(framework=self.framework) | ||
|
||
@override(RLModule) | ||
def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]: | ||
"""Generic BC forward pass (for all phases of training/evaluation).""" | ||
output = {} | ||
|
||
# State encodings. | ||
encoder_outs = self.encoder(batch) | ||
if Columns.STATE_OUT in encoder_outs: | ||
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] | ||
|
||
# Actions. | ||
action_logits = self.pi(encoder_outs[ENCODER_OUT]) | ||
output[Columns.ACTION_DIST_INPUTS] = action_logits | ||
|
||
class BCTorchRLModule(TorchRLModule, BCRLModule): | ||
pass | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.