Skip to content

Commit

Permalink
Streaming vision datasets (#284)
Browse files Browse the repository at this point in the history
Implements streaming for vision datasets.
  • Loading branch information
knighton authored Mar 8, 2022
1 parent 6116263 commit bb33b1a
Show file tree
Hide file tree
Showing 22 changed files with 1,185 additions and 93 deletions.
15 changes: 9 additions & 6 deletions composer/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
# Copyright 2021 MosaicML. All Rights Reserved.

"""Datasets TODO -- more description.
:class:`DataloaderHparams` contains the :class:`torch.utils.data.dataloader` settings that are common across both training and eval datasets:
* ``num_workers``
* ``prefetch_factor``
* ``persistent_workers``
* ``pin_memory``
* ``timeout``
Each :class:`DatasetHparams` is then responsible for settings such as:
* ``dataset``
* ``drop_last``
* ``shuffle``
* ``collate_fn``
A :class:`DatasetHparams` is responsible for returning a :class:`torch.utils.data.dataloader` or a :class:`DataloaderSpec`.
"""
from composer.datasets.ade20k import ADE20kDatasetHparams as ADE20kDatasetHparams
from composer.datasets.ade20k import ADE20kWebDatasetHparams as ADE20kWebDatasetHparams
from composer.datasets.brats import BratsDatasetHparams as BratsDatasetHparams
from composer.datasets.c4 import C4DatasetHparams as C4DatasetHparams
from composer.datasets.cifar10 import CIFAR10DatasetHparams as CIFAR10DatasetHparams
from composer.datasets.cifar import CIFAR10DatasetHparams as CIFAR10DatasetHparams
from composer.datasets.cifar import CIFAR10WebDatasetHparams as CIFAR10WebDatasetHparams
from composer.datasets.cifar import CIFAR20WebDatasetHparams as CIFAR20WebDatasetHparams
from composer.datasets.cifar import CIFAR100WebDatasetHparams as CIFAR100WebDatasetHparams
from composer.datasets.coco import COCODatasetHparams as COCODatasetHparams
from composer.datasets.dataloader import DataloaderHparams as DataloaderHparams
from composer.datasets.dataloader import WrappedDataLoader as WrappedDataLoader
Expand All @@ -31,9 +30,13 @@
from composer.datasets.glue import GLUEHparams as GLUEHparams
from composer.datasets.hparams import DatasetHparams as DatasetHparams
from composer.datasets.hparams import SyntheticHparamsMixin as SyntheticHparamsMixin
from composer.datasets.hparams import WebDatasetHparams as WebDatasetHparams
from composer.datasets.imagenet import Imagenet1kWebDatasetHparams as Imagenet1kWebDatasetHparams
from composer.datasets.imagenet import ImagenetDatasetHparams as ImagenetDatasetHparams
from composer.datasets.imagenet import TinyImagenet200WebDatasetHparams as TinyImagenet200WebDatasetHparams
from composer.datasets.lm_datasets import LMDatasetHparams as LMDatasetHparams
from composer.datasets.mnist import MNISTDatasetHparams as MNISTDatasetHparams
from composer.datasets.mnist import MNISTWebDatasetHparams as MNISTWebDatasetHparams
from composer.datasets.synthetic import MemoryFormat as MemoryFormat
from composer.datasets.synthetic import SyntheticBatchPairDataset as SyntheticBatchPairDataset
from composer.datasets.synthetic import SyntheticDataLabelType as SyntheticDataLabelType
Expand Down
107 changes: 105 additions & 2 deletions composer/datasets/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchvision import transforms

from composer.core.types import DataSpec
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin, WebDatasetHparams
from composer.datasets.imagenet import IMAGENET_CHANNEL_MEAN, IMAGENET_CHANNEL_STD
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.datasets.utils import NormalizationFn, pil_image_collate
Expand Down Expand Up @@ -310,7 +310,8 @@ def initialize_object(self, batch_size, dataloader_hparams) -> DataSpec:
RandomHFlipPair(),
)

# Photometric distoration values come from mmsegmentation: https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/transforms.py#L837
# Photometric distoration values come from mmsegmentation:
# https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/transforms.py#L837
r_mean, g_mean, b_mean = IMAGENET_CHANNEL_MEAN
image_transforms = torch.nn.Sequential(
PhotometricDistoration(brightness=32. / 255, contrast=0.5, saturation=0.5, hue=18. / 255),
Expand Down Expand Up @@ -344,3 +345,105 @@ def initialize_object(self, batch_size, dataloader_hparams) -> DataSpec:
collate_fn=collate_fn,
drop_last=self.drop_last),
device_transforms=device_transform_fn)


@dataclass
class ADE20kWebDatasetHparams(WebDatasetHparams):
"""Defines an instance of the ADE20k dataset for semantic segmentation.
Parameters:
remote (str): S3 bucket or root directory where dataset is stored.
name (str): Key used to determine where dataset is cached on local filesystem.
split (str): the dataset split to use either 'train', 'val', or 'test'. Default is `train`.
base_size (int): initial size of the image and target before other augmentations. Default is 512.
min_resize_scale (float): the minimum value the samples can be rescaled. Default is 0.5.
max_resize_scale (float): the maximum value the samples can be rescaled. Default is 2.0.
final_size (int): the final size of the image and target. Default is 512.
ignore_background (bool): if true, ignore the background class when calculating the training loss.
Default is true.
"""

remote: str = hp.optional('WebDataset S3 bucket name', default='s3://mosaicml-internal-dataset-ade20k')
name: str = hp.optional('WebDataset local cache name', default='ade20k')
split: str = hp.optional("Which split of the dataset to use. Either ['train', 'val', 'test']", default='train')
base_size: int = hp.optional("Initial size of the image and target before other augmentations", default=512)
min_resize_scale: float = hp.optional("Minimum value that the image and target can be scaled", default=0.5)
max_resize_scale: float = hp.optional("Maximum value that the image and target can be scaled", default=2.0)
final_size: int = hp.optional("Final size of the image and target", default=512)
ignore_background: bool = hp.optional("If true, ignore the background class in training loss", default=True)

def validate(self):
if self.split not in ['train', 'val', 'test']:
raise ValueError(f"split value {self.split} must be one of ['train', 'val', 'test'].")

if self.base_size <= 0:
raise ValueError("base_size cannot be zero or negative.")

if self.min_resize_scale <= 0:
raise ValueError("min_resize_scale cannot be zero or negative")

if self.max_resize_scale < self.min_resize_scale:
raise ValueError("max_resize_scale cannot be less than min_resize_scale")

def initialize_object(self, batch_size, dataloader_hparams) -> DataSpec:
from composer.datasets.webdataset import load_webdataset

self.validate()
# Define data transformations based on data split
if self.split == 'train':
both_transforms = torch.nn.Sequential(
RandomResizePair(min_scale=self.min_resize_scale,
max_scale=self.max_resize_scale,
base_size=(self.base_size, self.base_size)),
RandomCropPair(crop_size=(self.final_size, self.final_size)),
RandomHFlipPair(),
)

# Photometric distoration values come from mmsegmentation:
# https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/datasets/pipelines/transforms.py#L837
r_mean, g_mean, b_mean = IMAGENET_CHANNEL_MEAN
image_transforms = torch.nn.Sequential(
PhotometricDistoration(brightness=32. / 255, contrast=0.5, saturation=0.5, hue=18. / 255),
PadToSize(size=(self.final_size, self.final_size), fill=(int(r_mean), int(g_mean), int(b_mean))))

target_transforms = transforms.Compose([
PadToSize(size=(self.final_size, self.final_size), fill=0),
transforms.Grayscale(),
])
else:
both_transforms = None
image_transforms = transforms.Resize(size=(self.final_size, self.final_size),
interpolation=TF.InterpolationMode.BILINEAR)
target_transforms = transforms.Compose([
transforms.Resize(size=(self.final_size, self.final_size), interpolation=TF.InterpolationMode.NEAREST),
transforms.Grayscale(),
])

def map_fn(args):
x, y = args
if both_transforms:
x, y = both_transforms((x, y))
if image_transforms:
x = image_transforms(x)
if target_transforms:
y = target_transforms(y)
return x, y

preprocess = lambda dataset: dataset.decode('pil').to_tuple('scene.jpg', 'annotation.png').map(map_fn)
dataset = load_webdataset(self.remote, self.name, self.split, self.webdataset_cache_dir,
self.webdataset_cache_verbose, self.shuffle, self.shuffle_buffer, preprocess,
dist.get_world_size(), dataloader_hparams.num_workers, batch_size, self.drop_last)

collate_fn = pil_image_collate
device_transform_fn = NormalizationFn(mean=IMAGENET_CHANNEL_MEAN,
std=IMAGENET_CHANNEL_STD,
ignore_background=self.ignore_background)

return DataSpec(dataloader=dataloader_hparams.initialize_object(
dataset=dataset,
batch_size=batch_size,
sampler=None,
drop_last=self.drop_last,
collate_fn=collate_fn,
),
device_transforms=device_transform_fn)
165 changes: 165 additions & 0 deletions composer/datasets/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from dataclasses import dataclass
from typing import List

import yahp as hp
from torchvision import transforms
from torchvision.datasets import CIFAR10

from composer.core.types import DataLoader
from composer.datasets.dataloader import DataloaderHparams
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin, WebDatasetHparams
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.utils import dist


@dataclass
class CIFAR10DatasetHparams(DatasetHparams, SyntheticHparamsMixin):
"""Defines an instance of the CIFAR-10 dataset for image classification.
Parameters:
download (bool): Whether to download the dataset, if needed.
"""
download: bool = hp.optional("whether to download the dataset, if needed", default=True)

def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataLoader:
cifar10_mean = 0.4914, 0.4822, 0.4465
cifar10_std = 0.247, 0.243, 0.261

if self.use_synthetic:
total_dataset_size = 50_000 if self.is_train else 10_000
dataset = SyntheticBatchPairDataset(
total_dataset_size=total_dataset_size,
data_shape=[3, 32, 32],
num_classes=10,
num_unique_samples_to_create=self.synthetic_num_unique_samples,
device=self.synthetic_device,
memory_format=self.synthetic_memory_format,
)

else:
if self.datadir is None:
raise ValueError("datadir is required if use_synthetic is False")

if self.is_train:
transformation = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std),
])
else:
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std),
])

dataset = CIFAR10(
self.datadir,
train=self.is_train,
download=self.download,
transform=transformation,
)
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)

return dataloader_hparams.initialize_object(dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=self.drop_last)


@dataclass
class CIFARWebDatasetHparams(WebDatasetHparams):
"""Common functionality for CIFAR WebDatasets.
Parameters:
remote (str): S3 bucket or root directory where dataset is stored.
name (str): Key used to determine where dataset is cached on local filesystem.
n_train_samples (int): Number of training samples.
n_val_samples (int): Number of validation samples.
height (int): Sample image height in pixels.
width (int): Sample image width in pixels.
n_classes (int): Number of output classes.
channel_means (list of float): Channel means for normalization.
channel_stds (list of float): Channel stds for normalization.
"""

remote: str = hp.optional('WebDataset S3 bucket name', default='')
name: str = hp.optional('WebDataset local cache name', default='')

n_train_samples: int = hp.optional('Number of samples in training split', default=0)
n_val_samples: int = hp.optional('Number of samples in validation split', default=0)
height: int = hp.optional('Image height', default=32)
width: int = hp.optional('Image width', default=32)
n_classes: int = hp.optional('Number of output classes', default=0)
channel_means: List[float] = hp.optional('Mean per image channel', default=(0, 0, 0))
channel_stds: List[float] = hp.optional('Std per image channel', default=(0, 0, 0))

def initialize_object(self, batch_size: int, dataloader_hparams: DataloaderHparams) -> DataLoader:
from composer.datasets.webdataset import load_webdataset

if self.is_train:
split = 'train'
transform = transforms.Compose([
transforms.RandomCrop((self.height, self.width), (self.height // 8, self.width // 8)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(self.channel_means, self.channel_stds),
])
else:
split = 'val'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(self.channel_means, self.channel_stds),
])
preprocess = lambda dataset: dataset.decode('pil').map_dict(jpg=transform).to_tuple('jpg', 'cls')
dataset = load_webdataset(self.remote, self.name, split, self.webdataset_cache_dir,
self.webdataset_cache_verbose, self.shuffle, self.shuffle_buffer, preprocess,
dist.get_world_size(), dataloader_hparams.num_workers, batch_size, self.drop_last)
return dataloader_hparams.initialize_object(dataset,
batch_size=batch_size,
sampler=None,
drop_last=self.drop_last)


@dataclass
class CIFAR10WebDatasetHparams(CIFARWebDatasetHparams):
"""Defines an instance of the CIFAR-10 WebDataset for image classification."""

remote: str = hp.optional('WebDataset S3 bucket name', default='s3://mosaicml-internal-dataset-cifar10')
name: str = hp.optional('WebDataset local cache name', default='cifar10')

n_train_samples: int = hp.optional('Number of samples in training split', default=50_000)
n_val_samples: int = hp.optional('Number of samples in validation split', default=10_000)
n_classes: int = hp.optional('Number of output classes', default=10)
channel_means: List[float] = hp.optional('Mean per image channel', default=(0.4914, 0.4822, 0.4465))
channel_stds: List[float] = hp.optional('Std per image channel', default=(0.247, 0.243, 0.261))


@dataclass
class CIFAR20WebDatasetHparams(CIFARWebDatasetHparams):
"""Defines an instance of the CIFAR-20 WebDataset for image classification."""

remote: str = hp.optional('WebDataset S3 bucket name', default='s3://mosaicml-internal-dataset-cifar20')
name: str = hp.optional('WebDataset local cache name', default='cifar20')

n_train_samples: int = hp.optional('Number of samples in training split', default=50_000)
n_val_samples: int = hp.optional('Number of samples in validation split', default=10_000)
n_classes: int = hp.optional('Number of output classes', default=20)
channel_means: List[float] = hp.optional('Mean per image channel', default=(0.5071, 0.4867, 0.4408))
channel_stds: List[float] = hp.optional('Std per image channel', default=(0.2675, 0.2565, 0.2761))


@dataclass
class CIFAR100WebDatasetHparams(CIFARWebDatasetHparams):
"""Defines an instance of the CIFAR-100 WebDataset for image classification."""

remote: str = hp.optional('WebDataset S3 bucket name', default='s3://mosaicml-internal-dataset-cifar100')
name: str = hp.optional('WebDataset local cache name', default='cifar100')

n_train_samples: int = hp.optional('Number of samples in training split', default=50_000)
n_val_samples: int = hp.optional('Number of samples in validation split', default=10_000)
n_classes: int = hp.optional('Number of output classes', default=100)
channel_means: List[float] = hp.optional('Mean per image channel', default=(0.5071, 0.4867, 0.4408))
channel_stds: List[float] = hp.optional('Std per image channel', default=(0.2675, 0.2565, 0.2761))
Loading

0 comments on commit bb33b1a

Please sign in to comment.