Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthetic Datasets and Subset Sampling #110

Merged
merged 67 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
6357f2e
Added `run_event` to callback
ravi-mosaicml Nov 15, 2021
f395df4
Removed callback helper methods
ravi-mosaicml Nov 16, 2021
0f1aa69
Fixed tests
ravi-mosaicml Nov 16, 2021
06cac4b
Formatting
ravi-mosaicml Nov 16, 2021
d886af6
Addressed PR feedback
ravi-mosaicml Nov 18, 2021
9644ad9
Fixed tests
ravi-mosaicml Nov 18, 2021
cf5e533
Formatting
ravi-mosaicml Nov 18, 2021
b1bf400
Fixed _run_event
ravi-mosaicml Nov 18, 2021
9bffe3b
Merge branch 'dev' into ravi/run_event
ravi-mosaicml Nov 19, 2021
4ed9f4f
Formatting
ravi-mosaicml Nov 19, 2021
75944eb
Removed ip
ravi-mosaicml Nov 19, 2021
c8ccb49
Merge branch 'dev' into ravi/run_event
ravi-mosaicml Nov 22, 2021
5214f39
Supporting both styles for callbacks
ravi-mosaicml Nov 23, 2021
47158fb
Minimizing Diff
ravi-mosaicml Nov 23, 2021
35faa29
Fixed tests
ravi-mosaicml Nov 23, 2021
d20c914
Merge branch 'dev' into ravi/run_event
ravi-mosaicml Nov 23, 2021
254bd51
Merge branch 'dev' into ravi/run_event
ravi-mosaicml Nov 23, 2021
f3aa6bd
Remove the composer.trainer.ddp class
ravi-mosaicml Nov 23, 2021
a28ce89
Merge branch 'ravi/run_event' into ravi/ddp_global
ravi-mosaicml Nov 23, 2021
c30a274
Merge branch 'dev' into ravi/run_event
ravi-mosaicml Nov 23, 2021
0509df5
Merge branch 'ravi/run_event' into ravi/ddp_global
ravi-mosaicml Nov 23, 2021
20059ac
Dataset and Dataloader Upgrades
ravi-mosaicml Nov 25, 2021
42be8f4
Fixed most tests and updated some docs
ravi-mosaicml Nov 30, 2021
0a6dadc
Merge branch 'dev' into ravi/remove_dataloader_spec
ravi-mosaicml Nov 30, 2021
bd72dea
Dataloader Upgrades
ravi-mosaicml Nov 30, 2021
80af818
Merge branch 'ravi/dataloaders_in_trainer' into ravi/remove_dataloade…
ravi-mosaicml Nov 30, 2021
c94ac5d
Cleaned up diff
ravi-mosaicml Dec 1, 2021
f2bb206
Copied device changes
ravi-mosaicml Dec 1, 2021
52258bf
Simplifying diff
ravi-mosaicml Dec 1, 2021
6b628eb
Fixed common dataset fields
ravi-mosaicml Dec 1, 2021
17e0a10
Fixes
ravi-mosaicml Dec 1, 2021
40a804f
Merge branch 'ravi/dataloaders_in_trainer' into ravi/remove_dataloade…
ravi-mosaicml Dec 1, 2021
e02250c
Removed extra newline
ravi-mosaicml Dec 1, 2021
025c13a
Added tests and docstrings
ravi-mosaicml Dec 1, 2021
c2eda28
Fixed license header
ravi-mosaicml Dec 1, 2021
9750fef
Increased num_classes to prevent ddp flakiness
ravi-mosaicml Dec 1, 2021
2baf762
Use sublcassing for optional fields
ravi-mosaicml Dec 1, 2021
df6e9ad
Removed prefetching in cuda streams
ravi-mosaicml Dec 1, 2021
5d07292
Added missing newline
ravi-mosaicml Dec 1, 2021
61e9f68
Merge branch 'ravi/dataloaders_in_trainer' into ravi/remove_dataloade…
ravi-mosaicml Dec 1, 2021
1fec74b
Fixed test cutmix
ravi-mosaicml Dec 1, 2021
7507898
Fixed docstrings
ravi-mosaicml Dec 1, 2021
5d852ff
Fixed docstrings
ravi-mosaicml Dec 1, 2021
f54b545
Merge branch 'dev' into ravi/dataloaders_in_trainer
ravi-mosaicml Dec 1, 2021
4ff04c2
Fixed formatting
ravi-mosaicml Dec 1, 2021
787df7d
Added in Dataloader to hparams
ravi-mosaicml Dec 1, 2021
cc3a470
Merge branch 'ravi/dataloaders_in_trainer' into ravi/remove_dataloade…
ravi-mosaicml Dec 1, 2021
e03e1e7
Updated dataset docs to reflect removed dataloader spec
ravi-mosaicml Dec 1, 2021
467d586
Fixed tests
ravi-mosaicml Dec 1, 2021
30cd7b8
Fixed formatting
ravi-mosaicml Dec 1, 2021
15a5582
Removed prefetch
ravi-mosaicml Dec 1, 2021
e7996b3
Merge branch 'ravi/dataloaders_in_trainer' into ravi/remove_dataloade…
ravi-mosaicml Dec 1, 2021
cf4d449
Simplified
ravi-mosaicml Dec 2, 2021
37ac41d
Update composer/trainer/trainer_hparams.py
ravi-mosaicml Dec 2, 2021
3395548
Update tests/test_dataset_registry.py
ravi-mosaicml Dec 2, 2021
2c06b86
Addressed PR Comments
ravi-mosaicml Dec 2, 2021
2e156a4
Removed profiler import
ravi-mosaicml Dec 2, 2021
97aab84
Addressed remaining PR feedback
ravi-mosaicml Dec 2, 2021
d32fb9b
Fixed tests
ravi-mosaicml Dec 2, 2021
7df6d9a
Fixed tests
ravi-mosaicml Dec 2, 2021
53cc738
Merge branch 'dev' into ravi/dataloaders_in_trainer
ravi-mosaicml Dec 3, 2021
ae398e8
Merge branch 'ravi/dataloaders_in_trainer' into ravi/remove_dataloade…
ravi-mosaicml Dec 3, 2021
b11abb7
Updated docs
ravi-mosaicml Dec 3, 2021
80f88d5
Create samplers before ddp
ravi-mosaicml Dec 3, 2021
a82aba6
Merge branch 'ravi/dataloaders_in_trainer' into ravi/remove_dataloade…
ravi-mosaicml Dec 3, 2021
0259704
Merge branch 'dev' into ravi/remove_dataloader_spec
ravi-mosaicml Dec 3, 2021
16062af
Minimizing diff
ravi-mosaicml Dec 3, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion composer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from composer import algorithms as algorithms
from composer import callbacks as callbacks
from composer import datasets as datasets
from composer import loggers as loggers
from composer import models as models
from composer import optim as optim
from composer import trainer as trainer
from composer import utils as utils
from composer.algorithms import functional as functional
from composer.core import Algorithm as Algorithm
from composer.core import Callback as Callback
Expand All @@ -8,5 +16,4 @@
from composer.core import Logger as Logger
from composer.core import State as State
from composer.core import types as types
from composer.datasets import DataloaderSpec as DataloaderSpec
from composer.trainer import Trainer as Trainer
2 changes: 1 addition & 1 deletion composer/callbacks/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import warnings
from typing import Sequence

from composer import Logger, State
from composer.callbacks.callback_hparams import BenchmarkerHparams
from composer.core import Logger, State
from composer.core.callback import Callback
from composer.core.types import BreakEpochException
from composer.utils import ddp
Expand Down
2 changes: 1 addition & 1 deletion composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from collections import deque
from typing import Deque, Optional

from composer import Logger, State
from composer.callbacks.callback_hparams import SpeedMonitorHparams
from composer.core import Logger, State
from composer.core.callback import RankZeroCallback
from composer.core.types import StateDict
from composer.utils import ddp
Expand Down
7 changes: 2 additions & 5 deletions composer/callbacks/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@

import warnings
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Optional
from typing import Optional

import torch.profiler
from torch.profiler.profiler import ProfilerAction

from composer import Callback
from composer.callbacks.callback_hparams import TorchProfilerHparams
from composer.core import Callback, Logger, State
from composer.core.types import StateDict
from composer.utils.ddp import get_global_rank
from composer.utils.run_directory import get_relative_to_run_directory

if TYPE_CHECKING:
from composer.core import Logger, State

_PROFILE_MISSING_ERROR = "The profiler has not been setup. Please call profiler.training_start() before training starts."


Expand Down
8 changes: 8 additions & 0 deletions composer/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from composer.core.precision import Precision as Precision
from composer.core.serializable import Serializable as Serializable
from composer.core.state import State as State
from composer.utils.string_enum import StringEnum

try:
from typing import Protocol
Expand Down Expand Up @@ -153,3 +154,10 @@ def __len__(self) -> int:
TDeviceTransformFn = Callable[[Batch], Batch]

StateDict = Dict[str, Any]


class MemoryFormat(StringEnum):
CONTIGUOUS_FORMAT = "contiguous_format"
CHANNELS_LAST = "channels_last"
CHANNELS_LAST_3D = "channels_last_3d"
PRESERVE_FORMAT = "preserve_format"
6 changes: 4 additions & 2 deletions composer/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
from composer.datasets.hparams import DataloaderSpec as DataloaderSpec
from composer.datasets.hparams import DatasetHparams as DatasetHparams
from composer.datasets.hparams import SyntheticHparamsMixin as SyntheticHparamsMixin
from composer.datasets.imagenet import ImagenetDatasetHparams as ImagenetDatasetHparams
from composer.datasets.lm_datasets import LMDatasetHparams as LMDatasetHparams
from composer.datasets.mnist import MNISTDatasetHparams as MNISTDatasetHparams
from composer.datasets.synthetic import MemoryFormat as MemoryFormat
from composer.datasets.synthetic import SyntheticDataset as SyntheticDataset
from composer.datasets.synthetic import SyntheticDatasetHparams as SyntheticDatasetHparams
from composer.datasets.synthetic import SyntheticBatchPairDataset as SyntheticBatchPairDataset
from composer.datasets.synthetic import SyntheticDataLabelType as SyntheticDataLabelType
from composer.datasets.synthetic import SyntheticDataType as SyntheticDataType
26 changes: 11 additions & 15 deletions composer/datasets/brats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@

import numpy as np
import torch
import torch.utils.data
import torchvision
import yahp as hp
from torch.utils.data import Dataset

from composer.core.types import DataLoader, Dataset
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams
from composer.utils import ddp
from composer.utils.data import get_subset_dataset

PATCH_SIZE = [1, 192, 160]


def my_collate(batch):
def _my_collate(batch):
"""Custom collate function to handle images with different depths.

"""
Expand All @@ -34,30 +35,25 @@ class BratsDatasetHparams(DatasetHparams):
"""Defines an instance of the BraTS dataset for image segmentation.

Parameters:
is_train (bool): Whether to load the training or validation dataset.
datadir (str): Data directory to use.
download (bool): Whether to download the dataset, if needed.
drop_last (bool): Whether to drop the last samples for the last batch.
shuffle (bool): Whether to shuffle the dataset for each epoch.
oversampling (float): The oversampling ratio to use.
"""

is_train: bool = hp.required("whether to load the training or validation dataset")
datadir: str = hp.required("data directory")
download: bool = hp.required("whether to download the dataset, if needed")
drop_last: bool = hp.optional("Whether to drop the last samples for the last batch", default=True)
shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch", default=True)
oversampling: float = hp.optional("oversampling", default=0.33)

def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataLoader:

datadir = self.datadir
oversampling = self.oversampling

x_train, y_train, x_val, y_val = get_data_split(datadir)
if self.datadir is None:
raise ValueError("datadir must be specified if self.synthetic is False")
x_train, y_train, x_val, y_val = get_data_split(self.datadir)
dataset = PytTrain(x_train, y_train, oversampling) if self.is_train else PytVal(x_val, y_val)
collate_fn = None if self.is_train else my_collate
if self.subset_num_batches is not None:
size = batch_size * self.subset_num_batches * ddp.get_world_size()
dataset = get_subset_dataset(size, dataset)
collate_fn = None if self.is_train else _my_collate
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return dataloader_hparams.initialize_object(
dataset=dataset,
batch_size=batch_size,
Expand Down
83 changes: 49 additions & 34 deletions composer/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,74 @@

from dataclasses import dataclass

import torch.utils.data
import yahp as hp
from torchvision import transforms
from torchvision.datasets import CIFAR10

from composer.core.types import DataLoader
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.utils import ddp
from composer.utils.data import get_subset_dataset


@dataclass
class CIFAR10DatasetHparams(DatasetHparams):
class CIFAR10DatasetHparams(DatasetHparams, SyntheticHparamsMixin):
"""Defines an instance of the CIFAR-10 dataset for image classification.

Parameters:
is_train (bool): Whether to load the training or validation dataset.
datadir (str): Data directory to use.
download (bool): Whether to download the dataset, if needed.
drop_last (bool): Whether to drop the last samples for the last batch.
shuffle (bool): Whether to shuffle the dataset for each epoch.
"""

is_train: bool = hp.required("whether to load the training or validation dataset")
datadir: str = hp.required("data directory")
download: bool = hp.required("whether to download the dataset, if needed")
drop_last: bool = hp.optional("Whether to drop the last samples for the last batch", default=True)
shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch", default=True)
download: bool = hp.optional("whether to download the dataset, if needed", default=True)

def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataLoader:
cifar10_mean, cifar10_std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
datadir = self.datadir

if self.is_train:
transformation = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
])

if self.use_synthetic:
if self.subset_num_batches is None:
raise ValueError("subset_num_batches is required if use_synthetic is True")
dataset = SyntheticBatchPairDataset(
total_dataset_size=self.subset_num_batches * batch_size,
data_shape=[3, 32, 32],
num_classes=10,
num_unique_samples_to_create=self.synthetic_num_unique_samples,
device=self.synthetic_device,
memory_format=self.synthetic_memory_format,
)
if self.shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = torch.utils.data.SequentialSampler(dataset)

else:
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
])

dataset = CIFAR10(
datadir,
train=self.is_train,
download=self.download,
transform=transformation,
)

sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
if self.datadir is None:
raise ValueError("datadir is required if use_synthetic is False")

if self.is_train:
transformation = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
])
else:
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
])

dataset = CIFAR10(
self.datadir,
train=self.is_train,
download=self.download,
transform=transformation,
)
if self.subset_num_batches is not None:
size = batch_size * self.subset_num_batches * ddp.get_world_size()
dataset = get_subset_dataset(size, dataset)
sampler = ddp.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return dataloader_hparams.initialize_object(dataset,
batch_size=batch_size,
Expand Down
4 changes: 2 additions & 2 deletions composer/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def initialize_object(
dataset: Dataset,
*,
batch_size: int,
sampler: torch.utils.data.Sampler[int],
sampler: Optional[torch.utils.data.Sampler[int]],
drop_last: bool,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
Expand All @@ -123,7 +123,7 @@ def initialize_object(
Args:
dataset (Dataset): The dataset.
batch_size (int): The per-device batch size.
sampler (torch.utils.data.Sampler[int]): The sampler to use for the dataloader.
sampler (torch.utils.data.Sampler[int] or None): The sampler to use for the dataloader.
drop_last (bool): Whether to drop the last batch if the number of
samples is not evenly divisible by the batch size.
collate_fn (callable, optional): Custom collate function. Defaults to None.
Expand Down
Loading