From 7439108acaad88fe95b8ae46a4a5c90b70d90460 Mon Sep 17 00:00:00 2001 From: ravi-mosaicml Date: Thu, 10 Mar 2022 17:46:07 -0800 Subject: [PATCH] [Checkpointing - PR1] Store the `rank_zero_seed` on state (#680) This PR is the first in a series for cleaning up the checkpoint API. One of the prerequisites is storing the seed on the state. Here, only the rank zero seed is stored on state, since only the rank zero state is persisted in a checkpoint. The trainer uses a distributed reduction to share the seed across states, so the same seed will be restored when resuming from checkpointing, even if a seed was not originally specified. This PR ignores the `seed` parameter passed into the trainer when resuming from a checkpoint. For the time being, if a new seed is desired, the `seed` attribute must be removed from the checkpoint state dict. #497 will introduce a cleaner API for this (edge) use case. --- .../scale_schedule/scale_schedule.py | 3 - composer/core/state.py | 70 ++++++++++++++----- composer/trainer/_checkpoint.py | 35 ++-------- composer/trainer/trainer.py | 38 ++++++---- composer/trainer/trainer_hparams.py | 2 + composer/utils/dist.py | 1 + composer/utils/reproducibility.py | 15 +++- docs/source/doctest_fixtures.py | 1 + tests/algorithms/test_layer_freezing.py | 1 + tests/fixtures/dummy_fixtures.py | 2 + tests/fixtures/new_fixtures.py | 1 + tests/test_state.py | 1 + tests/trainer/test_ddp_sync_strategy.py | 1 + tests/trainer/test_scheduler.py | 1 + tests/utils/classifer.py | 1 + 15 files changed, 108 insertions(+), 65 deletions(-) diff --git a/composer/algorithms/scale_schedule/scale_schedule.py b/composer/algorithms/scale_schedule/scale_schedule.py index e4223ac9db..368d89dae5 100644 --- a/composer/algorithms/scale_schedule/scale_schedule.py +++ b/composer/algorithms/scale_schedule/scale_schedule.py @@ -21,9 +21,6 @@ class ScaleSchedule(Algorithm): ratio (float, optional): The factor by which to scale the duration of the schedule. E.g., 0.5 makes the schedule take half as long and 2.0 makes it take twice as long. Default: ``1.0``. - - See also: - :func:`composer.trainer.scale_schedule.scale_scheduler` """ def __init__(self, ratio: float = 1.0): diff --git a/composer/core/state.py b/composer/core/state.py index b4c0dbf56d..652bb38eef 100755 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -17,7 +17,7 @@ from composer.core.precision import Precision from composer.core.serializable import Serializable from composer.core.time import Time, Timer, TimeUnit -from composer.utils import ensure_tuple +from composer.utils import dist, ensure_tuple if TYPE_CHECKING: from composer.core.algorithm import Algorithm @@ -63,6 +63,8 @@ class State(Serializable): Args: model (:attr:`~.types.Model`): The model, typically as a subclass of :class:`~.ComposerModel`. + rank_zero_seed (int): The seed used on the rank zero process. It is assumed that each rank's seed is + ``rank_zero_seed + dist.get_global_rank()``. grad_accum (int): The number of gradient accumulation steps to use. With this argument, micro batch size for each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``. train_dataloader (types.DataLoader, DataSpec, or dict): @@ -111,6 +113,8 @@ class State(Serializable): +-----------------------+-------------------------------------------------------------+ | timer | The timer that tracks training loop progress. | +-----------------------+-------------------------------------------------------------+ + | rank_zero_seed | The seed of the rank zero process. | + +-----------------------+-------------------------------------------------------------+ """ _max_duration: Time[int] @@ -129,6 +133,7 @@ def __init__( # stopping conditions max_duration: Union[str, Time[int]], + rank_zero_seed: int, # data configurations train_dataloader: types.DataLoader, @@ -152,6 +157,7 @@ def __init__( # steps per epoch steps_per_epoch: Optional[int] = None, ): + self.rank_zero_seed = rank_zero_seed self.model = model self.grad_accum = grad_accum self.train_dataloader = train_dataloader @@ -188,10 +194,17 @@ def __init__( "callbacks", "scaler", "timer", + "rank_zero_seed", ] + @property + def seed(self): + """The seed for the current rank.""" + return self.rank_zero_seed + dist.get_global_rank() + @property def max_duration(self): + """The maximum training duration.""" return self._max_duration @max_duration.setter @@ -254,11 +267,24 @@ def state_dict(self) -> types.StateDict: # Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel serialized_value = state_field_value.state_dict() else: - serialized_value = { - obj.__class__.__qualname__: obj.state_dict() - for obj in ensure_tuple(state_field_value) - if obj is not None - } + # Duck typing since runtime checkable protocols are not available in Python 3.7 + is_state_dict_serializable = any(hasattr(x, "state_dict") for x in ensure_tuple(state_field_value)) + if len(ensure_tuple(state_field_value)) > 0: + # any and all should be the same, except in for an empty collection, since + # any([]) is False, while all([]) is True + if is_state_dict_serializable != all( + hasattr(x, "state_dict") for x in ensure_tuple(state_field_value)): + raise RuntimeError( + f"Every member of {state_field_name} should support `state_dict`, or no members should support it." + ) + + if is_state_dict_serializable: + serialized_value = { + obj.__class__.__qualname__: obj.state_dict() for obj in ensure_tuple(state_field_value) + } + else: + serialized_value = state_field_value + state_dict[state_field_name] = serialized_value state_dict["_is_model_ddp_wrapped"] = isinstance(self.model, DistributedDataParallel) @@ -297,16 +323,28 @@ def load_state_dict(self, state: types.StateDict, strict: bool = False): if state_field_name == "model": self.load_model_state(state, strict=strict) else: - for target in ensure_tuple(state_field_value): - if target is None: - continue - if target.__class__.__qualname__ not in serialized_value: - warnings.warn( - f"{target.__class__.__qualname__} was not found in the state_dict. Its state will NOT be restored", - category=UserWarning) - continue - source = serialized_value[target.__class__.__qualname__] - target.load_state_dict(source) + # Duck typing since runtime checkable protocols are not available in Python 3.7 + is_state_dict_serializable = any(hasattr(x, "load_state_dict") for x in ensure_tuple(state_field_value)) + if len(ensure_tuple(state_field_value)) > 0: + # any and all should be the same, except in for an empty collection, since + # any([]) is False, while all([]) is True + if is_state_dict_serializable != all( + hasattr(x, "load_state_dict") for x in ensure_tuple(state_field_value)): + raise RuntimeError( + f"Every member of {state_field_name} should support `load_state_dict`, or no members should support it." + ) + if is_state_dict_serializable: + for target in ensure_tuple(state_field_value): + if target.__class__.__qualname__ not in serialized_value: + warnings.warn( + f"{target.__class__.__qualname__} was not found in the state_dict. Its state will NOT be restored", + category=UserWarning) + continue + source = serialized_value[target.__class__.__qualname__] + target.load_state_dict(source) + else: + # direct serialization + setattr(self, state_field_name, serialized_value) @property def steps_per_epoch(self): diff --git a/composer/trainer/_checkpoint.py b/composer/trainer/_checkpoint.py index eb01e0b1b4..082551cabc 100755 --- a/composer/trainer/_checkpoint.py +++ b/composer/trainer/_checkpoint.py @@ -26,7 +26,7 @@ from composer.core.types import StateDict from composer.trainer._deepspeed import is_module_deepspeed from composer.trainer.devices.device import Device -from composer.utils import ObjectStoreProvider, dist, iterate_with_pbar, reproducibility, run_directory +from composer.utils import ObjectStoreProvider, dist, iterate_with_pbar, run_directory log = logging.getLogger(__name__) @@ -257,7 +257,7 @@ def _download_checkpoint(self, node_checkpoint_folder: str) -> Tuple[str, Option return composer_checkpoint_filepath, extracted_checkpoint_folder, extracted_rank_n def _restore_checkpoint(self, state: State, composer_checkpoint_filepath: str, extracted_rank_n: bool, - extracted_checkpoint_folder: Optional[str]) -> Optional[int]: + extracted_checkpoint_folder: Optional[str]): """Restore a checkpoint into ``state``. Args: @@ -268,14 +268,10 @@ def _restore_checkpoint(self, state: State, composer_checkpoint_filepath: str, e where global rank is greater than 0. extracted_checkpoint_folder (Optional[str]): The path to the checkpoint folder, which is passed into :meth:`deepspeed.DeepSpeedEngine.load_checkpoint`. - - Returns: - Optional[int]: The seed that was loaded from the checkpoint if it exists otherwise ``None``. """ # Now, all ranks load the checkpoint that local rank zero downloaded state_dict = torch.load(composer_checkpoint_filepath, map_location='cpu') log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state with keys {state_dict['state'].keys()}") - seed_to_restore = None if is_module_deepspeed(state.model): if extracted_checkpoint_folder is None: @@ -300,28 +296,11 @@ def _restore_checkpoint(self, state: State, composer_checkpoint_filepath: str, e state.load_state_dict(state_dict["state"]) self.checkpoint_rng_state = self._get_checkpoint_rng_state(state_dict["rng"]) - if "seed" in state_dict: - world_size = dist.get_world_size() - checkpointed_world_size = len(state_dict["seed"]) - if world_size != checkpointed_world_size: - warnings.warn( - textwrap.dedent(f"""\ - Current world size {world_size} does not match the checkpointed - world size {checkpointed_world_size}. The seed will not be restored.""")) - else: - seed_to_restore = state_dict["seed"][dist.get_global_rank()] - reproducibility.seed_all(seed_to_restore) - - return seed_to_restore - - def load_checkpoint(self, state: State) -> Optional[int]: + def load_checkpoint(self, state: State): """Initialize state from the loaded checkpoint's data. Args: state (State): The :class:`~composer.core.state.State` to load the checkpoint into. - - Returns: - Optional[int]: The seed that was loaded from the checkpoint if it exists otherwise ``None``. """ # download the checkpoint to the node-local folder @@ -331,7 +310,7 @@ def load_checkpoint(self, state: State) -> Optional[int]: node_checkpoint_folder = self._get_node_checkpoint_download_folder(tempdir) composer_checkpoint_filepath, extracted_checkpoint_folder, extracted_rank_n = self._download_checkpoint( node_checkpoint_folder) - seed_to_restore = self._restore_checkpoint( + self._restore_checkpoint( state, composer_checkpoint_filepath, extracted_rank_n, @@ -345,8 +324,6 @@ def load_checkpoint(self, state: State) -> Optional[int]: log.info(f'{"Model weights" if self.load_weights_only else "Trainer checkpoint"}' f' loaded from {self.path}.') - return seed_to_restore - def restore_checkpoint_rng_state(self, device: Device): """Restore the state of all RNG objects in this context from the loaded checkpoint's data. @@ -478,7 +455,7 @@ def should_checkpoint(self, state: State, event: Event) -> bool: return False - def save_checkpoint(self, state: State, seed: int, device: Device) -> None: + def save_checkpoint(self, state: State, device: Device) -> None: """Save the current state to a new checkpoint file. There are 3 cases for the format in which the checkpoint is saved: @@ -495,12 +472,10 @@ def save_checkpoint(self, state: State, seed: int, device: Device) -> None: Args: state (State): The current State of the trainer. - seed (int): The seed used for random number generation. device (Device): The Device in use by this process. """ state_dict = { 'rng': self._get_rng_state(device=device), # stored across all ranks - 'seed': dist.all_gather_object(seed), } if self.save_event == Event.EPOCH_END: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 3f7064c0a9..fbfc51b7fc 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -489,17 +489,31 @@ def __init__( raise ValueError('device must be of class Device') self._device = device + if self.deepspeed_enabled or dist.get_world_size() > 1: + # deepspeed requires torch.distributed to be initialized, even if the world size is 1 + # distributed is always required with multi-rank training + dist.initialize_dist(self._device.dist_backend, datetime.timedelta(seconds=dist_timeout)) + if not seed: seed = reproducibility.get_random_seed() - log.info(f"Seed was None. Setting seed to random value: {seed}") - # Assure that each process has a different seed, necessary if a seed is passed to init - seed += dist.get_global_rank() + # Ensure that each process has a seed = rank_zero_seed + global_rank + # This "deterministically different" seed behavior is required to be able + # to restore seeds when resuming form checkpoints, since only the + # `rank_zero_seed` is stored on state. + if seed < 0 or seed > reproducibility.MAX_SEED: + raise ValueError(f"Invalid seed: {seed}. It must be on [0; 2**32 - 1)") + rank_zero_seed = self._device.tensor_to_device(torch.tensor( + [seed], dtype=torch.int64)) # using int64 to prevent overflow + dist.broadcast(rank_zero_seed, src=0) + rank_zero_seed = rank_zero_seed.item() + assert isinstance(rank_zero_seed, int) + seed = rank_zero_seed + dist.get_global_rank() + log.info(f"Setting seed to {seed}") # If hparams is used to create the Trainer this function is called twice # which is okay because all runs with the hparams codepath will do this reproducibility.seed_all(seed) - self._seed = seed if not algorithms: algorithms = [] @@ -509,10 +523,6 @@ def __init__( find_unused_parameters = any(map(lambda x: x.find_unused_parameters, algorithms)) self._find_unused_parameters = find_unused_parameters - if self.deepspeed_enabled or dist.get_world_size() > 1: - # deepspeed requires torch.distributed to be initialized, even if the world size is 1 - # distributed is always required with multi-rank training - dist.initialize_dist(self._device.dist_backend, datetime.timedelta(seconds=dist_timeout)) if ddp_sync_strategy is None: self._ddp_sync_strategy = DDPSyncStrategy.SINGLE_AUTO_SYNC if not find_unused_parameters else DDPSyncStrategy.FORCED_SYNC else: @@ -573,6 +583,7 @@ def __init__( self.state = State( max_duration=max_duration, + rank_zero_seed=rank_zero_seed, algorithms=algorithms, model=model, callbacks=callbacks, @@ -731,9 +742,8 @@ def __init__( # initialized, but if using PyTorch DDP, the model must be loaded before it is wrapped with # DDP. if self._checkpoint_loader is not None: - restored_seed = self._checkpoint_loader.load_checkpoint(state=self.state) - if restored_seed is not None: - self._seed = restored_seed + self._checkpoint_loader.load_checkpoint(state=self.state) + reproducibility.seed_all(self.state.seed) if not self.deepspeed_enabled: host_model_params = self.state.model.parameters() @@ -1001,7 +1011,7 @@ def _train_loop(self) -> None: if self._checkpoint_saver and self._checkpoint_saver.should_checkpoint(state=self.state, event=Event.BATCH_END): - self._checkpoint_saver.save_checkpoint(state=self.state, seed=self._seed, device=self._device) + self._checkpoint_saver.save_checkpoint(state=self.state, device=self._device) if self.state.timer >= self.state.max_duration: # If max_duration is specified in batches, samples, or tokens, and @@ -1025,7 +1035,7 @@ def _train_loop(self) -> None: if self._checkpoint_saver and self._checkpoint_saver.should_checkpoint(state=self.state, event=Event.EPOCH_END): - self._checkpoint_saver.save_checkpoint(state=self.state, seed=self._seed, device=self._device) + self._checkpoint_saver.save_checkpoint(state=self.state, device=self._device) def _train_batch(self, microbatches: Sequence[Batch], ddp_sync: bool = True): """Run training on a full batch of data. @@ -1101,7 +1111,7 @@ def _train_batch_inner(self, microbatches: Sequence[Batch]): self.engine.run_event(Event.BEFORE_BACKWARD) if use_grad_scaling: - self.state.loss = self.state.scaler.scale(self.state.loss) + self.state.loss = cast(torch.Tensor, self.state.scaler.scale(self.state.loss)) if self.deepspeed_enabled: cast("deepspeed.DeepSpeedEngine", self.state.model).backward(self.state.loss) diff --git a/composer/trainer/trainer_hparams.py b/composer/trainer/trainer_hparams.py index a8ef7c4315..199197d13b 100755 --- a/composer/trainer/trainer_hparams.py +++ b/composer/trainer/trainer_hparams.py @@ -475,6 +475,8 @@ def initialize_object(self) -> Trainer: seed = self.seed if self.seed else reproducibility.get_random_seed() # need to set seed before model initialization for determinism # don't need to set different seeds per process since only the rank 0 initialization is used + # Algorithms should not use the `seed` on `__init__` but rather on `Event.INIT`, which occurs + # after the seed was properly distributed across ranks to ensure checkpoint compatibility reproducibility.seed_all(seed) model = self.model.initialize_object() diff --git a/composer/utils/dist.py b/composer/utils/dist.py index b16b6d37bc..f821d46647 100755 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -217,6 +217,7 @@ def broadcast(tensor: torch.Tensor, src: int) -> None: """ if dist.is_available() and dist.is_initialized(): dist.broadcast(tensor, src) + return world_size = get_world_size() if world_size == 1: return diff --git a/composer/utils/reproducibility.py b/composer/utils/reproducibility.py index 0008fa733c..cbe50bb3f3 100755 --- a/composer/utils/reproducibility.py +++ b/composer/utils/reproducibility.py @@ -29,9 +29,13 @@ >>> # model will now be deterministically initialized, since the seed is set. >>> init_weights(model) >>> trainer = Trainer(model=model) + +Attributes: + MAX_SEED (int): The maximum allowed seed, which is :math:`2^{32} - 1`. """ import os import random +import time import warnings import numpy as np @@ -42,8 +46,12 @@ "configure_deterministic_mode", "get_random_seed", "seed_all", + "MAX_SEED", ] +# seeds must be 32-bit unsigned integers +MAX_SEED = 2**32 - 1 + def configure_deterministic_mode(): """Configure PyTorch deterministic mode. @@ -86,7 +94,9 @@ def get_random_seed() -> int: Returns: int: A random seed. """ - seed = int(torch.empty((), dtype=torch.int64).random_(to=2**32).item()) + rng = random.Random(int(time.time_ns())) # get a new RNG does not respect the current seed + seed = rng.randint(0, MAX_SEED) + assert seed >= 0 and seed <= MAX_SEED, "seed should be on this range" return seed @@ -107,7 +117,8 @@ def seed_all(seed: int): Args: seed (int): The random seed """ - + if seed < 0 or seed > MAX_SEED: + raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py index 8ec52d8754..35601c99c9 100644 --- a/docs/source/doctest_fixtures.py +++ b/docs/source/doctest_fixtures.py @@ -75,6 +75,7 @@ ) state = State( + rank_zero_seed=0, model=model, optimizers=optimizer, grad_accum=1, diff --git a/tests/algorithms/test_layer_freezing.py b/tests/algorithms/test_layer_freezing.py index eb82d4434a..96d12247be 100755 --- a/tests/algorithms/test_layer_freezing.py +++ b/tests/algorithms/test_layer_freezing.py @@ -16,6 +16,7 @@ def _generate_state(epoch: int, max_epochs: int): model = SimpleConvModel() state = State(model=model, + rank_zero_seed=0, optimizers=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99), precision=Precision.FP32, grad_accum=1, diff --git a/tests/fixtures/dummy_fixtures.py b/tests/fixtures/dummy_fixtures.py index ef5f914fe9..69c3d1f748 100755 --- a/tests/fixtures/dummy_fixtures.py +++ b/tests/fixtures/dummy_fixtures.py @@ -110,6 +110,7 @@ def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_data model=dummy_model, precision=Precision.FP32, grad_accum=1, + rank_zero_seed=0, train_dataloader=dummy_train_dataloader, evaluators=evaluators, optimizers=dummy_optimizer, @@ -232,6 +233,7 @@ def state_with_model(simple_conv_model: Model, dummy_train_dataloader: DataLoade evaluators = [Evaluator(label="dummy_label", dataloader=dummy_val_dataloader, metrics=metric_coll)] state = State( grad_accum=1, + rank_zero_seed=0, max_duration="100ep", model=simple_conv_model, precision=Precision.FP32, diff --git a/tests/fixtures/new_fixtures.py b/tests/fixtures/new_fixtures.py index 878bb01436..40600ddaa2 100755 --- a/tests/fixtures/new_fixtures.py +++ b/tests/fixtures/new_fixtures.py @@ -16,6 +16,7 @@ def minimal_state(): """ return State( model=SimpleModel(), + rank_zero_seed=0, train_dataloader=DataLoader(RandomClassificationDataset()), evaluators=[], max_duration='100ep', diff --git a/tests/test_state.py b/tests/test_state.py index 476af3e0e0..63503621ef 100755 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -26,6 +26,7 @@ def get_dummy_state(model: ComposerModel, train_dataloader: types.DataLoader, va evaluators = [types.Evaluator(label="dummy_label", dataloader=val_dataloader, metrics=model.metrics(train=False))] state = State(model=model, grad_accum=random.randint(0, 100), + rank_zero_seed=random.randint(0, 100), precision=types.Precision.AMP, max_duration=f"{random.randint(0, 100)}ep", train_dataloader=train_dataloader, diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index cf47f13cf4..6979f9fafd 100755 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -55,6 +55,7 @@ def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional metric_coll = MetricCollection([Accuracy()]) evaluators = [Evaluator(label="dummy_label", dataloader=dummy_val_dataloader, metrics=metric_coll)] state = State(model=original_model, + rank_zero_seed=0, optimizers=optimizer, grad_accum=2, max_duration="1ep", diff --git a/tests/trainer/test_scheduler.py b/tests/trainer/test_scheduler.py index bd036a5e92..3dd17aea7c 100644 --- a/tests/trainer/test_scheduler.py +++ b/tests/trainer/test_scheduler.py @@ -23,6 +23,7 @@ def dummy_schedulers_state(dummy_model: Model, dummy_train_dataloader: DataLoader): return State( model=dummy_model, + rank_zero_seed=0, train_dataloader=dummy_train_dataloader, max_duration=MAX_DURATION, steps_per_epoch=STEPS_PER_EPOCH, diff --git a/tests/utils/classifer.py b/tests/utils/classifer.py index d0a5132cf7..28dbc67b04 100755 --- a/tests/utils/classifer.py +++ b/tests/utils/classifer.py @@ -20,6 +20,7 @@ def _get_state(train_dataloader: DataLoader, eval_dataloader: DataLoader, steps_ evaluators = [Evaluator(label="dummy_label", dataloader=eval_dataloader, metrics=metric_coll)] return State( model=model, + rank_zero_seed=0, optimizers=optim.SGD(model.parameters(), lr=.001, momentum=0.0), max_duration="1ep", train_dataloader=train_dataloader,