Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Remove CLI from docs (soon to be deprecated and replaced by python API). #46724

Merged
merged 11 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 93 additions & 68 deletions doc/source/rllib/doc_code/getting_started.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,151 @@
# flake8: noqa

# __rllib-first-config-begin__
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print
from pprint import pprint

from ray.rllib.algorithms.ppo import PPOConfig

algo = (
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.env_runners(num_env_runners=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
.build()
)

algo = config.build()

for i in range(10):
result = algo.train()
print(pretty_print(result))
result.pop("config")
pprint(result)

if i % 5 == 0:
checkpoint_dir = algo.save().checkpoint.path
checkpoint_dir = algo.save_to_path()
print(f"Checkpoint saved in directory {checkpoint_dir}")
# __rllib-first-config-end__

import ray

ray.shutdown()
algo.stop()

if False:
# __rllib-tune-config-begin__
import ray
from ray import train, tune

ray.init()

config = PPOConfig().training(lr=tune.grid_search([0.01, 0.001, 0.0001]))
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.training(
lr=tune.grid_search([0.01, 0.001, 0.0001]),
)
)

tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=train.RunConfig(
stop={"env_runners/episode_return_mean": 150},
stop={"env_runners/episode_return_mean": 150.0},
),
param_space=config,
)

tuner.fit()
# __rllib-tune-config-end__

# __rllib-tuner-begin__
# ``Tuner.fit()`` allows setting a custom log directory (other than ``~/ray-results``)
tuner = ray.tune.Tuner(
"PPO",
param_space=config,
run_config=train.RunConfig(
stop={"env_runners/episode_return_mean": 150},
checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
),
)

results = tuner.fit()
# __rllib-tuner-begin__
from ray import train, tune

# Get the best result based on a particular metric.
best_result = results.get_best_result(
metric="env_runners/episode_return_mean", mode="max"
)
# Tuner.fit() allows setting a custom log directory (other than ~/ray-results).
tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=train.RunConfig(
stop={"num_env_steps_sampled_lifetime": 20000},
checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
),
)

# Get the best checkpoint corresponding to the best result.
best_checkpoint = best_result.checkpoint
# __rllib-tuner-end__
results = tuner.fit()

# Get the best result based on a particular metric.
best_result = results.get_best_result(
metric="env_runners/episode_return_mean", mode="max"
)

# __rllib-compute-action-begin__
# Note: `gymnasium` (not `gym`) will be **the** API supported by RLlib from Ray 2.3 on.
try:
import gymnasium as gym
# Get the best checkpoint corresponding to the best result.
best_checkpoint = best_result.checkpoint
# __rllib-tuner-end__

gymnasium = True
except Exception:
import gym

gymnasium = False
# __rllib-compute-action-begin__
import pathlib
import gymnasium as gym
import numpy as np
import torch
from ray.rllib.core.rl_module import RLModule

from ray.rllib.algorithms.ppo import PPOConfig
env = gym.make("CartPole-v1")

env_name = "CartPole-v1"
env = gym.make(env_name)
algo = PPOConfig().environment(env_name).build()
# Create only the neural network (RLModule) from our checkpoint.
rl_module = RLModule.from_checkpoint(
pathlib.Path(best_checkpoint.path) / "learner_group" / "learner" / "rl_module"
)["default_policy"]

episode_reward = 0
episode_return = 0
terminated = truncated = False

if gymnasium:
obs, info = env.reset()
else:
obs = env.reset()
obs, info = env.reset()

while not terminated and not truncated:
action = algo.compute_single_action(obs)
if gymnasium:
obs, reward, terminated, truncated, info = env.step(action)
else:
obs, reward, terminated, info = env.step(action)
episode_reward += reward
# Compute the next action from a batch (B=1) of observations.
torch_obs_batch = torch.from_numpy(np.array([obs]))
action_logits = rl_module.forward_inference({"obs": torch_obs_batch})[
"action_dist_inputs"
]
# The default RLModule used here produces action logits (from which
# we'll have to sample an action or use the max-likelihood one).
action = torch.argmax(action_logits[0]).numpy()
obs, reward, terminated, truncated, info = env.step(action)
episode_return += reward

print(f"Reached episode return of {episode_return}.")
# __rllib-compute-action-end__


del rl_module


# __rllib-get-state-begin__
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.ppo import PPOConfig

algo = DQNConfig().environment(env="CartPole-v1").build()
algo = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.env_runners(num_env_runners=2)
).build()

# Get weights of the default local policy
algo.get_policy().get_weights()
# Get weights of the algo's RLModule.
algo.get_module().get_state()

# Same as above
algo.env_runner.policy_map["default_policy"].get_weights()
algo.env_runner.module.get_state()

# Get list of weights of each worker, including remote replicas
algo.env_runner_group.foreach_worker(
lambda env_runner: env_runner.get_policy().get_weights()
)
# Get list of weights of each EnvRunner, including remote replicas.
algo.env_runner_group.foreach_worker(lambda env_runner: env_runner.module.get_state())

# Same as above, but with index.
algo.env_runner_group.foreach_worker_with_id(
lambda _id, worker: worker.get_policy().get_weights()
lambda _id, env_runner: env_runner.module.get_state()
)
# __rllib-get-state-end__

algo.stop()
9 changes: 1 addition & 8 deletions doc/source/rllib/key-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Algorithms
----------

Algorithms bring all RLlib components together, making learning of different tasks
accessible via RLlib's Python API and its command line interface (CLI).
accessible via RLlib's Python API.
Each ``Algorithm`` class is managed by its respective ``AlgorithmConfig``, for example to
configure a ``PPO`` instance, you should use the ``PPOConfig`` class.
An ``Algorithm`` sets up its rollout workers and optimizers, and collects training metrics.
Expand Down Expand Up @@ -97,13 +97,6 @@ which implements the proximal policy optimization algorithm in RLlib.
tune.run("PPO", config=config)


.. tab-item:: RLlib Command Line

.. code-block:: bash

rllib train --run=PPO --env=CartPole-v1 --config='{"train_batch_size": 4000}'


RLlib `Algorithm classes <rllib-concepts.html#algorithms>`__ coordinate the distributed workflow of running rollouts and optimizing policies.
Algorithm classes leverage parallel iterators to implement the desired computation pattern.
The following figure shows *synchronous sampling*, the simplest of `these patterns <rllib-algorithms.html>`__:
Expand Down
Loading