Skip to content

Commit

Permalink
[RLlib] Cleanup examples folder (vol 30): BC pretraining, then PPO fi…
Browse files Browse the repository at this point in the history
…netuning (new API stack with RLModule checkpoints). (ray-project#47838)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent bd81c9e commit 9ac5d16
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 29 deletions.
26 changes: 2 additions & 24 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ py_test(
name = "test_bc",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
# Include the parquet data files.
# Include the offline data files.
data = ["tests/data/cartpole/cartpole-v1_large"],
srcs = ["algorithms/bc/tests/test_bc.py"]
)
Expand Down Expand Up @@ -1060,7 +1060,7 @@ py_test(
name = "test_marwil",
tags = ["team:rllib", "algorithms_dir"],
size = "large",
# Include the parquet data folder.
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/pendulum/pendulum-v1_large",
Expand Down Expand Up @@ -1719,28 +1719,6 @@ py_test(
]
)

py_test(
name = "test_offline_prelearner",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_offline_prelearner.py"],
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/cartpole/large.json",
],
)

py_test(
name = "test_offline_prelearner",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_offline_prelearner.py"],
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/cartpole/large.json",
],
)

# --------------------------------------------------------------------
# Policies
# rllib/policy/
Expand Down
9 changes: 4 additions & 5 deletions rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,11 @@ def _forward_train(self, batch, **kwargs):
}

@override(ValueFunctionAPI)
def compute_values(self, batch, embeddings=None):
# Compute embeddings ...
if embeddings is None:
embeddings = self._encoder(batch)[ENCODER_OUT]
def compute_values(self, batch):
# Compute features ...
features = self._encoder(batch)[ENCODER_OUT]
# then values using our value head.
return self._vf(embeddings).squeeze(-1)
return self._vf(features).squeeze(-1)


if __name__ == "__main__":
Expand Down

0 comments on commit 9ac5d16

Please sign in to comment.