Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple calls to Trainer.fit() #948

Merged
merged 123 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
bd3e847
[Eval-Only]: Made the `state.dataloader` optional; removed `state.ste…
ravi-mosaicml Mar 25, 2022
558f279
Restored `dataloader_len` on state
ravi-mosaicml Apr 12, 2022
f5d0a1b
Fixed tests
ravi-mosaicml Apr 12, 2022
13c6e53
Merge branch 'dev' into i40_1
ravi-mosaicml Apr 12, 2022
e4facaa
Added `dataloader_label`; removed `evaluators` from State
ravi-mosaicml Apr 12, 2022
f3f47ea
Merge branch 'dev' into i40_1
ravi-mosaicml Apr 12, 2022
56fd87d
Fixed pyright
ravi-mosaicml Apr 12, 2022
b89c3bc
Fixed pyright
ravi-mosaicml Apr 12, 2022
79a3094
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 13, 2022
096b44c
Made `max_duration` optional
ravi-mosaicml Apr 13, 2022
d102506
Merge branch 'ravi/optional_dataloader' of github.com:mosaicml/compos…
ravi-mosaicml Apr 13, 2022
fe70133
Addressed PR feedback; fixed Time type annotations
ravi-mosaicml Apr 14, 2022
7e6e7cf
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 14, 2022
11544bc
Fixed doctests
ravi-mosaicml Apr 14, 2022
1313a49
Fixed selective backprop
ravi-mosaicml Apr 14, 2022
b5b6192
Inceased timeout
ravi-mosaicml Apr 14, 2022
ecc4a59
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 14, 2022
d3408ba
Remove optimizers from state on init; clean up PR
ravi-mosaicml Apr 15, 2022
1470d11
Bind the schedulers to the state in `__init__()`, rather than on `fit…
ravi-mosaicml Apr 15, 2022
7f00806
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 15, 2022
a4b2697
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 20, 2022
54d6b88
Fixed the deepspeed schedulers
ravi-mosaicml Apr 20, 2022
f067ed7
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 21, 2022
ebb9fe5
Multiple calls to fit/eval WIP
ravi-mosaicml Apr 21, 2022
3237c9d
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml Apr 21, 2022
af7bc21
WIP
ravi-mosaicml Apr 21, 2022
59af154
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 21, 2022
49e68e8
* Addressed PR Feedback
ravi-mosaicml Apr 21, 2022
93dd36f
Merge branch 'ravi/optional_dataloader' of github.com:mosaicml/compos…
ravi-mosaicml Apr 21, 2022
7d3a073
Merge branch 'ravi/optional_dataloader' into trainer_fit_eval_signature
ravi-mosaicml Apr 21, 2022
a746b9c
WIP
ravi-mosaicml Apr 21, 2022
da5bd40
Fixing the `dataloader_len` setter
ravi-mosaicml Apr 21, 2022
4d34efd
Merge branch 'ravi/optional_dataloader' into trainer_fit_eval_signature
ravi-mosaicml Apr 21, 2022
0b16878
Added docs
ravi-mosaicml Apr 21, 2022
42ab447
Fix tests
ravi-mosaicml Apr 22, 2022
8cac6be
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 22, 2022
f65e904
Merge branch 'ravi/optional_dataloader' into trainer_fit_eval_signature
ravi-mosaicml Apr 22, 2022
3816438
Fixed style
ravi-mosaicml Apr 22, 2022
ff4f400
Fix tests
ravi-mosaicml Apr 22, 2022
ea25fb5
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml Apr 22, 2022
4b2c09e
Added `Timer.reset`
ravi-mosaicml Apr 22, 2022
c7b165d
Multiple Evaluator Improvements
ravi-mosaicml Apr 22, 2022
a2f4896
Merge branch 'multi_eval_improvements' into trainer_fit_eval_signature
ravi-mosaicml Apr 22, 2022
643c8b7
Fix merge
ravi-mosaicml Apr 25, 2022
e436ea7
Merge branch 'dev' into multi_eval_improvements
ravi-mosaicml Apr 25, 2022
22a136f
Fixed yaml
ravi-mosaicml Apr 25, 2022
964dd2d
Fixed doctests
ravi-mosaicml Apr 25, 2022
50ae29f
Test Cleanup
ravi-mosaicml Apr 25, 2022
5471ef8
Refactor
ravi-mosaicml Apr 25, 2022
a007ddd
Merge test-cleanup
ravi-mosaicml Apr 25, 2022
45b45d5
Added eval only test cases
ravi-mosaicml Apr 25, 2022
5419804
Added `eval_subset_num_batches` and `eval_interval` tests.
ravi-mosaicml Apr 26, 2022
6f6c92f
Fix `test_memory_monitor`
ravi-mosaicml Apr 26, 2022
677389e
Merge branch 'test-cleanup' into multi_eval_improvements
ravi-mosaicml Apr 26, 2022
42e5c8a
Merge branch 'multi_eval_improvements' into trainer_fit_eval_signature
ravi-mosaicml Apr 26, 2022
14e5e6a
Added trainer fit or eval tests
ravi-mosaicml Apr 26, 2022
46cf48e
More cleanup
ravi-mosaicml Apr 26, 2022
231fe5e
Fixed docs
ravi-mosaicml Apr 26, 2022
065186a
Merge branch 'dev' into test-cleanup
ravi-mosaicml Apr 26, 2022
45befa9
Rename `TestMetricsCallback` to `MetricsCallback`
ravi-mosaicml Apr 26, 2022
0c34f36
Merge branch 'test-cleanup' of github.com:ravi-mosaicml/ravi-composer…
ravi-mosaicml Apr 26, 2022
b3564f1
Merge branch 'test-cleanup' into multi_eval_improvements
ravi-mosaicml Apr 26, 2022
1acfa4e
Merge branch 'multi_eval_improvements' into trainer_fit_eval_signature
ravi-mosaicml Apr 26, 2022
f3d57a7
Fixed docstrings
ravi-mosaicml Apr 26, 2022
eb3331c
Merge branch 'dev' into multi_eval_improvements
ravi-mosaicml Apr 26, 2022
9c4a7d9
Merge branch 'multi_eval_improvements' into trainer_fit_eval_signature
ravi-mosaicml Apr 26, 2022
a794b54
Fixed `make style`
ravi-mosaicml Apr 26, 2022
65e0e75
Trigger Jenkins
ravi-mosaicml Apr 26, 2022
486d2da
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml Apr 26, 2022
3ce4012
Merge branch 'dev' into multi_eval_improvements
ravi-mosaicml Apr 27, 2022
cc0b0fa
Merge branch 'multi_eval_improvements' into trainer_fit_eval_signature
ravi-mosaicml Apr 27, 2022
58b905f
Merge branch 'trainer_fit_eval_signature' of github.com:ravi-mosaicml…
ravi-mosaicml Apr 27, 2022
b2a5ca6
Require that optimizers are specified on `Trainer.__init__()`
ravi-mosaicml Apr 27, 2022
dab3e92
Syncronized the defaults between the hparams class and trainer class
ravi-mosaicml Apr 28, 2022
3212408
Merge branch 'dev' into multi_eval_improvements
ravi-mosaicml Apr 28, 2022
4173204
Merge branch 'dev' into multi_eval_improvements
ravi-mosaicml Apr 28, 2022
a5553a6
Merge branch 'multi_eval_improvements' of github.com:ravi-mosaicml/ra…
ravi-mosaicml Apr 28, 2022
19c9b5d
Fixed docstrings
ravi-mosaicml Apr 28, 2022
9d01e57
Fixed checkpoint test merge
ravi-mosaicml Apr 28, 2022
ba0c6a8
Merge branch 'multi_eval_improvements' into trainer_fit_eval_signature
ravi-mosaicml Apr 28, 2022
663be21
Fix a bad merge
ravi-mosaicml Apr 28, 2022
134d4b3
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml Apr 28, 2022
f2010ec
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml Apr 29, 2022
6d17ff7
Cleanup
ravi-mosaicml Apr 29, 2022
762be8d
Added tests for multiple calls to fit
ravi-mosaicml Apr 29, 2022
c72dab5
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml Apr 29, 2022
87b5467
Fix FP16 Precision with DeepSpeed
ravi-mosaicml May 2, 2022
afa8db4
Fix tests
ravi-mosaicml May 2, 2022
ec3e662
Address PR feedback
ravi-mosaicml May 2, 2022
b87de7f
Test DeepSpeed on all precisions
ravi-mosaicml May 2, 2022
a3bedff
Testing all precisions
ravi-mosaicml May 2, 2022
7f92f36
Fix tests
ravi-mosaicml May 3, 2022
70304fd
Merge branch 'fix_deepspeed' into trainer_fit_eval_signature
ravi-mosaicml May 3, 2022
732510b
Addressed some PR feedback
ravi-mosaicml May 3, 2022
18b83dc
More changes
ravi-mosaicml May 3, 2022
0c4727d
More refactoring
ravi-mosaicml May 3, 2022
5b0e683
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 3, 2022
0965cb1
Fix tests
ravi-mosaicml May 3, 2022
2662dc3
Specify the train dataloader in init when using deepspeed
ravi-mosaicml May 3, 2022
6b1cfea
Fix the precision
ravi-mosaicml May 3, 2022
977a8dd
Fixed docs build
ravi-mosaicml May 3, 2022
e581acc
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 3, 2022
a5ad524
Addressing PR feedback
ravi-mosaicml May 3, 2022
7bfba07
Fixed duration in `Trainer.fit`
ravi-mosaicml May 3, 2022
49b35f0
Addressed more PR feedback
ravi-mosaicml May 3, 2022
21f6f44
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 3, 2022
71c5f65
Refactored grad_accum
ravi-mosaicml May 3, 2022
9bbc212
Refactored default scheduler frequency function
ravi-mosaicml May 3, 2022
faea817
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 3, 2022
06530a9
Fixed tests
ravi-mosaicml May 3, 2022
8d6c0c3
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 4, 2022
c56179d
Fixed tests; addressed some PR feedback
ravi-mosaicml May 4, 2022
db3a7e3
Addressed more PR feedback
ravi-mosaicml May 4, 2022
8d1e908
Added docs
ravi-mosaicml May 4, 2022
87152b7
Fixed docs
ravi-mosaicml May 4, 2022
a578676
Fixed a bug with timing. Added tests
ravi-mosaicml May 4, 2022
e17bdca
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 5, 2022
8a94b44
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 6, 2022
a8dc9be
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 6, 2022
64b8a43
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 7, 2022
96fdb93
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 9, 2022
94e0476
Merge branch 'dev' into trainer_fit_eval_signature
ravi-mosaicml May 9, 2022
f5b518f
Addressed PR feedback
ravi-mosaicml May 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions composer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,30 @@
:class:`~.logger.Logger` and :class:`~.time.Timestamp` are implemented under core.
"""

from composer.core.algorithm import Algorithm as Algorithm
from composer.core.callback import Callback as Callback
from composer.core.data_spec import DataSpec as DataSpec
from composer.core.engine import Engine as Engine
from composer.core.engine import Trace as Trace
from composer.core.evaluator import Evaluator as Evaluator
from composer.core.event import Event as Event
from composer.core.precision import Precision as Precision
from composer.core.state import State as State
from composer.core.time import Time as Time
from composer.core.time import Timestamp as Timestamp
from composer.core.time import TimeUnit as TimeUnit
from composer.core.algorithm import Algorithm
from composer.core.callback import Callback
from composer.core.data_spec import DataSpec, ensure_data_spec
from composer.core.engine import Engine, Trace
from composer.core.evaluator import Evaluator, ensure_evaluator
from composer.core.event import Event
from composer.core.precision import Precision
from composer.core.state import State
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time

__all__ = [
"Algorithm",
"Callback",
"DataSpec",
"ensure_data_spec",
"Engine",
"Trace",
"Evaluator",
"Event",
"Precision",
"State",
"Time",
"Timestamp",
"TimeUnit",
"ensure_time",
"ensure_evaluator",
]
37 changes: 32 additions & 5 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import collections.abc
import textwrap
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
import torch.utils.data
Expand All @@ -15,10 +15,10 @@
if TYPE_CHECKING:
from composer.core.types import Batch

__all__ = ["DataSpec"]
__all__ = ["DataSpec", "ensure_data_spec"]


def _split_list(l, num_microbatches):
def _split_list(l, num_microbatches: int):
if len(l) < num_microbatches:
raise ValueError(
textwrap.dedent(f"""\
Expand All @@ -27,7 +27,7 @@ def _split_list(l, num_microbatches):
return [l[i::num_microbatches] for i in range(num_microbatches)]


def _split_tensor(t, num_microbatches):
def _split_tensor(t, num_microbatches: int):
if len(t) < num_microbatches:
raise ValueError(
textwrap.dedent(f"""\
Expand All @@ -36,7 +36,7 @@ def _split_tensor(t, num_microbatches):
return t.chunk(num_microbatches)


def _split_mapping(m, num_microbatches):
def _split_mapping(m, num_microbatches: int):
chunked = {}
for k, v in m.items():
if isinstance(v, torch.Tensor):
Expand Down Expand Up @@ -176,6 +176,15 @@ def __init__(
else:
self.num_samples = None

if isinstance(dataloader, torch.utils.data.DataLoader) and dataloader._iterator is not None:
raise ValueError(
("The dataloader has an active iterator. This could occur "
"if `persistent_workers=True` and the dataloader has already been iterated, "
"or if the dataloader is mid-epoch. It is required that the training dataloader "
"does not have an active iterator, so CPU dataset augmentations can be "
"correctly inserted. To fix, please do not iterate over the dataloader before passing it into "
"the Trainer."))

def _default_device_transforms(self, batch: Batch):
return batch

Expand Down Expand Up @@ -203,3 +212,21 @@ def _default_get_num_samples_in_batch(self, batch: Batch) -> int:
def _default_get_num_tokens_in_batch(self, batch: Batch) -> int:
del batch # unused
return 0


def ensure_data_spec(dataloader: Union[DataSpec, Iterable, dict]) -> DataSpec:
"""Ensures that the ``dataloader`` is a :class:`.DataSpec`

Args:
dataloader (DataSpec | Iterable | dict): A DataSpec, DataLoader, or Dict of DataSpec kwargs.

Returns:
DataSpec: A DataSpec
"""
if isinstance(dataloader, dict):
# treat as kwargs for DataSpec
dataloader = DataSpec(**dataloader)
if not isinstance(dataloader, DataSpec):
dataloader = DataSpec(dataloader)

return dataloader
54 changes: 41 additions & 13 deletions composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from __future__ import annotations

import copy
from typing import Callable, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union

from torchmetrics import Metric, MetricCollection

from composer.core.data_spec import DataSpec
from composer.core.data_spec import DataSpec, ensure_data_spec
from composer.core.event import Event
from composer.core.state import State
from composer.core.time import Time, TimeUnit

__all__ = ["Evaluator", "evaluate_periodically"]
__all__ = ["Evaluator", "evaluate_periodically", "ensure_evaluator"]


def evaluate_periodically(eval_interval: Union[str, Time, int]):
Expand Down Expand Up @@ -72,8 +72,8 @@ class Evaluator:

Args:
label (str): Name of the Evaluator
dataloader (Union[DataSpec, Iterable]): Iterable that yields batches or a :class:`.DataSpec` for evaluation
data.
dataloader (DataSpec | Iterable | Dict[str, Any]): Iterable that yields batches, a :class:`.DataSpec` for evaluation,
or a Dict of :class:`.DataSpec` kwargs.
metrics (Metric | MetricCollection): :class:`torchmetrics.Metric` to log. ``metrics`` will be deep-copied to
ensure that each evaluator updates only its ``metrics``.
subset_num_batches (int, optional): The maximum number of batches to use for each evaluation. Defaults to
Expand All @@ -97,20 +97,19 @@ class Evaluator:
or :attr:`.Event.EPOCH_END`.
"""

_eval_interval: Optional[Callable[[State, Event], bool]]

def __init__(
self,
*,
label: str,
dataloader: Union[DataSpec, Iterable],
dataloader: Union[DataSpec, Iterable, Dict[str, Any]],
metrics: Union[Metric, MetricCollection],
subset_num_batches: Optional[int] = None,
eval_interval: Optional[Union[int, str, Time, Callable[[State, Event], bool]]] = None,
):
self.label = label
if isinstance(dataloader, DataSpec):
self.dataloader = dataloader
else:
self.dataloader = DataSpec(dataloader)
self.dataloader = ensure_data_spec(dataloader)

# Forcing metrics to be a MetricCollection simplifies logging results
metrics = copy.deepcopy(metrics)
Expand All @@ -120,10 +119,39 @@ def __init__(
self.metrics = metrics

self.subset_num_batches = subset_num_batches
self.eval_interval = eval_interval

@property
def eval_interval(self):
return self._eval_interval

@eval_interval.setter
def eval_interval(self, eval_interval: Optional[Union[int, str, Time, Callable[[State, Event], bool]]]):
if eval_interval is None:
self.should_eval = None
self._eval_interval = None
elif not callable(eval_interval):
self.should_eval = evaluate_periodically(eval_interval)
self._eval_interval = evaluate_periodically(eval_interval)
else:
self.should_eval = eval_interval
self._eval_interval = eval_interval


def ensure_evaluator(evaluator: Union[Evaluator, DataSpec, Iterable, Dict[str, Any]],
default_metrics: Union[Metric, MetricCollection]):
"""Ensure that ``evaluator`` is an :class:`.Evaluator`.

Args:
evaluator (Evaluator | DataSpec | Iterable | Dict[str, Any]): A dataloader,
:class:`.DataSpec` instance, dictionary of :class:`.DataSpec` kwargs, or existing evaluator.
default_metrics (Union[Metric, MetricCollection]): The metrics for the ``evaluator``, if a datalaoder was specified.

Returns:
Evaluator: An evaluator.
"""
if isinstance(evaluator, Evaluator):
return evaluator
else:
return Evaluator(
label="eval",
dataloader=evaluator,
metrics=default_metrics,
)
2 changes: 1 addition & 1 deletion composer/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class DataLoaderHparams(hp.Hparams):
If ``num_workers = 0``, then the ``pin_memory`` must be ``False``."""),
default=True)
timeout: float = hp.optional(
"Timeout, in seconds, for collecting a batch from workers. Set to ``0`` for no timeout.", default=0)
"Timeout, in seconds, for collecting a batch from workers. Set to ``0`` for no timeout.", default=0.0)

def initialize_object(
self,
Expand Down
2 changes: 1 addition & 1 deletion composer/models/ssd/hparams.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ schedulers:
train_batch_size: 1024
eval_batch_size: 1024
seed: 0
validate_every_n_epochs: 10
eval_interval: 10ep
grad_accum: 1
device:
gpu: {}
Expand Down
58 changes: 31 additions & 27 deletions composer/trainer/_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
import warnings
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, cast

import torch
import torch.utils.data
Expand All @@ -18,7 +18,9 @@


def _add_batch_config(config: Dict[str, Any], state: State):
assert state.dataloader is not None, "dataloader should be set on FIT_START, which is where the Deepspeed config is applied."
if state.dataloader is None:
raise ValueError(
"When using DeepSpeed, the `train_dataloader` must be specified when constructing the Trainer.")

grad_accum = state.grad_accum

Expand Down Expand Up @@ -55,20 +57,20 @@ def _add_batch_config(config: Dict[str, Any], state: State):
if "gradient_accumulation_steps" in config:
ds_grad_accum = config["gradient_accumulation_steps"]
if ds_grad_accum != grad_accum:
raise ValueError(f"Provided DeepSpeed configuration specifies grad accum={ds_grad_accum}, "
f"but the Mosaic trainer has been configured with grad accum={grad_accum}.")
raise ValueError((f"Provided DeepSpeed configuration specifies grad accum={ds_grad_accum}, "
f"but the Mosaic trainer has been configured with grad accum={grad_accum}."))

config["gradient_accumulation_steps"] = grad_accum


def _ensure_no_optim_in_config(config: Dict[str, Any]):
if "optimizer" in config:
raise ValueError("The DeepSpeed configuration specifies an optimizer, but the Mosaic "
"trainer will override this setting.")
raise ValueError(("The DeepSpeed configuration specifies an optimizer, but the Mosaic "
"trainer will override this setting."))

if "scheduler" in config:
raise ValueError("The DeepSpeed configuration specifies a scheduler, but the Mosaic "
"trainer will override this setting.")
raise ValueError(("The DeepSpeed configuration specifies a scheduler, but the Mosaic "
"trainer will override this setting."))


def _add_precision_config(config: Dict[str, Any], state: State):
Expand All @@ -78,15 +80,15 @@ def _add_precision_config(config: Dict[str, Any], state: State):
if "fp16" in config and "enabled" in config["fp16"] and config["fp16"]["enabled"]:
ds_precision = Precision.FP16
if "bf16" in config and "enabled" in config["bf16"] and config["bf16"]["enabled"]:
raise ValueError("DeepSpeed is configured to use BFLOAT16, but this is unsupported by the "
"Mosaic trainer.")
raise ValueError(("DeepSpeed is configured to use BFLOAT16, but this is unsupported by the "
"Mosaic trainer."))
if "amp" in config and "enabled" in config["amp"] and config["amp"]["enabled"]:
raise ValueError("DeepSpeed is configured to use Apex AMP, but this is unsupported by the "
"Mosaic trainer.")
raise ValueError(("DeepSpeed is configured to use Apex AMP, but this is unsupported by the "
"Mosaic trainer."))

if ds_precision is not None and ds_precision != precision:
raise ValueError(f"Provided DeepSpeed configuration specifies precision={ds_precision}, "
f"but the Mosaic trainer has been configured with precision={precision}.")
raise ValueError((f"Provided DeepSpeed configuration specifies precision={ds_precision}, "
f"but the Mosaic trainer has been configured with precision={precision}."))

if precision == Precision.FP16:
if "fp16" not in config:
Expand All @@ -99,28 +101,30 @@ def _add_precision_config(config: Dict[str, Any], state: State):
fp16_config.setdefault("loss_scale_window", 2000)


def _add_other_config(config: Dict[str, Any], grad_clip_norm: Optional[float]):
def _add_other_config(config: Dict[str, Any], grad_clip_norm: float):
if "gradient_clipping" in config:
ds_grad_clip_norm = config["gradient_clipping"]
if ds_grad_clip_norm != grad_clip_norm:
raise ValueError("Provided DeepSpeed configuration specifies grad clip norm="
f"{ds_grad_clip_norm}, but the Mosaic trainer has been configured "
f"with grad clip norm={grad_clip_norm}")
raise ValueError(("Provided DeepSpeed configuration specifies grad clip norm="
f"{ds_grad_clip_norm}, but the Mosaic trainer has been configured "
f"with grad clip norm={grad_clip_norm}"))

if grad_clip_norm is not None:
if grad_clip_norm >= 0:
config["gradient_clipping"] = grad_clip_norm

if "zero_allow_untested_optimizer" in config and not config["zero_allow_untested_optimizer"]:
warnings.warn("Provided DeepSpeed configuration specifies zero_allow_untested_optimizer=False. "
"This causes DeepSpeed to reject certain Mosaic optimizers that are known to "
"work well with DeepSpeed.")
warnings.warn(("Provided DeepSpeed configuration specifies zero_allow_untested_optimizer=False. "
"This causes DeepSpeed to reject certain Mosaic optimizers that are known to "
"work well with DeepSpeed."))

config["zero_allow_untested_optimizer"] = True


def _parse_deepspeed_config(config: Dict[str, Any],
state: State,
grad_clip_norm: Optional[float] = None) -> Dict[str, Any]:
def _parse_deepspeed_config(
config: Dict[str, Any],
state: State,
grad_clip_norm: float,
) -> Dict[str, Any]:
"""Parses the provided DeepSpeed config for compatibility with the Mosaic trainer.

Broadly speaking, this function does three things.
Expand All @@ -135,8 +139,8 @@ def _parse_deepspeed_config(config: Dict[str, Any],
config (Dict[str, Any]): The DeepSpeed config to use. Must follow the format specified
in `DeepSpeed's documentation <https://www.deepspeed.ai/docs/config-json/>`_.
state (State): The state of the trainer.
grad_clip_norm (Optional[float]): The norm to clip gradient magnitudes to.
``None`` results in no gradient clipping. (default: ``None``)
grad_clip_norm (float, optional): The norm to clip gradient magnitudes to. Set to ``-1``
for no gradient clipping. (default: ``-1.0``)

Returns:
Dict[str, Any]: The DeepSpeed config updated with values from the arguments passed to the
Expand Down
Loading