Skip to content

Commit

Permalink
ExperimentHparams class; Set state.train_dataloader (#966)
Browse files Browse the repository at this point in the history
- Added in an `ExperimentHparams` class. This class describes how to run a training job that may have multiple calls to `Trainer.fit` and/or `Trainer.eval`. Specifically, `ExperimentHparams.initialize_object()` returns a `(Trainer, List[FitKwargs], List[EvalKwargs])` tuple, that then the user's entrypoint can consome.
  This class does not automatically train the model, nor does it include an entrypoint.
- Added typing definitions for `FitKwargs` and `EvalKwargs`, along with test cases to ensure they stay in sync with the Trainer signature.
- Fix an bug introduced in #948, which removed the setting of `State.train_dataloader`. Added back the lines to correctly set the train dataloader.
  • Loading branch information
ravi-mosaicml authored May 11, 2022
1 parent 2e941d0 commit e85302b
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 70 deletions.
7 changes: 4 additions & 3 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
Each callback inherits from the :class:`~composer.core.callback.Callback` base class. See detailed description and
examples for writing your own callbacks at the :class:`~composer.core.callback.Callback` base class.
"""
from composer.callbacks.callback_hparams import (CallbackHparams, CheckpointSaverHparams, GradMonitorHparams,
LRMonitorHparams, MemoryMonitorHparams, MLPerfCallbackHparams,
SpeedMonitorHparams)
from composer.callbacks.callback_hparams import (CallbackHparams, CheckpointSaverHparams, EarlyStopperHparams,
GradMonitorHparams, LRMonitorHparams, MemoryMonitorHparams,
MLPerfCallbackHparams, SpeedMonitorHparams)
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.callbacks.grad_monitor import GradMonitor
from composer.callbacks.lr_monitor import LRMonitor
Expand All @@ -26,6 +26,7 @@
# hparams objects
"CallbackHparams",
"CheckpointSaverHparams",
"EarlyStopperHparams",
"GradMonitorHparams",
"LRMonitorHparams",
"MemoryMonitorHparams",
Expand Down
4 changes: 2 additions & 2 deletions composer/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from composer.trainer import devices as devices
from composer.trainer.trainer import Trainer
from composer.trainer.trainer_hparams import TrainerHparams
from composer.trainer.trainer_hparams import EvalHparams, ExperimentHparams, FitHparams, TrainerHparams

load = TrainerHparams.load

__all__ = ["Trainer", "TrainerHparams"]
__all__ = ["Trainer", "TrainerHparams", "ExperimentHparams", "FitHparams", "EvalHparams"]
2 changes: 2 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ def __init__(
if self._train_data_spec is not None:
self.state.set_dataloader(self._train_data_spec.dataloader, train_dataloader_label,
train_subset_num_batches)
self.state.train_dataloader = self.state.dataloader
self.train_metrics = _get_training_metrics(model) if compute_training_metrics else None

# Max Duration
Expand Down Expand Up @@ -1106,6 +1107,7 @@ def fit(
if train_dataloader is not None:
self._train_data_spec = ensure_data_spec(train_dataloader)
self.state.set_dataloader(self._train_data_spec.dataloader, train_dataloader_label)
self.state.train_dataloader = self.state.dataloader
if self._train_data_spec is None:
_raise_missing_argument_exception("train_dataloader")
if train_subset_num_batches is not None:
Expand Down
Loading

0 comments on commit e85302b

Please sign in to comment.