diff --git a/composer/models/__init__.py b/composer/models/__init__.py index 2a904229ec..51508124ab 100644 --- a/composer/models/__init__.py +++ b/composer/models/__init__.py @@ -27,6 +27,8 @@ from composer.models.resnet56_cifar10 import CIFARResNetHparams as CIFARResNetHparams from composer.models.resnet101 import ResNet101 as ResNet101 from composer.models.resnet101 import ResNet101Hparams as ResNet101Hparams +from composer.models.timm import Timm as Timm +from composer.models.timm import TimmHparams as TimmHparams from composer.models.transformer_shared import MosaicTransformer as MosaicTransformer from composer.models.unet import UNet as UNet from composer.models.unet import UnetHparams as UnetHparams diff --git a/composer/models/timm/__init__.py b/composer/models/timm/__init__.py new file mode 100644 index 0000000000..a21aeee6ec --- /dev/null +++ b/composer/models/timm/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2021 MosaicML. All Rights Reserved. +from composer.models.timm.model import Timm as Timm +from composer.models.timm.timm_hparams import TimmHparams as TimmHparams diff --git a/composer/models/timm/model.py b/composer/models/timm/model.py new file mode 100644 index 0000000000..162a860cf1 --- /dev/null +++ b/composer/models/timm/model.py @@ -0,0 +1,46 @@ +# Copyright 2021 MosaicML. All Rights Reserved. +from typing import Optional + +from composer.models.base import MosaicClassifier + + +class Timm(MosaicClassifier): + """A wrapper around timm.create_model() used to create mosaic classifiers from timm models + Args: + model_name (str): timm model name e.g:'resnet50'list of models can be found at https://github.com/rwightman/pytorch-image-models + pretrained (bool): imagenet pretrained. default: False + num_classes (int): The number of classes. Needed for classification tasks. default: 1000 + drop_rate (float): dropout rate. default: 0.0 + drop_path_rate (float): drop path rate (model default if None). default: None + drop_block_rate (float): drop block rate (model default if None). default: None + global_pool (str): Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None. default: None + bn_momentum (float): BatchNorm momentum override (model default if not None). default: None + bn_eps (float): BatchNorm epsilon override (model default if not None). default: None + """ + + def __init__( + self, + model_name: str, + pretrained: bool = False, + num_classes: int = 1000, + drop_rate: float = 0.0, + drop_path_rate: Optional[float] = None, + drop_block_rate: Optional[float] = None, + global_pool: Optional[str] = None, + bn_momentum: Optional[float] = None, + bn_eps: Optional[float] = None, + ) -> None: + import timm + + model = timm.create_model( + model_name=model_name, + pretrained=pretrained, + num_classes=num_classes, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + drop_block_rate=drop_block_rate, + global_pool=global_pool, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + ) + super().__init__(module=model) diff --git a/composer/models/timm/timm_hparams.py b/composer/models/timm/timm_hparams.py new file mode 100644 index 0000000000..0e4e0d08f4 --- /dev/null +++ b/composer/models/timm/timm_hparams.py @@ -0,0 +1,42 @@ +# Copyright 2021 MosaicML. All Rights Reserved. +from dataclasses import dataclass +from typing import Optional + +import yahp as hp + +from composer.models.model_hparams import ModelHparams +from composer.models.timm.model import Timm + + +@dataclass +class TimmHparams(ModelHparams): + + model_name: str = hp.optional( + "timm model name e.g: 'resnet50', list of models can be found at https://github.com/rwightman/pytorch-image-models", + default=None, + ) + pretrained: bool = hp.optional("imagenet pretrained", default=False) + num_classes: int = hp.optional("The number of classes. Needed for classification tasks", default=1000) + drop_rate: float = hp.optional("dropout rate", default=0.0) + drop_path_rate: Optional[float] = hp.optional("drop path rate (model default if None)", default=None) + drop_block_rate: Optional[float] = hp.optional("drop block rate (model default if None)", default=None) + global_pool: Optional[str] = hp.optional( + "Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.", default=None) + bn_momentum: Optional[float] = hp.optional("BatchNorm momentum override (model default if not None)", default=None) + bn_eps: Optional[float] = hp.optional("BatchNorm epsilon override (model default if not None)", default=None) + + def validate(self): + if self.model_name is None: + import timm + raise ValueError(f"model must be one of {timm.models.list_models()}") + + def initialize_object(self): + return Timm(model_name=self.model_name, + pretrained=self.pretrained, + num_classes=self.num_classes, + drop_rate=self.drop_rate, + drop_path_rate=self.drop_path_rate, + drop_block_rate=self.drop_block_rate, + global_pool=self.global_pool, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps) diff --git a/composer/trainer/trainer_hparams.py b/composer/trainer/trainer_hparams.py index 100c5eceb0..27f73fa4c5 100755 --- a/composer/trainer/trainer_hparams.py +++ b/composer/trainer/trainer_hparams.py @@ -21,10 +21,10 @@ from composer.datasets import DataloaderHparams from composer.loggers import (BaseLoggerBackendHparams, FileLoggerBackendHparams, MosaicMLLoggerBackendHparams, TQDMLoggerBackendHparams, WandBLoggerBackendHparams) -from composer.models import (BERTForClassificationHparams, BERTHparams, CIFARResNet9Hparams, CIFARResNetHparams, - DeepLabV3Hparams, EfficientNetB0Hparams, GPT2Hparams, MnistClassifierHparams, ModelHparams, - ResNet18Hparams, ResNet50Hparams, ResNet101Hparams, UnetHparams) -from composer.models.resnet20_cifar10.resnet20_cifar10_hparams import CIFARResNet20Hparams +from composer.models import (BERTForClassificationHparams, BERTHparams, CIFARResNet9Hparams, CIFARResNet20Hparams, + CIFARResNetHparams, DeepLabV3Hparams, EfficientNetB0Hparams, GPT2Hparams, + MnistClassifierHparams, ModelHparams, ResNet18Hparams, ResNet50Hparams, ResNet101Hparams, + TimmHparams, UnetHparams) from composer.optim import (AdamHparams, AdamWHparams, DecoupledAdamWHparams, DecoupledSGDWHparams, OptimizerHparams, RAdamHparams, RMSPropHparams, SchedulerHparams, SGDHparams, scheduler) from composer.profiler import ProfilerHparams @@ -73,6 +73,7 @@ "gpt2": GPT2Hparams, "bert": BERTHparams, "bert_classification": BERTForClassificationHparams, + "timm": TimmHparams } dataset_registry = { diff --git a/composer/yamls/models/timm_resnet50_imagenet.yaml b/composer/yamls/models/timm_resnet50_imagenet.yaml new file mode 100644 index 0000000000..5aff3e305e --- /dev/null +++ b/composer/yamls/models/timm_resnet50_imagenet.yaml @@ -0,0 +1,56 @@ +train_dataset: + imagenet: + resize_size: -1 + crop_size: 224 + is_train: true + datadir: /datasets/ImageNet + shuffle: true + drop_last: true +val_dataset: + imagenet: + resize_size: 256 + crop_size: 224 + is_train: false + datadir: /datasets/ImageNet + shuffle: false + drop_last: false +optimizer: + decoupled_sgdw: + lr: 2.048 + momentum: 0.875 + weight_decay: 5.0e-4 + dampening: 0 + nesterov: false +schedulers: + - warmup: + warmup_iters: "8ep" + warmup_method: linear + warmup_factor: 0 + verbose: false + interval: step + - cosine_decay: + T_max: "82ep" + eta_min: 0 + verbose: false + interval: step +model: + timm: + model_name: 'resnet50' + num_classes: 1000 +loggers: + - tqdm: {} +max_duration: 90ep +train_batch_size: 2048 +eval_batch_size: 2048 +seed: 17 +device: + gpu: {} +dataloader: + pin_memory: true + timeout: 0 + prefetch_factor: 2 + persistent_workers: true + num_workers: 8 +validate_every_n_epochs: 1 +grad_accum: 1 +precision: amp diff --git a/setup.py b/setup.py index 8b3f7dec2d..1b3574af38 100755 --- a/setup.py +++ b/setup.py @@ -92,6 +92,8 @@ def package_files(directory: str): 'datasets>=1.14.0', ] +extra_deps['vision'] = ['timm>=0.5.4'] + extra_deps['unet'] = [ 'monai>=0.7.0', 'scikit-learn>=1.0.1', diff --git a/tests/test_hparams.py b/tests/test_hparams.py index f644f9906f..8945ab0249 100644 --- a/tests/test_hparams.py +++ b/tests/test_hparams.py @@ -25,7 +25,7 @@ def walk_model_yamls(): def _configure_dataset_for_synthetic(dataset_hparams: DatasetHparams) -> None: if not isinstance(dataset_hparams, SyntheticHparamsMixin): - pytest.xfail(f"{dataset_hparams.__class__.__name__} does not support synthetic data or num_total_batchjes") + pytest.xfail(f"{dataset_hparams.__class__.__name__} does not support synthetic data or num_total_batches") assert isinstance(dataset_hparams, SyntheticHparamsMixin) @@ -36,10 +36,19 @@ def _configure_dataset_for_synthetic(dataset_hparams: DatasetHparams) -> None: class TestHparamsCreate: def test_hparams_create(self, hparams_file: str): + if "timm" in hparams_file: + pytest.importorskip("timm") + if hparams_file in ["unet.yaml"]: + pytest.importorskip("monai") + hparams = TrainerHparams.create(hparams_file, cli_args=False) assert isinstance(hparams, TrainerHparams) def test_trainer_initialize(self, hparams_file: str): + if "timm" in hparams_file: + pytest.importorskip("timm") + if hparams_file in ["unet.yaml"]: + pytest.importorskip("monai") hparams = TrainerHparams.create(hparams_file, cli_args=False) hparams.dataloader.num_workers = 0 hparams.dataloader.persistent_workers = False diff --git a/tests/test_load.py b/tests/test_load.py index d5c8dffc58..e4da24e7cc 100755 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -27,11 +27,7 @@ def get_model_algs(model_name: str) -> List[str]: if is_image_model: algs.remove("alibi") if "alibi" in algs: - try: - import transformers - del transformers - except ImportError: - pytest.skip("Unable to import transformers; skipping alibi") + pytest.importorskip("transformers") if model_name in ("unet", "gpt2_52m", "gpt2_83m", 'gpt2_125m'): algs.remove("mixup") algs.remove("cutmix") @@ -41,6 +37,10 @@ def get_model_algs(model_name: str) -> List[str]: @pytest.mark.parametrize('model_name', model_names) @pytest.mark.timeout(15) def test_load(model_name: str): + if 'timm' in model_name: + pytest.importorskip("timm") + if model_name in ['unet']: + pytest.importorskip("monai") if model_name in ['deeplabv3_ade20k']: pytest.skip(f"Model {model_name} requires GPU") diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index e63200f63f..37ccee8d89 100755 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -2,12 +2,17 @@ import pytest -from composer.models import BaseMosaicModel, ModelHparams +from composer.models import ModelHparams from composer.trainer.trainer_hparams import model_registry @pytest.mark.parametrize("model_name", model_registry.keys()) def test_model_registry(model_name, request): + if model_name in ['timm']: + pytest.importorskip("timm") + if model_name in ['unet']: + pytest.importorskip("monai") + # TODO (Moin + Ravi): create dummy versions of these models to pass unit tests if model_name in ['gpt2', 'bert', 'bert_classification']: # do not pull from HF model hub request.applymarker(pytest.mark.xfail()) @@ -31,13 +36,7 @@ def test_model_registry(model_name, request): if model_name == "deeplabv3": model_hparams.is_backbone_pretrained = False - assert isinstance(model_hparams, ModelHparams) + if model_name == "timm": + model_hparams.model_name = "resnet18" - try: - # create the model object using the hparams - model = model_hparams.initialize_object() - assert isinstance(model, BaseMosaicModel) - except ModuleNotFoundError as e: - if model_name == "unet" and e.name == 'monai': - pytest.skip("Unet not installed -- skipping") - raise e + assert isinstance(model_hparams, ModelHparams)