Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the DDP DataLoader #245

Merged
merged 5 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion composer/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from composer.datasets.brats import BratsDatasetHparams as BratsDatasetHparams
from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
from composer.datasets.dataloader import DDPDataLoader as DDPDataLoader
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
from composer.datasets.glue import GLUEHparams as GLUEHparams
from composer.datasets.hparams import DatasetHparams as DatasetHparams
Expand Down
40 changes: 0 additions & 40 deletions composer/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
from __future__ import annotations

import textwrap
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Optional

import torch
import torch.distributed
import torch.utils.data
import yahp as hp
from torch.utils.data.distributed import DistributedSampler

from composer.core.types import Batch, DataLoader, Dataset

Expand Down Expand Up @@ -45,44 +43,6 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


class DDPDataLoader(WrappedDataLoader):
"""Wraps the dataset to ensure that, if the dataset sampler is a
:class:`~torch.utils.data.distributed.DistributedSampler`, then
:meth:`~torch.utils.data.distributed.DistributedSampler.set_epoch`
is called after each epoch.

If the dataset sampler is not a :class:`~torch.utils.data.distributed.DistributedSampler`,
then this wrapper is a no-op.
"""

def __init__(self, dataloader: DataLoader) -> None:
super().__init__(dataloader)
self._iterator: Optional[Iterator[Batch]] = None

def __iter__(self) -> DDPDataLoader:
if self._iterator is not None:
warnings.warn(
"DataloaderMultipleIterationWarning: "
"The dataloader detected the start of a new iteration before the previous iteration finished. "
"The dataloader is skipping ahead to the start of the next epoch. "
"Multiple simultaneous iterations through the DDP dataloader prohibited, since "
"it automatically tracks the current epoch.")
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
self._iterator = iter(self.dataloader)
return self

def __next__(self) -> Batch:
assert self._iterator is not None
try:
return next(self._iterator)
except StopIteration:
self._iterator = None
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch=self.sampler.epoch + 1)
raise


@dataclass
class DataloaderHparams(hp.Hparams):
"""Hyperparameters to initialize a :class:`~torch.utils.data.Dataloader`.
Expand Down
22 changes: 13 additions & 9 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from composer.core.algorithm import Algorithm
from composer.core.logging import BaseLoggerBackend, LogLevel
from composer.core.types import BreakEpochException, DataLoader, Metrics, Precision
from composer.datasets.dataloader import DDPDataLoader
from composer.loggers.tqdm_logger import TQDMLoggerBackend
from composer.models.base import BaseMosaicModel
from composer.optim import (ComposedScheduler, CosineAnnealingLRHparams, DecoupledSGDWHparams, OptimizerHparams,
Expand Down Expand Up @@ -212,10 +211,8 @@ def __init__(

if not isinstance(train_dataloader, DataSpec):
train_dataloader = DataSpec(train_dataloader)
train_dataloader.dataloader = DDPDataLoader(train_dataloader.dataloader)
if not isinstance(eval_dataloader, DataSpec):
eval_dataloader = DataSpec(eval_dataloader)
eval_dataloader.dataloader = DDPDataLoader(eval_dataloader.dataloader)

self._train_data_spec = train_dataloader
self._eval_data_spec = eval_dataloader
Expand Down Expand Up @@ -498,18 +495,17 @@ def _spin_dataloaders(self):
since only the first batch is being loaded, the dataloader may
not be completely iterated through.
"""
# surpressing this multiple iteration warning -- it is OK to ignore
warnings.filterwarnings(action="ignore", message=r"^DataloaderMultipleIterationWarning", append=True)
assert self.state.train_dataloader is not None, "train dataloader should be set"
assert self.state.eval_dataloader is not None, "eval dataloader should be set"

# spin the eval dataloader once to initialize its sampler deterministically
# so it does not affect any other RNG reads
if isinstance(self.state.eval_dataloader.sampler, torch.utils.data.DistributedSampler):
self.state.eval_dataloader.sampler.set_epoch(0)
for _ in self.state.eval_dataloader:
break

# spin the train dataloader's sampler to get to the state of the desired epoch
for _ in range(self.state.epoch):
for epoch in range(int(self.state.timer.epoch)):
if isinstance(self.state.train_dataloader.sampler, torch.utils.data.DistributedSampler):
self.state.train_dataloader.sampler.set_epoch(epoch)
for _ in self.state.train_dataloader:
break

Expand Down Expand Up @@ -555,6 +551,9 @@ def _train_loop(self) -> None:
self.engine.run_event(Event.EPOCH_START)
self.logger.metric_epoch({"epoch": self.state.epoch})

if isinstance(self.state.train_dataloader.sampler, torch.utils.data.DistributedSampler):
self.state.train_dataloader.sampler.set_epoch(int(self.state.timer.epoch))

for batch_idx, state.batch in enumerate(
itertools.islice(state.train_dataloader, self.state.steps_per_epoch)):

Expand Down Expand Up @@ -807,6 +806,11 @@ def eval(self, is_batch: bool):

metrics = self._get_metrics_as_collection(is_train=False)

if isinstance(self.state.eval_dataloader.sampler, torch.utils.data.DistributedSampler):
# using the batch as the epoch, as the evaluator may be running every batch
# instead of every epoch
self.state.eval_dataloader.sampler.set_epoch(int(self.state.timer.batch))

for state.batch in itertools.islice(state.eval_dataloader, self._eval_subset_num_batches):
state.batch = self.device.batch_to_device(state.batch)
state.batch = self._eval_data_spec.device_transforms(state.batch)
Expand Down