Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix schedule validation on new API stack (for config settings like lr or entropy_coeff). #49363

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4483,6 +4483,16 @@ def _validate_new_api_stack_settings(self):
# `enable_rl_module_and_learner=True`.
return

# Warn about new API stack on by default.
logger.warning(
f"You are running {self.algo_class.__name__} on the new API stack! "
"This is the new default behavior for this algorithm. If you don't "
"want to use the new API stack, set `config.api_stack("
"enable_rl_module_and_learner=False,"
"enable_env_runner_and_connector_v2=False)`. For a detailed migration "
"guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)

# Disabled hybrid API stack. Now, both `enable_rl_module_and_learner` and
# `enable_env_runner_and_connector_v2` must be True or both False.
if not self.enable_env_runner_and_connector_v2:
Expand Down
61 changes: 22 additions & 39 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,30 +413,31 @@ def validate(self) -> None:
# Call super's validation method.
super().validate()

# Warn about new API stack on by default.
if self.enable_rl_module_and_learner:
logger.warning(
f"You are running {self.algo_class.__name__} on the new API stack! "
"This is the new default behavior for this algorithm. If you don't "
"want to use the new API stack, set `config.api_stack("
"enable_rl_module_and_learner=False,"
"enable_env_runner_and_connector_v2=False)`. For a detailed migration "
"guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)

if (
not self.enable_rl_module_and_learner
and self.exploration_config["type"] == "ParameterNoise"
):
if self.batch_mode != "complete_episodes":
# `lr_schedule` checking.
if self.lr_schedule is not None:
raise ValueError(
"ParameterNoise Exploration requires `batch_mode` to be "
"'complete_episodes'. Try setting `config.env_runners("
"batch_mode='complete_episodes')`."
"`lr_schedule` is deprecated and must be None! Use the "
"`lr` setting to setup a schedule."
)

if not self.enable_env_runner_and_connector_v2 and not self.in_evaluation:
validate_buffer_config(self)
else:
if not self.in_evaluation:
validate_buffer_config(self)

# TODO (simon): Find a clean solution to deal with configuration configs
# when using the new API stack.
if self.exploration_config["type"] == "ParameterNoise":
if self.batch_mode != "complete_episodes":
raise ValueError(
"ParameterNoise Exploration requires `batch_mode` to be "
"'complete_episodes'. Try setting `config.env_runners("
"batch_mode='complete_episodes')`."
)
if self.noisy:
raise ValueError(
"ParameterNoise Exploration and `noisy` network cannot be"
" used at the same time!"
)

if self.td_error_loss_fn not in ["huber", "mse"]:
raise ValueError("`td_error_loss_fn` must be 'huber' or 'mse'!")
Expand All @@ -454,24 +455,6 @@ def validate(self) -> None:
f"{self.n_step})."
)

# TODO (simon): Find a clean solution to deal with
# configuration configs when using the new API stack.
if (
not self.enable_rl_module_and_learner
and self.exploration_config["type"] == "ParameterNoise"
):
if self.batch_mode != "complete_episodes":
raise ValueError(
"ParameterNoise Exploration requires `batch_mode` to be "
"'complete_episodes'. Try setting `config.env_runners("
"batch_mode='complete_episodes')`."
)
if self.noisy:
raise ValueError(
"ParameterNoise Exploration and `noisy` network cannot be"
" used at the same time!"
)

# Validate that we use the corresponding `EpisodeReplayBuffer` when using
# episodes.
# TODO (sven, simon): Implement the multi-agent case for replay buffers.
Expand Down
6 changes: 6 additions & 0 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,12 @@ def validate(self) -> None:
"does NOT support a mixin replay buffer yet for "
f"{self} (set `config.replay_proportion` to 0.0)!"
)
# `lr_schedule` checking.
if self.lr_schedule is not None:
raise ValueError(
"`lr_schedule` is deprecated and must be None! Use the "
"`lr` setting to setup a schedule."
)
# Entropy coeff schedule checking.
if self.entropy_coeff_schedule is not None:
raise ValueError(
Expand Down
19 changes: 7 additions & 12 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,6 @@ def validate(self) -> None:
# Call super's validation method.
super().validate()

# Warn about new API stack on by default.
if self.enable_rl_module_and_learner:
logger.warning(
f"You are running {self.algo_class.__name__} on the new API stack! "
"This is the new default behavior for this algorithm. If you don't "
"want to use the new API stack, set `config.api_stack("
"enable_rl_module_and_learner=False,"
"enable_env_runner_and_connector_v2=False)`. For a detailed migration "
"guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)

# Synchronous sampling, on-policy/PPO algos -> Check mismatches between
# `rollout_fragment_length` and `train_batch_size_per_learner` to avoid user
# confusion.
Expand Down Expand Up @@ -350,8 +339,14 @@ def validate(self) -> None:
"batch_mode=complete_episodes."
)

# Entropy coeff schedule checking.
# New API stack checks.
if self.enable_rl_module_and_learner:
# `lr_schedule` checking.
if self.lr_schedule is not None:
raise ValueError(
"`lr_schedule` is deprecated and must be None! Use the "
"`lr` setting to setup a schedule."
)
if self.entropy_coeff_schedule is not None:
raise ValueError(
"`entropy_coeff_schedule` is deprecated and must be None! Use the "
Expand Down
1 change: 1 addition & 0 deletions rllib/tuned_examples/sac/pendulum_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
actor_lr=2e-4 * (args.num_learners or 1) ** 0.5,
critic_lr=8e-4 * (args.num_learners or 1) ** 0.5,
alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5,
# TODO (sven): Maybe go back to making this a dict of the sub-learning rates?
lr=None,
target_entropy="auto",
n_step=(2, 5),
Expand Down
18 changes: 13 additions & 5 deletions rllib/utils/schedules/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def validate(
Raises:
ValueError: In case, errors are found in the schedule's format.
"""
# Fixed (single) value.
if (
isinstance(fixed_value_or_schedule, (int, float))
or fixed_value_or_schedule is None
Expand All @@ -97,17 +98,24 @@ def validate(
):
raise ValueError(
f"Invalid `{setting_name}` ({fixed_value_or_schedule}) specified! "
f"Must be a list of at least 2 tuples, each of the form "
f"(`timestep`, `{description} to reach`), e.g. "
f"Must be a list of 2 or more tuples, each of the form "
f"(`timestep`, `{description} to reach`), for example "
"`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`."
)
elif fixed_value_or_schedule[0][0] != 0:
raise ValueError(
f"When providing a `{setting_name}`, the first timestep must be 0 "
f"and the corresponding lr value is the initial {description}! You "
f"provided ts={fixed_value_or_schedule[0][0]} {description}="
f"When providing a `{setting_name}` schedule, the first timestep must "
f"be 0 and the corresponding lr value is the initial {description}! "
f"You provided ts={fixed_value_or_schedule[0][0]} {description}="
f"{fixed_value_or_schedule[0][1]}."
)
elif any(len(pair) != 2 for pair in fixed_value_or_schedule):
raise ValueError(
f"When providing a `{setting_name}` schedule, each tuple in the "
f"schedule list must have exctly 2 items of the form "
f"(`timestep`, `{description} to reach`), for example "
"`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`."
)

def get_current_value(self) -> TensorType:
"""Returns the current value (as a tensor variable).
Expand Down
Loading