Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Eval-Only]: Optional timing and dataloader attributes on state; removed evaluators from the state. #832

Merged
merged 29 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bd3e847
[Eval-Only]: Made the `state.dataloader` optional; removed `state.ste…
ravi-mosaicml Mar 25, 2022
558f279
Restored `dataloader_len` on state
ravi-mosaicml Apr 12, 2022
f5d0a1b
Fixed tests
ravi-mosaicml Apr 12, 2022
13c6e53
Merge branch 'dev' into i40_1
ravi-mosaicml Apr 12, 2022
e4facaa
Added `dataloader_label`; removed `evaluators` from State
ravi-mosaicml Apr 12, 2022
f3f47ea
Merge branch 'dev' into i40_1
ravi-mosaicml Apr 12, 2022
56fd87d
Fixed pyright
ravi-mosaicml Apr 12, 2022
b89c3bc
Fixed pyright
ravi-mosaicml Apr 12, 2022
79a3094
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 13, 2022
096b44c
Made `max_duration` optional
ravi-mosaicml Apr 13, 2022
d102506
Merge branch 'ravi/optional_dataloader' of github.com:mosaicml/compos…
ravi-mosaicml Apr 13, 2022
fe70133
Addressed PR feedback; fixed Time type annotations
ravi-mosaicml Apr 14, 2022
7e6e7cf
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 14, 2022
11544bc
Fixed doctests
ravi-mosaicml Apr 14, 2022
1313a49
Fixed selective backprop
ravi-mosaicml Apr 14, 2022
b5b6192
Inceased timeout
ravi-mosaicml Apr 14, 2022
ecc4a59
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 14, 2022
d3408ba
Remove optimizers from state on init; clean up PR
ravi-mosaicml Apr 15, 2022
1470d11
Bind the schedulers to the state in `__init__()`, rather than on `fit…
ravi-mosaicml Apr 15, 2022
7f00806
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 15, 2022
a4b2697
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 20, 2022
54d6b88
Fixed the deepspeed schedulers
ravi-mosaicml Apr 20, 2022
f067ed7
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 21, 2022
59af154
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 21, 2022
49e68e8
* Addressed PR Feedback
ravi-mosaicml Apr 21, 2022
93dd36f
Merge branch 'ravi/optional_dataloader' of github.com:mosaicml/compos…
ravi-mosaicml Apr 21, 2022
da5bd40
Fixing the `dataloader_len` setter
ravi-mosaicml Apr 21, 2022
42ab447
Fix tests
ravi-mosaicml Apr 22, 2022
8cac6be
Merge branch 'dev' into ravi/optional_dataloader
ravi-mosaicml Apr 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions composer/algorithms/augmix/augmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,19 @@ def __init__(self,
self._transformed_datasets = weakref.WeakSet()

def match(self, event: Event, state: State) -> bool:
return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets
if event != Event.FIT_START:
return False
assert state.dataloader is not None, "dataloader should be defined on fit start"
return state.dataloader.dataset not in self._transformed_datasets

def apply(self, event: Event, state: State, logger: Logger) -> None:
am = AugmentAndMixTransform(severity=self.severity,
depth=self.depth,
width=self.width,
alpha=self.alpha,
augmentation_set=self.augmentation_set)
dataset = state.train_dataloader.dataset
assert state.dataloader is not None, "dataloader should be defined on fit start"
dataset = state.dataloader.dataset
if not isinstance(dataset, VisionDataset):
raise TypeError(
textwrap.dedent(f"""\
Expand Down
8 changes: 6 additions & 2 deletions composer/algorithms/colout/colout.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,15 @@ def match(self, event: Event, state: State) -> bool:
if self.batch:
return event == Event.AFTER_DATALOADER
else:
return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets
if event != Event.FIT_START:
return False
assert state.dataloader is not None, "dataloader should be defined on fit start"
return state.dataloader.dataset not in self._transformed_datasets

def _apply_sample(self, state: State) -> None:
"""Add the ColOut dataset transform to the dataloader."""
dataset = state.train_dataloader.dataset
assert state.dataloader is not None, "dataloader should be defined on fit start"
dataset = state.dataloader.dataset

transform = ColOutTransform(p_row=self.p_row, p_col=self.p_col, resize_target=self.resize_target)

Expand Down
4 changes: 3 additions & 1 deletion composer/algorithms/layer_freezing/layer_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
del event # unused
optimizers = state.optimizers
assert optimizers is not None
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed duration should be set on Event.EPOCH_END"
freeze_depth, freeze_percentage = freeze_layers(
model=state.model,
optimizers=optimizers,
current_duration=float(state.get_elapsed_duration()),
current_duration=float(elapsed_duration),
freeze_start=self.freeze_start,
freeze_level=self.freeze_level,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) ->
# Calculate the current size of the inputs to use
initial_size = self.initial_scale
finetune_fraction = self.finetune_fraction
scale_frac_elapsed = min([state.get_elapsed_duration().value / (1 - finetune_fraction), 1])
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed duration should be set on Event.AFTER_DATALOADER"
scale_frac_elapsed = min([elapsed_duration.value / (1 - finetune_fraction), 1])

# Linearly increase to full size at the start of the fine tuning period
scale_factor = initial_size + (1 - initial_size) * scale_frac_elapsed
Expand Down
9 changes: 6 additions & 3 deletions composer/algorithms/randaugment/randaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,15 @@ def __init__(self, severity: int = 9, depth: int = 2, augmentation_set: str = "a
self._transformed_datasets = weakref.WeakSet()

def match(self, event: Event, state: State) -> bool:
return event == Event.FIT_START and state.train_dataloader.dataset not in self._transformed_datasets
if event != Event.FIT_START:
return False
assert state.dataloader is not None, "dataloader should be defined on fit start"
return state.dataloader.dataset not in self._transformed_datasets

def apply(self, event: Event, state: State, logger: Logger) -> None:
ra = RandAugmentTransform(severity=self.severity, depth=self.depth, augmentation_set=self.augmentation_set)
assert state.train_dataloader is not None
dataset = state.train_dataloader.dataset
assert state.dataloader is not None, "dataloader should be defined on fit start"
dataset = state.dataloader.dataset
if not isinstance(dataset, VisionDataset):
raise TypeError(
textwrap.dedent(f"""\
Expand Down
10 changes: 7 additions & 3 deletions composer/algorithms/selective_backprop/selective_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.nn import functional as F

from composer.core import Algorithm, Event, State
from composer.core.precision import get_precision_context
from composer.loggers import Logger
from composer.models import ComposerModel

Expand Down Expand Up @@ -93,7 +94,7 @@ def select_using_loss(input: torch.Tensor,
This function runs an extra forward pass through the model on the batch of data.
If you are using a non-default precision, ensure that this forward pass
runs in your desired precision. For example:

.. testsetup::

N_sb, D_sb = 16, 8
Expand Down Expand Up @@ -223,8 +224,11 @@ def match(self, event: Event, state: State) -> bool:
if not is_keep:
return False

elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed duration should be set on Event.AFTER_DATALOADER"

is_chosen = should_selective_backprop(
current_duration=float(state.get_elapsed_duration()),
current_duration=float(elapsed_duration),
batch_idx=state.timer.batch_in_epoch.value,
start=self.start,
end=self.end,
Expand All @@ -250,6 +254,6 @@ def loss(p, y, reduction="none"):
assert self._loss_fn is not None, "loss_fn should be set on Event.INIT"
return self._loss_fn(p, (torch.Tensor(), y), reduction=reduction)

with state.precision_context:
with get_precision_context(state.precision):
new_input, new_target = select_using_loss(input, target, model, loss, self.keep, self.scale_factor)
state.batch = (new_input, new_target)
15 changes: 9 additions & 6 deletions composer/algorithms/seq_length_warmup/seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from composer.core import Algorithm, Event, State
from composer.core.precision import get_precision_context
from composer.core.time import TimeUnit
from composer.core.types import Batch
from composer.loggers import Logger
Expand Down Expand Up @@ -179,12 +180,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
{type(self).__name__} requires state.model to be of type {ComposerTransformer.__name__}, not of type {type(state.model)}"""
))

if state.train_dataloader.batch_size is None:
raise RuntimeError("Sequence Length Warmup algorithm requires constant batch size.")

self._original_model = state.model
return

assert state.dataloader is not None, "dataloader should be set on AFTER_DATALOADER"
assert state.max_duration is not None, "max_duration should be set on AFTER_DATALOADER"

# in order to avoid OOMs, we do a forward and a backward pass on a dummy input.
if not self._activated:
# ensure that input_ids is a valid model input. since we don't need the
Expand All @@ -204,7 +205,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
# all of the parameters
device = next(state.model.parameters()).device

per_gpu_macrobatch = state.train_dataloader.batch_size
per_gpu_macrobatch = state.dataloader.batch_size
if per_gpu_macrobatch is None:
raise RuntimeError("Sequence Length Warmup algorithm requires constant batch size.")
per_gpu_batch = ceil(per_gpu_macrobatch / state.grad_accum)
Expand All @@ -223,7 +224,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:

# start by running a forward and backward pass
# of the maximum sequence length to allocate cache.
with state.precision_context:
with get_precision_context(state.precision):
outputs = state.model.forward(model_inputs)
loss = self._original_model.loss(outputs, model_inputs)

Expand All @@ -238,7 +239,9 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
self._activated = True

if state.max_duration.unit == TimeUnit.EPOCH:
num_optimization_steps = state.steps_per_epoch * state.max_duration.value
if state.dataloader_len is None:
raise RuntimeError("Sequential Length Warmup requires the dataloader to be sized.")
num_optimization_steps = int(state.dataloader_len) * state.max_duration.value
elif state.max_duration.unit == TimeUnit.BATCH:
num_optimization_steps = state.max_duration.value
else:
Expand Down
6 changes: 4 additions & 2 deletions composer/algorithms/stochastic_depth/stochastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,10 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
logger.data_epoch({'stochastic_depth/num_stochastic_layers': num_stochastic_layers})

elif event == Event.BATCH_START:
if state.get_elapsed_duration() < self.drop_warmup:
current_drop_rate = float(state.get_elapsed_duration() / self.drop_warmup) * self.drop_rate
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed duration is set on BATCH_START"
if elapsed_duration < self.drop_warmup:
current_drop_rate = float(elapsed_duration / self.drop_warmup) * self.drop_rate
_update_drop_rate(state.model, stochastic_layer, current_drop_rate, self.drop_distribution)
else:
current_drop_rate = self.drop_rate
Expand Down
19 changes: 13 additions & 6 deletions composer/algorithms/swa/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,12 @@ def __init__(self,
self.match_event = Event.EPOCH_END

def match(self, event: Event, state: State) -> bool:
if event != self.match_event:
return False
if self.swa_start.unit == TimeUnit.DURATION:
should_start_swa = state.get_elapsed_duration() >= self.swa_start and not self.swa_completed
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed duration should be set on Event.BATCH_END or Event.EPOCH_END"
should_start_swa = elapsed_duration >= self.swa_start and not self.swa_completed
elif self.swa_start.unit == TimeUnit.EPOCH:
should_start_swa = state.timer.get("ep") >= self.swa_start and not self.swa_completed
else:
Expand Down Expand Up @@ -219,16 +223,19 @@ def apply(self, event: Event, state: State, logger: Logger) -> None:

self.step_counter += 1

elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed duration should be set on Event.BATCH_END or Event.EPOCH_END"

# Determine whether it's time to end SWA
if self.swa_end.unit == TimeUnit.DURATION and (state.get_elapsed_duration() >= self.swa_end):
if self.swa_end.unit == TimeUnit.DURATION and (elapsed_duration >= self.swa_end):
self.swa_completed = True
if self.swa_end.unit == TimeUnit.EPOCH and (state.timer.get("ep") >= self.swa_end):
self.swa_completed = True
if self.swa_completed:
if state.get_elapsed_duration() == 1:
log.warning("The baseline model was replaced with the SWA model after the end of "
"training. This means that SWA model will not have its batch norm "
"statistics updated. This will negatively impact accuracy. See the "
"documentation for the `swa_end` parameter for details.")
log.warning(("The baseline model was replaced with the SWA model after the end of "
"training. This means that SWA model will not have its batch norm "
"statistics updated. This will negatively impact accuracy. See the "
"documentation for the `swa_end` parameter for details."))
state.model.load_state_dict(self.swa_model.module.state_dict()) # type: ignore
log.info('Set model to the averaged model')
13 changes: 10 additions & 3 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def checkpoint_periodically(interval: Union[str, int, Time]) -> Callable[[State,

def save_interval(state: State, event: Event):
nonlocal last_checkpoint_batch
if state.get_elapsed_duration() >= 1.0:
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT"

if elapsed_duration >= 1.0:
# if doing batch-wise checkpointing, and we saved a checkpoint at the batch_checkpoint event
# right before the epoch_checkpoint event, do not save another checkpoint at the epoch_checkpoint
# event if the batch count didn't increase.
Expand Down Expand Up @@ -332,12 +335,16 @@ def fit_start(self, state: State, logger: Logger) -> None:
def batch_checkpoint(self, state: State, logger: Logger):
if self.save_interval(state, Event.BATCH_CHECKPOINT):
# If training is finished, log at the FIT loglevel
log_level = LogLevel.BATCH if state.get_elapsed_duration() < 1.0 else LogLevel.FIT
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed_duration is set on Event.BATCH_CHECKPOINT"
log_level = LogLevel.BATCH if elapsed_duration < 1.0 else LogLevel.FIT
self._save_checkpoint(state, logger, log_level)

def epoch_checkpoint(self, state: State, logger: Logger):
if self.save_interval(state, Event.EPOCH_CHECKPOINT):
log_level = LogLevel.EPOCH if state.get_elapsed_duration() < 1.0 else LogLevel.FIT
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, "elapsed_duration is set on Event.BATCH_CHECKPOINT"
log_level = LogLevel.EPOCH if elapsed_duration < 1.0 else LogLevel.FIT
self._save_checkpoint(state, logger, log_level)

def _save_checkpoint(self, state: State, logger: Logger, log_level: LogLevel):
Expand Down
1 change: 1 addition & 0 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def epoch_start(self, state: State, logger: Logger):
def batch_end(self, state: State, logger: Logger):
self.batch_end_times.append(time.time())
new_num_samples = state.timer.sample
assert self.batch_start_num_samples is not None, "self.batch_start_num_samples should have been set on Event.BATCH_START"
batch_num_samples = int(new_num_samples - self.batch_start_num_samples)
self.batch_num_samples.append(batch_num_samples)
self.train_examples_per_epoch += batch_num_samples
Expand Down
25 changes: 25 additions & 0 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,25 @@
Traces = Dict[str, "Trace"]

_ALWAYS_RECORD_EVENTS = [Event.INIT, Event.FIT_START, Event.EPOCH_START, Event.EPOCH_END]
_EVENTS_WHERE_DATALOADER_IS_SET = [e for e in Event if e != Event.INIT]
_EVENTS_WHERE_MAX_DURATION_IS_SET = [
Event.FIT_START,
Event.EPOCH_START,
Event.BATCH_START,
Event.AFTER_DATALOADER,
Event.BEFORE_TRAIN_BATCH,
Event.BEFORE_FORWARD,
Event.AFTER_FORWARD,
Event.BEFORE_LOSS,
Event.AFTER_LOSS,
Event.BEFORE_BACKWARD,
Event.AFTER_BACKWARD,
Event.AFTER_TRAIN_BATCH,
Event.BATCH_END,
Event.BATCH_CHECKPOINT,
Event.EPOCH_END,
Event.EPOCH_CHECKPOINT,
]


@dataclass
Expand Down Expand Up @@ -174,6 +193,12 @@ def run_event(
if event.is_after_event and duration_marker is not None:
duration_marker.finish()

if event in _EVENTS_WHERE_DATALOADER_IS_SET:
assert self.state.dataloader is not None, f"The trainer should have set state.dataloader for event {event}."

if event in _EVENTS_WHERE_MAX_DURATION_IS_SET:
assert self.state.max_duration is not None, f"The trainer should have set state.max_duration for event {event}."

if event == Event.INIT:
# For the INIT event, run the callbacks first to initialize the loggers
# For other events, run the algorithms first, so the callbacks have the state
Expand Down
37 changes: 36 additions & 1 deletion composer/core/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

"""Enum class for the numerical precision to be used by the model."""

import contextlib
from typing import Generator, Union

import torch
from packaging import version

from composer.utils.string_enum import StringEnum

__all__ = ["Precision"]
__all__ = ["Precision", "get_precision_context"]


class Precision(StringEnum):
Expand All @@ -23,3 +29,32 @@ class Precision(StringEnum):
FP16 = "fp16"
FP32 = "fp32"
BF16 = "bf16"


@contextlib.contextmanager
def get_precision_context(precision: Union[str, Precision]) -> Generator[None, None, None]:
"""Returns a context manager to automatically cast to a specific precision.

Args:
precision (str or Precision): Precision for the context
"""

precision = Precision(precision)
if precision == Precision.FP32:
if torch.cuda.is_available():
with torch.cuda.amp.autocast(False):
yield
else:
# Yield here to avoid warnings about cuda not being available
yield
elif precision == Precision.AMP:
# Retain compatibility with PyTorch < 1.10
with torch.cuda.amp.autocast(True):
yield
elif precision == Precision.BF16:
if version.parse(torch.__version__) < version.parse("1.10"):
raise ValueError(f"BF16 precision requires torch > 1.10, got version {torch.__version__}")
with torch.cuda.amp.autocast(True, torch.bfloat16):
yield
else:
raise ValueError(f"Unsupported precision: {precision}")
Loading