Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-gpu testing to test_algorithm_resumption #2016

Merged
merged 11 commits into from
Mar 1, 2023
4 changes: 4 additions & 0 deletions composer/algorithms/gyro_dropout/gyro_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def __init__(self, p: float = 0.5, sigma: int = 256, tau: int = 16):
self.sigma = sigma
self.tau = tau

log.warning(
'GyroDropout is not implemented in a way that allows correct resumption from checkpoint, which may lead to incorrect behavior.'
)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'

Expand Down
3 changes: 3 additions & 0 deletions composer/algorithms/sam/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def __init__(
epsilon: float = 1.0e-12,
interval: int = 1,
):
log.warning(
'SAM has known issues of weight mismatch when loading from a checkpoint, which will cause an error when resuming without `load_weights_only=True`.'
)
"""__init__ is constructed from the same fields as in hparams."""
self.rho = rho
self.epsilon = epsilon
Expand Down
4 changes: 4 additions & 0 deletions composer/algorithms/stochastic_depth/stochastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def __init__(self,
drop_distribution: str = 'linear',
drop_warmup: Union[float, Time, str] = 0.0):

log.warning(
'Stochastic depth has known issues of weight mismatch when loading from a checkpoint, which will cause an error when resuming without `load_weights_only=True`.'
)

if drop_rate == 0.0:
log.warning('Stochastic Depth will have no effect when drop_rate set to 0')

Expand Down
4 changes: 4 additions & 0 deletions composer/algorithms/swa/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def __init__(self,
anneal_strategy: str = 'linear',
anneal_steps: int = 10,
swa_lr: Optional[float] = None):

log.warning(
'SWA has known issues when resuming from a checkpoint on multiple GPUs, which will cause an error when resuming without `load_weights_only=True`.'
)
self.schedule_swa_lr = schedule_swa_lr
self.anneal_strategy = anneal_strategy
self.anneal_steps = anneal_steps
Expand Down
29 changes: 17 additions & 12 deletions tests/algorithms/algorithm_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Dict, Optional, Type

import pytest
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader

import composer
import composer.algorithms
Expand All @@ -24,6 +24,7 @@
WeightStandardization)
from composer.models import composer_resnet
from composer.models.base import ComposerModel
from composer.utils import dist
from tests.common import get_module_subclasses
from tests.common.datasets import RandomImageDataset, SimpleDataset, dummy_bert_lm_dataloader, dummy_gpt_lm_dataloader
from tests.common.models import (SimpleConvModel, SimpleModelWithDropout, configure_tiny_bert_hf_model,
Expand Down Expand Up @@ -215,25 +216,29 @@ def get_alg_model(alg_cls: Type[Algorithm]) -> ComposerModel:
return cls(**kwargs)


def get_alg_dataloader(alg_cls: Type[Algorithm]) -> DataLoader:
def get_alg_dataloader(alg_cls: Type[Algorithm], multigpu=False) -> DataLoader:
"""Return an instance of the dataset for an algorithm."""
settings = _get_alg_settings(alg_cls)

if 'dataloader' in settings:
settings = settings['dataloader']
dataloader_cls, kwargs = settings['dataloader']
if 'dataset' in kwargs and multigpu:
kwargs['sampler'] = dist.get_sampler(kwargs['dataset'])

dataloader = dataloader_cls(**kwargs)

elif 'dataset' in settings:
settings = settings['dataset']
if isinstance(settings['dataset'], tuple):
dataset_cls, kwargs = settings['dataset']
else:
dataset_cls = settings['dataset']
kwargs = {}
dataset = dataset_cls(**kwargs)
sampler = dist.get_sampler(dataset) if multigpu else None
dataloader = DataLoader(dataset=dataset, batch_size=4, sampler=sampler)
else:
raise ValueError(f'Neither dataset nor dataloader have been provided for algorithm {alg_cls}')

if isinstance(settings, tuple):
(cls, kwargs) = settings
else:
(cls, kwargs) = (settings, {})

dataloader = cls(**kwargs)
if isinstance(dataloader, Dataset):
dataloader = DataLoader(dataset=dataloader, batch_size=2)
return dataloader


Expand Down
43 changes: 28 additions & 15 deletions tests/algorithms/test_algorithm_resumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,27 @@
import torch

from composer import Algorithm, Trainer
from composer.algorithms import SAM, GyroDropout, LayerFreezing, SeqLengthWarmup, StochasticDepth
from composer.algorithms import SAM, SWA, GyroDropout, LayerFreezing, SeqLengthWarmup, StochasticDepth
from composer.utils import dist
from tests.algorithms.algorithm_settings import get_alg_dataloader, get_alg_kwargs, get_alg_model, get_algs_with_marks
from tests.common import deep_compare
from tests.common.markers import world_size


@pytest.mark.gpu
@pytest.mark.parametrize('alg_cls', get_algs_with_marks())
@pytest.mark.filterwarnings('ignore:Detected call of `lr_scheduler.step()'
) # optimizer.step() sometimes skipped when NaN/inf on low batch size
@world_size(1, 2)
def test_algorithm_resumption(
tmp_path: pathlib.Path,
alg_cls: Type[Algorithm],
world_size,
):
folder1 = os.path.join(tmp_path, 'folder1')
folder2 = os.path.join(tmp_path, 'folder2')
os.makedirs(folder1, exist_ok=True)
os.makedirs(folder2, exist_ok=True)

model = get_alg_model(alg_cls)
alg_kwargs = get_alg_kwargs(alg_cls)
Expand All @@ -40,20 +46,24 @@ def test_algorithm_resumption(
if alg_cls is GyroDropout:
pytest.xfail('GyroDropoutLayer is not implemented in a way that allows correct resumption.')

if alg_cls is SWA and world_size > 1:
pytest.xfail('SWA is not implemented in a way that is compatible correct resumption on multiple devices.')

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)

shared_config = {
'max_duration': '2ep',
'save_filename': 'ep{epoch}-rank{rank}',
'train_subset_num_batches': 4,
'save_interval': '1ep',
'train_subset_num_batches': 2,
'precision': 'amp_fp16',
}

train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True)
# train model once, saving checkpoints every epoch
trainer1 = Trainer(
model=model,
train_dataloader=get_alg_dataloader(alg_cls),
train_dataloader=train_dataloader,
optimizers=optimizer,
schedulers=scheduler,
save_folder=folder1,
Expand All @@ -74,9 +84,11 @@ def test_algorithm_resumption(
# when reloading.
if alg_cls is SeqLengthWarmup:
alg._activated = True # type: ignore

train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True)
trainer2 = Trainer(
model=copied_model,
train_dataloader=get_alg_dataloader(alg_cls),
train_dataloader=train_dataloader,
load_path=os.path.join(folder1, 'ep1-rank{rank}'),
load_weights_only=False,
load_strict_model_weights=False,
Expand All @@ -87,20 +99,21 @@ def test_algorithm_resumption(
**shared_config,
)
trainer2.fit()

# check that the checkpoints are equal
_assert_checkpoints_equal(
file1=os.path.join(folder1, 'ep2-rank0'),
file2=os.path.join(folder2, 'ep2-rank0'),
)
if world_size == 1 or dist.get_global_rank() == 0:
_assert_checkpoints_equal(
file1=os.path.join(folder1, 'ep2-rank0'),
file2=os.path.join(folder2, 'ep2-rank0'),
)

# check that different epoch checkpoints are _not_ equal
# this ensures that the model weights are being updated.
with pytest.raises(AssertionError):
_assert_model_weights_equal(
file1=os.path.join(folder1, 'ep1-rank0'),
file2=os.path.join(folder1, 'ep2-rank0'),
)
if world_size == 1 or dist.get_global_rank() == 0:
with pytest.raises(AssertionError):
_assert_model_weights_equal(
file1=os.path.join(folder1, 'ep1-rank0'),
file2=os.path.join(folder1, 'ep2-rank0'),
)


def _assert_checkpoints_equal(file1, file2):
Expand Down