Skip to content

Commit

Permalink
Multiple calls to Trainer.fit() (#948)
Browse files Browse the repository at this point in the history
This PR allows for arguments to be specified when calling `Trainer.fit()` instead of when constructing the trainer.

1. Refactored `Trainer.init` to move out shared helper code, so it can be re-used when passing in parameters on `fit`.
2. Modified the signature of `fit` to take training parameters. If not specified on `Trainer.__init__`, parameters must be specified on `fit`.
3. Rearranged the position of arguments in the `Trainer` and `TrainerHparams` classes to group by functionality. Re-ordered the docstrings to be in the same order as the arguments.
  • Loading branch information
ravi-mosaicml authored May 10, 2022
1 parent 3289073 commit b1e89b4
Show file tree
Hide file tree
Showing 14 changed files with 1,900 additions and 916 deletions.
39 changes: 27 additions & 12 deletions composer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,30 @@
:class:`~.logger.Logger` and :class:`~.time.Timestamp` are implemented under core.
"""

from composer.core.algorithm import Algorithm as Algorithm
from composer.core.callback import Callback as Callback
from composer.core.data_spec import DataSpec as DataSpec
from composer.core.engine import Engine as Engine
from composer.core.engine import Trace as Trace
from composer.core.evaluator import Evaluator as Evaluator
from composer.core.event import Event as Event
from composer.core.precision import Precision as Precision
from composer.core.state import State as State
from composer.core.time import Time as Time
from composer.core.time import Timestamp as Timestamp
from composer.core.time import TimeUnit as TimeUnit
from composer.core.algorithm import Algorithm
from composer.core.callback import Callback
from composer.core.data_spec import DataSpec, ensure_data_spec
from composer.core.engine import Engine, Trace
from composer.core.evaluator import Evaluator, ensure_evaluator
from composer.core.event import Event
from composer.core.precision import Precision
from composer.core.state import State
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time

__all__ = [
"Algorithm",
"Callback",
"DataSpec",
"ensure_data_spec",
"Engine",
"Trace",
"Evaluator",
"Event",
"Precision",
"State",
"Time",
"Timestamp",
"TimeUnit",
"ensure_time",
"ensure_evaluator",
]
37 changes: 32 additions & 5 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import collections.abc
import textwrap
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
import torch.utils.data
Expand All @@ -15,10 +15,10 @@
if TYPE_CHECKING:
from composer.core.types import Batch

__all__ = ["DataSpec"]
__all__ = ["DataSpec", "ensure_data_spec"]


def _split_list(l, num_microbatches):
def _split_list(l, num_microbatches: int):
if len(l) < num_microbatches:
raise ValueError(
textwrap.dedent(f"""\
Expand All @@ -27,7 +27,7 @@ def _split_list(l, num_microbatches):
return [l[i::num_microbatches] for i in range(num_microbatches)]


def _split_tensor(t, num_microbatches):
def _split_tensor(t, num_microbatches: int):
if len(t) < num_microbatches:
raise ValueError(
textwrap.dedent(f"""\
Expand All @@ -36,7 +36,7 @@ def _split_tensor(t, num_microbatches):
return t.chunk(num_microbatches)


def _split_mapping(m, num_microbatches):
def _split_mapping(m, num_microbatches: int):
chunked = {}
for k, v in m.items():
if isinstance(v, torch.Tensor):
Expand Down Expand Up @@ -176,6 +176,15 @@ def __init__(
else:
self.num_samples = None

if isinstance(dataloader, torch.utils.data.DataLoader) and dataloader._iterator is not None:
raise ValueError(
("The dataloader has an active iterator. This could occur "
"if `persistent_workers=True` and the dataloader has already been iterated, "
"or if the dataloader is mid-epoch. It is required that the training dataloader "
"does not have an active iterator, so CPU dataset augmentations can be "
"correctly inserted. To fix, please do not iterate over the dataloader before passing it into "
"the Trainer."))

def _default_device_transforms(self, batch: Batch):
return batch

Expand Down Expand Up @@ -203,3 +212,21 @@ def _default_get_num_samples_in_batch(self, batch: Batch) -> int:
def _default_get_num_tokens_in_batch(self, batch: Batch) -> int:
del batch # unused
return 0


def ensure_data_spec(dataloader: Union[DataSpec, Iterable, dict]) -> DataSpec:
"""Ensures that the ``dataloader`` is a :class:`.DataSpec`
Args:
dataloader (DataSpec | Iterable | dict): A DataSpec, DataLoader, or Dict of DataSpec kwargs.
Returns:
DataSpec: A DataSpec
"""
if isinstance(dataloader, dict):
# treat as kwargs for DataSpec
dataloader = DataSpec(**dataloader)
if not isinstance(dataloader, DataSpec):
dataloader = DataSpec(dataloader)

return dataloader
54 changes: 41 additions & 13 deletions composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from __future__ import annotations

import copy
from typing import Callable, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union

from torchmetrics import Metric, MetricCollection

from composer.core.data_spec import DataSpec
from composer.core.data_spec import DataSpec, ensure_data_spec
from composer.core.event import Event
from composer.core.state import State
from composer.core.time import Time, TimeUnit

__all__ = ["Evaluator", "evaluate_periodically"]
__all__ = ["Evaluator", "evaluate_periodically", "ensure_evaluator"]


def evaluate_periodically(eval_interval: Union[str, Time, int]):
Expand Down Expand Up @@ -72,8 +72,8 @@ class Evaluator:
Args:
label (str): Name of the Evaluator
dataloader (Union[DataSpec, Iterable]): Iterable that yields batches or a :class:`.DataSpec` for evaluation
data.
dataloader (DataSpec | Iterable | Dict[str, Any]): Iterable that yields batches, a :class:`.DataSpec` for evaluation,
or a Dict of :class:`.DataSpec` kwargs.
metrics (Metric | MetricCollection): :class:`torchmetrics.Metric` to log. ``metrics`` will be deep-copied to
ensure that each evaluator updates only its ``metrics``.
subset_num_batches (int, optional): The maximum number of batches to use for each evaluation. Defaults to
Expand All @@ -97,20 +97,19 @@ class Evaluator:
or :attr:`.Event.EPOCH_END`.
"""

_eval_interval: Optional[Callable[[State, Event], bool]]

def __init__(
self,
*,
label: str,
dataloader: Union[DataSpec, Iterable],
dataloader: Union[DataSpec, Iterable, Dict[str, Any]],
metrics: Union[Metric, MetricCollection],
subset_num_batches: Optional[int] = None,
eval_interval: Optional[Union[int, str, Time, Callable[[State, Event], bool]]] = None,
):
self.label = label
if isinstance(dataloader, DataSpec):
self.dataloader = dataloader
else:
self.dataloader = DataSpec(dataloader)
self.dataloader = ensure_data_spec(dataloader)

# Forcing metrics to be a MetricCollection simplifies logging results
metrics = copy.deepcopy(metrics)
Expand All @@ -120,10 +119,39 @@ def __init__(
self.metrics = metrics

self.subset_num_batches = subset_num_batches
self.eval_interval = eval_interval

@property
def eval_interval(self):
return self._eval_interval

@eval_interval.setter
def eval_interval(self, eval_interval: Optional[Union[int, str, Time, Callable[[State, Event], bool]]]):
if eval_interval is None:
self.should_eval = None
self._eval_interval = None
elif not callable(eval_interval):
self.should_eval = evaluate_periodically(eval_interval)
self._eval_interval = evaluate_periodically(eval_interval)
else:
self.should_eval = eval_interval
self._eval_interval = eval_interval


def ensure_evaluator(evaluator: Union[Evaluator, DataSpec, Iterable, Dict[str, Any]],
default_metrics: Union[Metric, MetricCollection]):
"""Ensure that ``evaluator`` is an :class:`.Evaluator`.
Args:
evaluator (Evaluator | DataSpec | Iterable | Dict[str, Any]): A dataloader,
:class:`.DataSpec` instance, dictionary of :class:`.DataSpec` kwargs, or existing evaluator.
default_metrics (Union[Metric, MetricCollection]): The metrics for the ``evaluator``, if a datalaoder was specified.
Returns:
Evaluator: An evaluator.
"""
if isinstance(evaluator, Evaluator):
return evaluator
else:
return Evaluator(
label="eval",
dataloader=evaluator,
metrics=default_metrics,
)
2 changes: 1 addition & 1 deletion composer/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class DataLoaderHparams(hp.Hparams):
If ``num_workers = 0``, then the ``pin_memory`` must be ``False``."""),
default=True)
timeout: float = hp.optional(
"Timeout, in seconds, for collecting a batch from workers. Set to ``0`` for no timeout.", default=0)
"Timeout, in seconds, for collecting a batch from workers. Set to ``0`` for no timeout.", default=0.0)

def initialize_object(
self,
Expand Down
2 changes: 1 addition & 1 deletion composer/models/ssd/hparams.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ schedulers:
train_batch_size: 1024
eval_batch_size: 1024
seed: 0
validate_every_n_epochs: 10
eval_interval: 10ep
grad_accum: 1
device:
gpu: {}
Expand Down
58 changes: 31 additions & 27 deletions composer/trainer/_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
import warnings
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, cast

import torch
import torch.utils.data
Expand All @@ -18,7 +18,9 @@


def _add_batch_config(config: Dict[str, Any], state: State):
assert state.dataloader is not None, "dataloader should be set on FIT_START, which is where the Deepspeed config is applied."
if state.dataloader is None:
raise ValueError(
"When using DeepSpeed, the `train_dataloader` must be specified when constructing the Trainer.")

grad_accum = state.grad_accum

Expand Down Expand Up @@ -55,20 +57,20 @@ def _add_batch_config(config: Dict[str, Any], state: State):
if "gradient_accumulation_steps" in config:
ds_grad_accum = config["gradient_accumulation_steps"]
if ds_grad_accum != grad_accum:
raise ValueError(f"Provided DeepSpeed configuration specifies grad accum={ds_grad_accum}, "
f"but the Mosaic trainer has been configured with grad accum={grad_accum}.")
raise ValueError((f"Provided DeepSpeed configuration specifies grad accum={ds_grad_accum}, "
f"but the Mosaic trainer has been configured with grad accum={grad_accum}."))

config["gradient_accumulation_steps"] = grad_accum


def _ensure_no_optim_in_config(config: Dict[str, Any]):
if "optimizer" in config:
raise ValueError("The DeepSpeed configuration specifies an optimizer, but the Mosaic "
"trainer will override this setting.")
raise ValueError(("The DeepSpeed configuration specifies an optimizer, but the Mosaic "
"trainer will override this setting."))

if "scheduler" in config:
raise ValueError("The DeepSpeed configuration specifies a scheduler, but the Mosaic "
"trainer will override this setting.")
raise ValueError(("The DeepSpeed configuration specifies a scheduler, but the Mosaic "
"trainer will override this setting."))


def _add_precision_config(config: Dict[str, Any], state: State):
Expand All @@ -78,15 +80,15 @@ def _add_precision_config(config: Dict[str, Any], state: State):
if "fp16" in config and "enabled" in config["fp16"] and config["fp16"]["enabled"]:
ds_precision = Precision.FP16
if "bf16" in config and "enabled" in config["bf16"] and config["bf16"]["enabled"]:
raise ValueError("DeepSpeed is configured to use BFLOAT16, but this is unsupported by the "
"Mosaic trainer.")
raise ValueError(("DeepSpeed is configured to use BFLOAT16, but this is unsupported by the "
"Mosaic trainer."))
if "amp" in config and "enabled" in config["amp"] and config["amp"]["enabled"]:
raise ValueError("DeepSpeed is configured to use Apex AMP, but this is unsupported by the "
"Mosaic trainer.")
raise ValueError(("DeepSpeed is configured to use Apex AMP, but this is unsupported by the "
"Mosaic trainer."))

if ds_precision is not None and ds_precision != precision:
raise ValueError(f"Provided DeepSpeed configuration specifies precision={ds_precision}, "
f"but the Mosaic trainer has been configured with precision={precision}.")
raise ValueError((f"Provided DeepSpeed configuration specifies precision={ds_precision}, "
f"but the Mosaic trainer has been configured with precision={precision}."))

if precision == Precision.FP16:
if "fp16" not in config:
Expand All @@ -99,28 +101,30 @@ def _add_precision_config(config: Dict[str, Any], state: State):
fp16_config.setdefault("loss_scale_window", 2000)


def _add_other_config(config: Dict[str, Any], grad_clip_norm: Optional[float]):
def _add_other_config(config: Dict[str, Any], grad_clip_norm: float):
if "gradient_clipping" in config:
ds_grad_clip_norm = config["gradient_clipping"]
if ds_grad_clip_norm != grad_clip_norm:
raise ValueError("Provided DeepSpeed configuration specifies grad clip norm="
f"{ds_grad_clip_norm}, but the Mosaic trainer has been configured "
f"with grad clip norm={grad_clip_norm}")
raise ValueError(("Provided DeepSpeed configuration specifies grad clip norm="
f"{ds_grad_clip_norm}, but the Mosaic trainer has been configured "
f"with grad clip norm={grad_clip_norm}"))

if grad_clip_norm is not None:
if grad_clip_norm >= 0:
config["gradient_clipping"] = grad_clip_norm

if "zero_allow_untested_optimizer" in config and not config["zero_allow_untested_optimizer"]:
warnings.warn("Provided DeepSpeed configuration specifies zero_allow_untested_optimizer=False. "
"This causes DeepSpeed to reject certain Mosaic optimizers that are known to "
"work well with DeepSpeed.")
warnings.warn(("Provided DeepSpeed configuration specifies zero_allow_untested_optimizer=False. "
"This causes DeepSpeed to reject certain Mosaic optimizers that are known to "
"work well with DeepSpeed."))

config["zero_allow_untested_optimizer"] = True


def _parse_deepspeed_config(config: Dict[str, Any],
state: State,
grad_clip_norm: Optional[float] = None) -> Dict[str, Any]:
def _parse_deepspeed_config(
config: Dict[str, Any],
state: State,
grad_clip_norm: float,
) -> Dict[str, Any]:
"""Parses the provided DeepSpeed config for compatibility with the Mosaic trainer.
Broadly speaking, this function does three things.
Expand All @@ -135,8 +139,8 @@ def _parse_deepspeed_config(config: Dict[str, Any],
config (Dict[str, Any]): The DeepSpeed config to use. Must follow the format specified
in `DeepSpeed's documentation <https://www.deepspeed.ai/docs/config-json/>`_.
state (State): The state of the trainer.
grad_clip_norm (Optional[float]): The norm to clip gradient magnitudes to.
``None`` results in no gradient clipping. (default: ``None``)
grad_clip_norm (float, optional): The norm to clip gradient magnitudes to. Set to ``-1``
for no gradient clipping. (default: ``-1.0``)
Returns:
Dict[str, Any]: The DeepSpeed config updated with values from the arguments passed to the
Expand Down
Loading

0 comments on commit b1e89b4

Please sign in to comment.