Skip to content

Commit

Permalink
Remove the DDP DataLoader (mosaicml#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi-mosaicml authored and coryMosaicML committed Feb 23, 2022
1 parent 497f163 commit 5574b64
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 50 deletions.
1 change: 0 additions & 1 deletion composer/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from composer.datasets.brats import BratsDatasetHparams as BratsDatasetHparams
from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
from composer.datasets.dataloader import DDPDataLoader as DDPDataLoader
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
from composer.datasets.glue import GLUEHparams as GLUEHparams
from composer.datasets.hparams import DatasetHparams as DatasetHparams
Expand Down
40 changes: 0 additions & 40 deletions composer/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
from __future__ import annotations

import textwrap
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Optional

import torch
import torch.distributed
import torch.utils.data
import yahp as hp
from torch.utils.data.distributed import DistributedSampler

from composer.core.types import Batch, DataLoader, Dataset

Expand Down Expand Up @@ -45,44 +43,6 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


class DDPDataLoader(WrappedDataLoader):
"""Wraps the dataset to ensure that, if the dataset sampler is a
:class:`~torch.utils.data.distributed.DistributedSampler`, then
:meth:`~torch.utils.data.distributed.DistributedSampler.set_epoch`
is called after each epoch.
If the dataset sampler is not a :class:`~torch.utils.data.distributed.DistributedSampler`,
then this wrapper is a no-op.
"""

def __init__(self, dataloader: DataLoader) -> None:
super().__init__(dataloader)
self._iterator: Optional[Iterator[Batch]] = None

def __iter__(self) -> DDPDataLoader:
if self._iterator is not None:
warnings.warn(
"DataloaderMultipleIterationWarning: "
"The dataloader detected the start of a new iteration before the previous iteration finished. "
"The dataloader is skipping ahead to the start of the next epoch. "
"Multiple simultaneous iterations through the DDP dataloader prohibited, since "
"it automatically tracks the current epoch.")
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
self._iterator = iter(self.dataloader)
return self

def __next__(self) -> Batch:
assert self._iterator is not None
try:
return next(self._iterator)
except StopIteration:
self._iterator = None
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
raise


@dataclass
class DataloaderHparams(hp.Hparams):
"""Hyperparameters to initialize a :class:`~torch.utils.data.Dataloader`.
Expand Down
24 changes: 15 additions & 9 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from composer.core.algorithm import Algorithm
from composer.core.logging import BaseLoggerBackend, LogLevel
from composer.core.types import Batch, BreakEpochException, DataLoader, Metrics, Precision
from composer.datasets.dataloader import DDPDataLoader
from composer.loggers.tqdm_logger import TQDMLoggerBackend
from composer.models.base import BaseMosaicModel
from composer.optim import (ComposedScheduler, CosineAnnealingLRHparams, DecoupledSGDWHparams, OptimizerHparams,
Expand Down Expand Up @@ -216,10 +215,8 @@ def __init__(

if not isinstance(train_dataloader, DataSpec):
train_dataloader = DataSpec(train_dataloader)
train_dataloader.dataloader = DDPDataLoader(train_dataloader.dataloader)
if not isinstance(eval_dataloader, DataSpec):
eval_dataloader = DataSpec(eval_dataloader)
eval_dataloader.dataloader = DDPDataLoader(eval_dataloader.dataloader)

self._train_data_spec = train_dataloader
self._eval_data_spec = eval_dataloader
Expand Down Expand Up @@ -513,18 +510,17 @@ def _spin_dataloaders(self):
since only the first batch is being loaded, the dataloader may
not be completely iterated through.
"""
# surpressing this multiple iteration warning -- it is OK to ignore
warnings.filterwarnings(action="ignore", message=r"^DataloaderMultipleIterationWarning", append=True)
assert self.state.train_dataloader is not None, "train dataloader should be set"
assert self.state.eval_dataloader is not None, "eval dataloader should be set"

# spin the eval dataloader once to initialize its sampler deterministically
# so it does not affect any other RNG reads
if isinstance(self.state.eval_dataloader.sampler, torch.utils.data.DistributedSampler):
self.state.eval_dataloader.sampler.set_epoch(0)
for _ in self.state.eval_dataloader:
break

# spin the train dataloader's sampler to get to the state of the desired epoch
for _ in range(self.state.epoch):
for epoch in range(int(self.state.timer.epoch)):
if isinstance(self.state.train_dataloader.sampler, torch.utils.data.DistributedSampler):
self.state.train_dataloader.sampler.set_epoch(epoch)
for _ in self.state.train_dataloader:
break

Expand Down Expand Up @@ -570,6 +566,9 @@ def _train_loop(self) -> None:
self.engine.run_event(Event.EPOCH_START)
self.logger.metric_epoch({"epoch": self.state.epoch})

if isinstance(self.state.train_dataloader.sampler, torch.utils.data.DistributedSampler):
self.state.train_dataloader.sampler.set_epoch(int(self.state.timer.epoch))

for batch_idx, state.batch in enumerate(
itertools.islice(state.train_dataloader, self.state.steps_per_epoch)):

Expand Down Expand Up @@ -822,6 +821,13 @@ def eval(self, is_batch: bool):

metrics = self._get_metrics_as_collection(is_train=False)

if isinstance(self.state.eval_dataloader.sampler, torch.utils.data.DistributedSampler):
# The distributed sampler uses `set_epoch` to set the random seed
# Because evaluation can run on each batch, we use the batch to seed the sampler
# so each evaluation will get a proper shuffle.
# The epoch provided to `set_epoch` need not be sequential, so this is fine.
self.state.eval_dataloader.sampler.set_epoch(int(self.state.timer.batch))

for state.batch in itertools.islice(state.eval_dataloader, self._eval_subset_num_batches):
state.batch = self.device.batch_to_device(state.batch)
state.batch = self._eval_data_spec.device_transforms(state.batch)
Expand Down

0 comments on commit 5574b64

Please sign in to comment.