Skip to content

Commit

Permalink
[RLlib] Checkpointing enhancements: Experimentally support `msgpack…
Browse files Browse the repository at this point in the history
…` and separate state from architecture. (#49497)
  • Loading branch information
sven1977 authored and srinathk10 committed Jan 3, 2025
1 parent bd3d139 commit 2b78a5d
Show file tree
Hide file tree
Showing 24 changed files with 706 additions and 287 deletions.
32 changes: 24 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1589,13 +1589,24 @@ py_test(
# Tag: utils
# --------------------------------------------------------------------

# Checkpointables
py_test(
name = "utils/tests/test_checkpointable",
tags = ["team:rllib", "utils"],
size = "large",
data = glob(["utils/tests/old_checkpoints/**"]),
srcs = ["utils/tests/test_checkpointable.py"]
)

# Errors
py_test(
name = "test_errors",
tags = ["team:rllib", "utils"],
size = "medium",
srcs = ["utils/tests/test_errors.py"]
)

# @OldAPIStack
py_test(
name = "test_minibatch_utils",
tags = ["team:rllib", "utils"],
Expand All @@ -1610,20 +1621,15 @@ py_test(
srcs = ["utils/tests/test_serialization.py"]
)

py_test(
name = "test_curiosity",
tags = ["team:rllib", "utils"],
size = "large",
srcs = ["utils/exploration/tests/test_curiosity.py"]
)

# @OldAPIStack
py_test(
name = "test_explorations",
tags = ["team:rllib", "utils"],
size = "large",
srcs = ["utils/exploration/tests/test_explorations.py"]
)

# @OldAPIStack
py_test(
name = "test_value_predictions",
tags = ["team:rllib", "utils"],
Expand All @@ -1646,6 +1652,7 @@ py_test(
srcs = ["utils/schedules/tests/test_schedules.py"]
)

# @OldAPIStack
py_test(
name = "test_framework_agnostic_components",
tags = ["team:rllib", "utils"],
Expand Down Expand Up @@ -1991,6 +1998,15 @@ py_test(

# subdirectory: checkpoints/
# ....................................
py_test(
name = "examples/checkpoints/change_config_during_training",
main = "examples/checkpoints/change_config_during_training.py",
tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core"],
size = "large",
srcs = ["examples/checkpoints/change_config_during_training.py"],
args = ["--enable-new-api-stack", "--as-test", "--stop-reward-first-config=150.0", "--stop-reward=450.0"]
)

py_test(
name = "examples/checkpoints/checkpoint_by_custom_criteria",
main = "examples/checkpoints/checkpoint_by_custom_criteria.py",
Expand Down Expand Up @@ -2673,7 +2689,7 @@ py_test(
tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core", "no_main"],
size = "large",
srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"],
args = ["--enable-new-api-stack", "--num-agents=2", "--framework=torch", "--checkpoint-freq=20", "--checkpoint-at-end", "--num-cpus=4", "--algo=PPO"]
args = ["--enable-new-api-stack", "--as-test", "--num-agents=2", "--framework=torch", "--checkpoint-freq=20", "--checkpoint-at-end", "--num-cpus=4", "--algo=PPO"]
)

py_test(
Expand Down
51 changes: 43 additions & 8 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from ray.rllib.utils.checkpoints import (
Checkpointable,
CHECKPOINT_VERSION,
CHECKPOINT_VERSION_LEARNER,
CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER,
get_checkpoint_info,
try_import_msgpack,
)
Expand Down Expand Up @@ -300,7 +300,7 @@ class Algorithm(Checkpointable, Trainable, AlgorithmBase):
# Backward compatibility with old checkpoint system (now through the
# `Checkpointable` API).
METADATA_FILE_NAME = "rllib_checkpoint.json"
STATE_FILE_NAME = "algorithm_state.pkl"
STATE_FILE_NAME = "algorithm_state"

@classmethod
@override(Checkpointable)
Expand Down Expand Up @@ -1633,16 +1633,38 @@ def restore_workers(self, workers: EnvRunnerGroup) -> None:
# Get the state of the correct (reference) worker. For example the local
# worker of an EnvRunnerGroup.
state = from_worker.get_state()
state_ref = ray.put(state)

def _sync_env_runner(er):
er.set_state(ray.get(state_ref))

# Take out (old) connector states from local worker's state.
if not self.config.enable_env_runner_and_connector_v2:
for pol_states in state["policy_states"].values():
pol_states.pop("connector_configs", None)
state_ref = ray.put(state)

elif self.config.is_multi_agent():

multi_rl_module_spec = MultiRLModuleSpec.from_module(from_worker.module)

def _sync_env_runner(er): # noqa
# Remove modules, if necessary.
for module_id, module in er.module._rl_modules.copy().items():
if module_id not in multi_rl_module_spec.rl_module_specs:
er.module.remove_module(
module_id, raise_err_if_not_found=True
)
# Add modules, if necessary.
for mid, mod_spec in multi_rl_module_spec.rl_module_specs.items():
if mid not in er.module:
er.module.add_module(mid, mod_spec.build(), override=False)
# Now that the MultiRLModule is fixed, update the state.
er.set_state(ray.get(state_ref))

# By default, entire local EnvRunner state is synced after restoration
# to bring the previously failed EnvRunner up to date.
workers.foreach_worker(
func=lambda w: w.set_state(ray.get(state_ref)),
func=_sync_env_runner,
remote_worker_ids=restored,
# Don't update the local EnvRunner, b/c it's the one we are synching
# from.
Expand Down Expand Up @@ -1949,6 +1971,12 @@ def _remove(_env_runner):
_env_runner.config.multi_agent(
policy_mapping_fn=new_agent_to_module_mapping_fn
)
# Force reset all ongoing episodes on the EnvRunner to avoid having
# different ModuleIDs compute actions for the same AgentID in the same
# episode.
# TODO (sven): Create an API for this.
_env_runner._needs_initial_reset = True

return MultiRLModuleSpec.from_module(_env_runner.module)

# Remove from (training) EnvRunners and sync weights.
Expand Down Expand Up @@ -2603,7 +2631,10 @@ def save_checkpoint(self, checkpoint_dir: str) -> None:
# New API stack: Delegate to the `Checkpointable` implementation of
# `save_to_path()`.
if self.config.enable_rl_module_and_learner:
return self.save_to_path(checkpoint_dir)
return self.save_to_path(
checkpoint_dir,
use_msgpack=self.config._use_msgpack_checkpoints,
)

checkpoint_dir = pathlib.Path(checkpoint_dir)

Expand All @@ -2617,7 +2648,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> None:

# Add RLlib checkpoint version.
if self.config.enable_rl_module_and_learner:
state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER
state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER
else:
state["checkpoint_version"] = CHECKPOINT_VERSION

Expand Down Expand Up @@ -2813,9 +2844,13 @@ def restore_from_path(self, path, *args, **kwargs):
# Override from parent method, b/c we might have to sync the EnvRunner weights
# after having restored/loaded the LearnerGroup state.
super().restore_from_path(path, *args, **kwargs)
# Sync EnvRunners, but only if LearnerGroup's checkpoint can be found in path.

# Sync EnvRunners, if LearnerGroup's checkpoint can be found in path
# or user loaded a subcomponent within the LearnerGroup (for example a module).
path = pathlib.Path(path)
if (path / "learner_group").is_dir():
if (path / COMPONENT_LEARNER_GROUP).is_dir() or (
"component" in kwargs and COMPONENT_LEARNER_GROUP in kwargs["component"]
):
# Make also sure, all (training) EnvRunners get the just loaded weights, but
# only the inference-only ones.
self.env_runner_group.sync_weights(
Expand Down
16 changes: 6 additions & 10 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def __init__(self, algo_class: Optional[type] = None):
self._per_module_overrides: Dict[ModuleID, "AlgorithmConfig"] = {}

# `self.experimental()`
self._use_msgpack_checkpoints = False
self._torch_grad_scaler_class = None
self._torch_lr_scheduler_classes = None
self._tf_policy_handles_more_than_one_loss = False
Expand Down Expand Up @@ -3458,6 +3459,7 @@ def rl_module(
def experimental(
self,
*,
_use_msgpack_checkpoints: Optional[bool] = NotProvided,
_torch_grad_scaler_class: Optional[Type] = NotProvided,
_torch_lr_scheduler_classes: Optional[
Union[List[Type], Dict[ModuleID, List[Type]]]
Expand All @@ -3466,12 +3468,12 @@ def experimental(
_disable_preprocessor_api: Optional[bool] = NotProvided,
_disable_action_flattening: Optional[bool] = NotProvided,
_disable_initialize_loss_from_dummy_batch: Optional[bool] = NotProvided,
# Deprecated args.
_enable_new_api_stack=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the config's experimental settings.
Args:
_use_msgpack_checkpoints: Create state files in all checkpoints through
msgpack rather than pickle.
_torch_grad_scaler_class: Class to use for torch loss scaling (and gradient
unscaling). The class must implement the following methods to be
compatible with a `TorchLearner`. These methods/APIs match exactly those
Expand Down Expand Up @@ -3511,14 +3513,8 @@ def experimental(
Returns:
This updated AlgorithmConfig object.
"""
if _enable_new_api_stack != DEPRECATED_VALUE:
deprecation_warning(
old="config.experimental(_enable_new_api_stack=...)",
new="config.api_stack(enable_rl_module_and_learner=...,"
"enable_env_runner_and_connector_v2=...)",
error=True,
)

if _use_msgpack_checkpoints is not NotProvided:
self._use_msgpack_checkpoints = _use_msgpack_checkpoints
if _tf_policy_handles_more_than_one_loss is not NotProvided:
self._tf_policy_handles_more_than_one_loss = (
_tf_policy_handles_more_than_one_loss
Expand Down
2 changes: 0 additions & 2 deletions rllib/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
COMPONENT_LEARNER_GROUP = "learner_group"
COMPONENT_METRICS_LOGGER = "metrics_logger"
COMPONENT_MODULE_TO_ENV_CONNECTOR = "module_to_env_connector"
COMPONENT_MULTI_RL_MODULE_SPEC = "_multi_rl_module_spec"
COMPONENT_OPTIMIZER = "optimizer"
COMPONENT_RL_MODULE = "rl_module"

Expand All @@ -28,7 +27,6 @@
"COMPONENT_LEARNER_GROUP",
"COMPONENT_METRICS_LOGGER",
"COMPONENT_MODULE_TO_ENV_CONNECTOR",
"COMPONENT_MULTI_RL_MODULE_SPEC",
"COMPONENT_OPTIMIZER",
"COMPONENT_RL_MODULE",
"DEFAULT_AGENT_ID",
Expand Down
3 changes: 0 additions & 3 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ray import ObjectRef
from ray.rllib.core import (
COMPONENT_LEARNER,
COMPONENT_MULTI_RL_MODULE_SPEC,
COMPONENT_RL_MODULE,
)
from ray.rllib.core.learner.learner import Learner
Expand Down Expand Up @@ -797,8 +796,6 @@ def get_weights(
)
]
state = self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
# Remove the MultiRLModuleSpec to just get the weights.
state.pop(COMPONENT_MULTI_RL_MODULE_SPEC, None)
return state

def set_weights(self, weights) -> None:
Expand Down
4 changes: 0 additions & 4 deletions rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ray.rllib.core import (
Columns,
COMPONENT_LEARNER,
COMPONENT_MULTI_RL_MODULE_SPEC,
COMPONENT_RL_MODULE,
DEFAULT_MODULE_ID,
)
Expand Down Expand Up @@ -431,7 +430,6 @@ def test_save_to_path_and_restore_from_path(self):
weights_after_update = learner_group.get_state(
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
weights_after_update.pop(COMPONENT_MULTI_RL_MODULE_SPEC)
# Weights after the update must be different from original ones.
check(initial_weights, weights_after_update, false=True)

Expand All @@ -454,7 +452,6 @@ def test_save_to_path_and_restore_from_path(self):
weights_after_2_updates_with_break = learner_group.get_state(
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
weights_after_2_updates_with_break.pop(COMPONENT_MULTI_RL_MODULE_SPEC)
learner_group.shutdown()
del learner_group

Expand All @@ -464,7 +461,6 @@ def test_save_to_path_and_restore_from_path(self):
weights_after_restore = learner_group.get_state(
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
weights_after_restore.pop(COMPONENT_MULTI_RL_MODULE_SPEC)
check(initial_weights, weights_after_restore)
# Perform 2 updates to get to the same state as the previous learners.
learner_group.update_from_episodes(FAKE_EPISODES)
Expand Down
20 changes: 12 additions & 8 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
WEIGHTS_SEQ_NO,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor, copy_torch_tensors
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import (
ModuleID,
Optimizer,
Expand Down Expand Up @@ -321,21 +321,25 @@ def _get_optimizer_state(self) -> StateDict:
for name, optim in self._named_optimizers.items():
ret[name] = {
"module_id": self._optimizer_name_to_module[name],
"state": copy_torch_tensors(optim.state_dict(), device="cpu"),
"state": convert_to_numpy(optim.state_dict()),
}
return ret

@override(Learner)
def _set_optimizer_state(self, state: StateDict) -> None:
for name, state_dict in state.items():
if name not in self._named_optimizers:
# Ignore updating optimizers matching to submodules not present in this
# Learner's MultiRLModule.
module_id = state_dict["module_id"]
if name not in self._named_optimizers and module_id in self.module:
self.configure_optimizers_for_module(
state_dict["module_id"],
config=self.config.get_config_for_module(state_dict["module_id"]),
module_id=module_id,
config=self.config.get_config_for_module(module_id=module_id),
)
if name in self._named_optimizers:
self._named_optimizers[name].load_state_dict(
convert_to_torch_tensor(state_dict["state"], device=self._device)
)
self._named_optimizers[name].load_state_dict(
copy_torch_tensors(state_dict["state"], device=self._device)
)

@override(Learner)
def get_param_ref(self, param: Param) -> Hashable:
Expand Down
31 changes: 0 additions & 31 deletions rllib/core/rl_module/multi_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import gymnasium as gym

from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
from ray.rllib.utils import force_list
Expand Down Expand Up @@ -420,17 +419,6 @@ def get_state(
) -> StateDict:
state = {}

# We store the current RLModuleSpec as well as it might have changed over time
# (modules added/removed from `self`).
if self._check_component(
COMPONENT_MULTI_RL_MODULE_SPEC,
components,
not_components,
):
state[COMPONENT_MULTI_RL_MODULE_SPEC] = MultiRLModuleSpec.from_module(
self
).to_dict()

for module_id, rl_module in self.get_checkpointable_components():
if self._check_component(module_id, components, not_components):
state[module_id] = rl_module.get_state(
Expand All @@ -454,27 +442,8 @@ def set_state(self, state: StateDict) -> None:
Args:
state: The state dict to set.
"""
# Check the given MultiRLModuleSpec and - if there are changes in the individual
# sub-modules - apply these to this MultiRLModule.
if COMPONENT_MULTI_RL_MODULE_SPEC in state:
multi_rl_module_spec = MultiRLModuleSpec.from_dict(
state[COMPONENT_MULTI_RL_MODULE_SPEC]
)
# Go through all of our current modules and check, whether they are listed
# in the given MultiRLModuleSpec. If not, erase them from `self`.
for module_id, module in self._rl_modules.copy().items():
if module_id not in multi_rl_module_spec.rl_module_specs:
self.remove_module(module_id, raise_err_if_not_found=True)
# Go through all the modules in the given MultiRLModuleSpec and if
# they are not present in `self`, add them.
for module_id, module_spec in multi_rl_module_spec.rl_module_specs.items():
if module_id not in self:
self.add_module(module_id, module_spec.build(), override=False)

# Now, set the individual states
for module_id, module_state in state.items():
if module_id == COMPONENT_MULTI_RL_MODULE_SPEC:
continue
if module_id in self:
self._rl_modules[module_id].set_state(module_state)

Expand Down
Loading

0 comments on commit 2b78a5d

Please sign in to comment.