Skip to content

Commit

Permalink
[Checkpointing - PR1] Store the rank_zero_seed on state (#680)
Browse files Browse the repository at this point in the history
This PR is the first in a series for cleaning up the checkpoint API. One of the prerequisites is storing the seed on the state.
Here, only the rank zero seed is stored on state, since only the rank zero state is persisted in a checkpoint. The trainer uses a distributed reduction to share the seed across states, so the same seed will be restored when resuming from checkpointing, even if a seed was not originally specified.

This PR ignores the `seed` parameter passed into the trainer when resuming from a checkpoint. For the time being, if a new seed is desired, the `seed` attribute must be removed from the checkpoint state dict. #497 will introduce a cleaner API for this (edge) use case.
  • Loading branch information
ravi-mosaicml authored Mar 11, 2022
1 parent 55d3d98 commit 7439108
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 65 deletions.
3 changes: 0 additions & 3 deletions composer/algorithms/scale_schedule/scale_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class ScaleSchedule(Algorithm):
ratio (float, optional): The factor by which to scale the duration of the schedule. E.g., 0.5
makes the schedule take half as long and 2.0 makes it
take twice as long. Default: ``1.0``.
See also:
:func:`composer.trainer.scale_schedule.scale_scheduler`
"""

def __init__(self, ratio: float = 1.0):
Expand Down
70 changes: 54 additions & 16 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from composer.core.precision import Precision
from composer.core.serializable import Serializable
from composer.core.time import Time, Timer, TimeUnit
from composer.utils import ensure_tuple
from composer.utils import dist, ensure_tuple

if TYPE_CHECKING:
from composer.core.algorithm import Algorithm
Expand Down Expand Up @@ -63,6 +63,8 @@ class State(Serializable):
Args:
model (:attr:`~.types.Model`): The model, typically as a subclass of :class:`~.ComposerModel`.
rank_zero_seed (int): The seed used on the rank zero process. It is assumed that each rank's seed is
``rank_zero_seed + dist.get_global_rank()``.
grad_accum (int): The number of gradient accumulation steps to use. With this argument, micro batch size for
each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``.
train_dataloader (types.DataLoader, DataSpec, or dict):
Expand Down Expand Up @@ -111,6 +113,8 @@ class State(Serializable):
+-----------------------+-------------------------------------------------------------+
| timer | The timer that tracks training loop progress. |
+-----------------------+-------------------------------------------------------------+
| rank_zero_seed | The seed of the rank zero process. |
+-----------------------+-------------------------------------------------------------+
"""

_max_duration: Time[int]
Expand All @@ -129,6 +133,7 @@ def __init__(

# stopping conditions
max_duration: Union[str, Time[int]],
rank_zero_seed: int,

# data configurations
train_dataloader: types.DataLoader,
Expand All @@ -152,6 +157,7 @@ def __init__(
# steps per epoch
steps_per_epoch: Optional[int] = None,
):
self.rank_zero_seed = rank_zero_seed
self.model = model
self.grad_accum = grad_accum
self.train_dataloader = train_dataloader
Expand Down Expand Up @@ -188,10 +194,17 @@ def __init__(
"callbacks",
"scaler",
"timer",
"rank_zero_seed",
]

@property
def seed(self):
"""The seed for the current rank."""
return self.rank_zero_seed + dist.get_global_rank()

@property
def max_duration(self):
"""The maximum training duration."""
return self._max_duration

@max_duration.setter
Expand Down Expand Up @@ -254,11 +267,24 @@ def state_dict(self) -> types.StateDict:
# Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel
serialized_value = state_field_value.state_dict()
else:
serialized_value = {
obj.__class__.__qualname__: obj.state_dict()
for obj in ensure_tuple(state_field_value)
if obj is not None
}
# Duck typing since runtime checkable protocols are not available in Python 3.7
is_state_dict_serializable = any(hasattr(x, "state_dict") for x in ensure_tuple(state_field_value))
if len(ensure_tuple(state_field_value)) > 0:
# any and all should be the same, except in for an empty collection, since
# any([]) is False, while all([]) is True
if is_state_dict_serializable != all(
hasattr(x, "state_dict") for x in ensure_tuple(state_field_value)):
raise RuntimeError(
f"Every member of {state_field_name} should support `state_dict`, or no members should support it."
)

if is_state_dict_serializable:
serialized_value = {
obj.__class__.__qualname__: obj.state_dict() for obj in ensure_tuple(state_field_value)
}
else:
serialized_value = state_field_value

state_dict[state_field_name] = serialized_value

state_dict["_is_model_ddp_wrapped"] = isinstance(self.model, DistributedDataParallel)
Expand Down Expand Up @@ -297,16 +323,28 @@ def load_state_dict(self, state: types.StateDict, strict: bool = False):
if state_field_name == "model":
self.load_model_state(state, strict=strict)
else:
for target in ensure_tuple(state_field_value):
if target is None:
continue
if target.__class__.__qualname__ not in serialized_value:
warnings.warn(
f"{target.__class__.__qualname__} was not found in the state_dict. Its state will NOT be restored",
category=UserWarning)
continue
source = serialized_value[target.__class__.__qualname__]
target.load_state_dict(source)
# Duck typing since runtime checkable protocols are not available in Python 3.7
is_state_dict_serializable = any(hasattr(x, "load_state_dict") for x in ensure_tuple(state_field_value))
if len(ensure_tuple(state_field_value)) > 0:
# any and all should be the same, except in for an empty collection, since
# any([]) is False, while all([]) is True
if is_state_dict_serializable != all(
hasattr(x, "load_state_dict") for x in ensure_tuple(state_field_value)):
raise RuntimeError(
f"Every member of {state_field_name} should support `load_state_dict`, or no members should support it."
)
if is_state_dict_serializable:
for target in ensure_tuple(state_field_value):
if target.__class__.__qualname__ not in serialized_value:
warnings.warn(
f"{target.__class__.__qualname__} was not found in the state_dict. Its state will NOT be restored",
category=UserWarning)
continue
source = serialized_value[target.__class__.__qualname__]
target.load_state_dict(source)
else:
# direct serialization
setattr(self, state_field_name, serialized_value)

@property
def steps_per_epoch(self):
Expand Down
35 changes: 5 additions & 30 deletions composer/trainer/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from composer.core.types import StateDict
from composer.trainer._deepspeed import is_module_deepspeed
from composer.trainer.devices.device import Device
from composer.utils import ObjectStoreProvider, dist, iterate_with_pbar, reproducibility, run_directory
from composer.utils import ObjectStoreProvider, dist, iterate_with_pbar, run_directory

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -257,7 +257,7 @@ def _download_checkpoint(self, node_checkpoint_folder: str) -> Tuple[str, Option
return composer_checkpoint_filepath, extracted_checkpoint_folder, extracted_rank_n

def _restore_checkpoint(self, state: State, composer_checkpoint_filepath: str, extracted_rank_n: bool,
extracted_checkpoint_folder: Optional[str]) -> Optional[int]:
extracted_checkpoint_folder: Optional[str]):
"""Restore a checkpoint into ``state``.
Args:
Expand All @@ -268,14 +268,10 @@ def _restore_checkpoint(self, state: State, composer_checkpoint_filepath: str, e
where global rank is greater than 0.
extracted_checkpoint_folder (Optional[str]): The path to the checkpoint folder, which is passed into
:meth:`deepspeed.DeepSpeedEngine.load_checkpoint`.
Returns:
Optional[int]: The seed that was loaded from the checkpoint if it exists otherwise ``None``.
"""
# Now, all ranks load the checkpoint that local rank zero downloaded
state_dict = torch.load(composer_checkpoint_filepath, map_location='cpu')
log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state with keys {state_dict['state'].keys()}")
seed_to_restore = None

if is_module_deepspeed(state.model):
if extracted_checkpoint_folder is None:
Expand All @@ -300,28 +296,11 @@ def _restore_checkpoint(self, state: State, composer_checkpoint_filepath: str, e
state.load_state_dict(state_dict["state"])
self.checkpoint_rng_state = self._get_checkpoint_rng_state(state_dict["rng"])

if "seed" in state_dict:
world_size = dist.get_world_size()
checkpointed_world_size = len(state_dict["seed"])
if world_size != checkpointed_world_size:
warnings.warn(
textwrap.dedent(f"""\
Current world size {world_size} does not match the checkpointed
world size {checkpointed_world_size}. The seed will not be restored."""))
else:
seed_to_restore = state_dict["seed"][dist.get_global_rank()]
reproducibility.seed_all(seed_to_restore)

return seed_to_restore

def load_checkpoint(self, state: State) -> Optional[int]:
def load_checkpoint(self, state: State):
"""Initialize state from the loaded checkpoint's data.
Args:
state (State): The :class:`~composer.core.state.State` to load the checkpoint into.
Returns:
Optional[int]: The seed that was loaded from the checkpoint if it exists otherwise ``None``.
"""

# download the checkpoint to the node-local folder
Expand All @@ -331,7 +310,7 @@ def load_checkpoint(self, state: State) -> Optional[int]:
node_checkpoint_folder = self._get_node_checkpoint_download_folder(tempdir)
composer_checkpoint_filepath, extracted_checkpoint_folder, extracted_rank_n = self._download_checkpoint(
node_checkpoint_folder)
seed_to_restore = self._restore_checkpoint(
self._restore_checkpoint(
state,
composer_checkpoint_filepath,
extracted_rank_n,
Expand All @@ -345,8 +324,6 @@ def load_checkpoint(self, state: State) -> Optional[int]:
log.info(f'{"Model weights" if self.load_weights_only else "Trainer checkpoint"}'
f' loaded from {self.path}.')

return seed_to_restore

def restore_checkpoint_rng_state(self, device: Device):
"""Restore the state of all RNG objects in this context from the loaded checkpoint's data.
Expand Down Expand Up @@ -478,7 +455,7 @@ def should_checkpoint(self, state: State, event: Event) -> bool:

return False

def save_checkpoint(self, state: State, seed: int, device: Device) -> None:
def save_checkpoint(self, state: State, device: Device) -> None:
"""Save the current state to a new checkpoint file.
There are 3 cases for the format in which the checkpoint is saved:
Expand All @@ -495,12 +472,10 @@ def save_checkpoint(self, state: State, seed: int, device: Device) -> None:
Args:
state (State): The current State of the trainer.
seed (int): The seed used for random number generation.
device (Device): The Device in use by this process.
"""
state_dict = {
'rng': self._get_rng_state(device=device), # stored across all ranks
'seed': dist.all_gather_object(seed),
}

if self.save_event == Event.EPOCH_END:
Expand Down
38 changes: 24 additions & 14 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,17 +489,31 @@ def __init__(
raise ValueError('device must be of class Device')
self._device = device

if self.deepspeed_enabled or dist.get_world_size() > 1:
# deepspeed requires torch.distributed to be initialized, even if the world size is 1
# distributed is always required with multi-rank training
dist.initialize_dist(self._device.dist_backend, datetime.timedelta(seconds=dist_timeout))

if not seed:
seed = reproducibility.get_random_seed()
log.info(f"Seed was None. Setting seed to random value: {seed}")

# Assure that each process has a different seed, necessary if a seed is passed to init
seed += dist.get_global_rank()
# Ensure that each process has a seed = rank_zero_seed + global_rank
# This "deterministically different" seed behavior is required to be able
# to restore seeds when resuming form checkpoints, since only the
# `rank_zero_seed` is stored on state.
if seed < 0 or seed > reproducibility.MAX_SEED:
raise ValueError(f"Invalid seed: {seed}. It must be on [0; 2**32 - 1)")
rank_zero_seed = self._device.tensor_to_device(torch.tensor(
[seed], dtype=torch.int64)) # using int64 to prevent overflow
dist.broadcast(rank_zero_seed, src=0)
rank_zero_seed = rank_zero_seed.item()
assert isinstance(rank_zero_seed, int)
seed = rank_zero_seed + dist.get_global_rank()
log.info(f"Setting seed to {seed}")

# If hparams is used to create the Trainer this function is called twice
# which is okay because all runs with the hparams codepath will do this
reproducibility.seed_all(seed)
self._seed = seed

if not algorithms:
algorithms = []
Expand All @@ -509,10 +523,6 @@ def __init__(
find_unused_parameters = any(map(lambda x: x.find_unused_parameters, algorithms))
self._find_unused_parameters = find_unused_parameters

if self.deepspeed_enabled or dist.get_world_size() > 1:
# deepspeed requires torch.distributed to be initialized, even if the world size is 1
# distributed is always required with multi-rank training
dist.initialize_dist(self._device.dist_backend, datetime.timedelta(seconds=dist_timeout))
if ddp_sync_strategy is None:
self._ddp_sync_strategy = DDPSyncStrategy.SINGLE_AUTO_SYNC if not find_unused_parameters else DDPSyncStrategy.FORCED_SYNC
else:
Expand Down Expand Up @@ -573,6 +583,7 @@ def __init__(

self.state = State(
max_duration=max_duration,
rank_zero_seed=rank_zero_seed,
algorithms=algorithms,
model=model,
callbacks=callbacks,
Expand Down Expand Up @@ -731,9 +742,8 @@ def __init__(
# initialized, but if using PyTorch DDP, the model must be loaded before it is wrapped with
# DDP.
if self._checkpoint_loader is not None:
restored_seed = self._checkpoint_loader.load_checkpoint(state=self.state)
if restored_seed is not None:
self._seed = restored_seed
self._checkpoint_loader.load_checkpoint(state=self.state)
reproducibility.seed_all(self.state.seed)

if not self.deepspeed_enabled:
host_model_params = self.state.model.parameters()
Expand Down Expand Up @@ -1001,7 +1011,7 @@ def _train_loop(self) -> None:

if self._checkpoint_saver and self._checkpoint_saver.should_checkpoint(state=self.state,
event=Event.BATCH_END):
self._checkpoint_saver.save_checkpoint(state=self.state, seed=self._seed, device=self._device)
self._checkpoint_saver.save_checkpoint(state=self.state, device=self._device)

if self.state.timer >= self.state.max_duration:
# If max_duration is specified in batches, samples, or tokens, and
Expand All @@ -1025,7 +1035,7 @@ def _train_loop(self) -> None:

if self._checkpoint_saver and self._checkpoint_saver.should_checkpoint(state=self.state,
event=Event.EPOCH_END):
self._checkpoint_saver.save_checkpoint(state=self.state, seed=self._seed, device=self._device)
self._checkpoint_saver.save_checkpoint(state=self.state, device=self._device)

def _train_batch(self, microbatches: Sequence[Batch], ddp_sync: bool = True):
"""Run training on a full batch of data.
Expand Down Expand Up @@ -1101,7 +1111,7 @@ def _train_batch_inner(self, microbatches: Sequence[Batch]):
self.engine.run_event(Event.BEFORE_BACKWARD)

if use_grad_scaling:
self.state.loss = self.state.scaler.scale(self.state.loss)
self.state.loss = cast(torch.Tensor, self.state.scaler.scale(self.state.loss))

if self.deepspeed_enabled:
cast("deepspeed.DeepSpeedEngine", self.state.model).backward(self.state.loss)
Expand Down
2 changes: 2 additions & 0 deletions composer/trainer/trainer_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ def initialize_object(self) -> Trainer:
seed = self.seed if self.seed else reproducibility.get_random_seed()
# need to set seed before model initialization for determinism
# don't need to set different seeds per process since only the rank 0 initialization is used
# Algorithms should not use the `seed` on `__init__` but rather on `Event.INIT`, which occurs
# after the seed was properly distributed across ranks to ensure checkpoint compatibility
reproducibility.seed_all(seed)

model = self.model.initialize_object()
Expand Down
1 change: 1 addition & 0 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def broadcast(tensor: torch.Tensor, src: int) -> None:
"""
if dist.is_available() and dist.is_initialized():
dist.broadcast(tensor, src)
return
world_size = get_world_size()
if world_size == 1:
return
Expand Down
Loading

0 comments on commit 7439108

Please sign in to comment.