Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint logging and doc fixes #270

Merged
merged 3 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion composer/algorithms/label_smoothing/label_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def smooth_labels(logits: Tensor, targets: Tensor, alpha: float):
as in `Szegedy et al. <https://arxiv.org/abs/1512.00567>`_.

This is computed by ``(1 - alpha) * targets + alpha * smoothed_targets``
where ``smoothed_targets`` is a vector of ones.
where ``smoothed_targets`` is a uniform distribution.

Args:
logits: Output of the model. Tensor of shape (N, C, d1, ..., dn) for
Expand Down
8 changes: 4 additions & 4 deletions composer/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

def as_batch_dict(batch: Batch) -> BatchDict:
"""Casts a :class:`Batch` as a :class:`BatchDict`.

Args:
batch (Batch): A batch.
Raises:
Expand Down Expand Up @@ -83,7 +83,7 @@ def as_batch_pair(batch: Batch) -> BatchPair:

class BreakEpochException(Exception):
"""Raising this exception will immediately end the current epoch.

If you're wondering whether you should use this, the answer is no.
"""

Expand All @@ -96,8 +96,8 @@ class DataLoader(Protocol):

Attributes:
dataset (Dataset): Dataset from which to load the data.
batch_size (int, optional): How many samples per batch to load
(default: ``1``).
batch_size (int, optional): How many samples per batch to load for a
single device (default: ``1``).
num_workers (int): How many subprocesses to use for data loading.
``0`` means that the data will be loaded in the main process.
pin_memory (bool): If ``True``, the data loader will copy Tensors
Expand Down
12 changes: 7 additions & 5 deletions composer/trainer/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,18 @@ class CheckpointLoader:
checkpoint (str): The template path to an existing checkpoint file.
It can be a path to a file on local disk, a URL, or if ``object_store_hparams`` is set, the object name
for a checkpoint in a cloud bucket.

When using Deepspeed zero, the :class:`CheckpointSaver` shards checkpoints by rank. To load deepspeed checkpoints,
specify ``{RANK}`` in in the ``checkpoint`` parameter, and this variable will be substituted with the global rank.
For example, suppose that checkpoints are stored in the following structure:

.. code-block::

my_model/rank_0/ep1.tar
my_model/rank_1/ep1.tar
my_model/rank_2/ep1.tar
...

Then, ``checkpoint`` should be set to ``my_model/rank_{RANK}/ep1.tar``, and all ranks will load the correct
data.

Expand Down Expand Up @@ -189,7 +189,6 @@ def _download_checkpoint(self, node_checkpoint_folder: str) -> Tuple[str, Option
self._retrieve_checkpoint(destination_filepath=rank_zero_checkpoint_archive_filepath,
rank=dist.get_global_rank(),
ignore_not_found_errors=False)

if extracted_checkpoint_folder is not None:
try:
with tarfile.open(rank_zero_checkpoint_archive_filepath) as tarball:
Expand Down Expand Up @@ -235,7 +234,7 @@ def _restore_checkpoint(self, state: State, mosaic_checkpoint_filepath: str,
"""
# Now, all ranks load the checkpoint that local rank zero downloaded
state_dict = torch.load(mosaic_checkpoint_filepath, map_location='cpu')

log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state with keys {state_dict['state'].keys()}")
seed_to_restore = None

if is_module_deepspeed(state.model):
Expand Down Expand Up @@ -287,6 +286,9 @@ def load_checkpoint(self, state: State):
mosaic_checkpoint_filepath, extracted_checkpoint_folder = self._download_checkpoint(node_checkpoint_folder)
seed_to_restore = self._restore_checkpoint(state, mosaic_checkpoint_filepath, extracted_checkpoint_folder)

log.info(f'{"Model weights" if self.hparams.load_weights_only else "Trainer checkpoint"}'
f' loaded from {self.hparams.checkpoint}.')

return seed_to_restore

def restore_checkpoint_rng_state(self, device: Device):
Expand Down
3 changes: 2 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class Trainer:
or dict of :class:`DataSpec` kwargs for the training data.
eval_dataloader (DataLoader, DataSpec, or dict): The :class:`DataLoader`, :class:`DataSpec`,
or dict of :class:`DataSpec` kwargs for the evaluation data.
max_epochs (int): The maxmimum number of epochs to train for.
max_duration (Union[str, `~composer.core.Time`]): The maxmimum number amount of Time to train for.
See `~composer.core.Time` for details.
algorithms (List[Algorithm], optional): The algorithms to use during training.
(default: ``[]``)
optimizer_hparams: (OptimizerHparams, optional): The OptimizerHparams for constructing
Expand Down
2 changes: 1 addition & 1 deletion composer/trainer/trainer_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class TrainerHparams(hp.Hparams):
default=False)

compute_training_metrics: bool = hp.optional(doc="Log validation metrics on training data", default=False)
log_level: str = hp.optional(doc="Python loglevel to use composer", default="WARNING")
log_level: str = hp.optional(doc="Python loglevel to use composer", default="INFO")
datadir: Optional[str] = hp.optional(doc=textwrap.dedent("""
Datadir to apply for both the training and validation datasets. If specified,
it will override train_dataset.datadir and val_dataset.datadir"""),
Expand Down