Skip to content

Commit

Permalink
Timm support (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
A-Jacobson authored Feb 1, 2022
1 parent 257720a commit fa1b992
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 20 deletions.
2 changes: 2 additions & 0 deletions composer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from composer.models.resnet56_cifar10 import CIFARResNetHparams as CIFARResNetHparams
from composer.models.resnet101 import ResNet101 as ResNet101
from composer.models.resnet101 import ResNet101Hparams as ResNet101Hparams
from composer.models.timm import Timm as Timm
from composer.models.timm import TimmHparams as TimmHparams
from composer.models.transformer_shared import MosaicTransformer as MosaicTransformer
from composer.models.unet import UNet as UNet
from composer.models.unet import UnetHparams as UnetHparams
3 changes: 3 additions & 0 deletions composer/models/timm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright 2021 MosaicML. All Rights Reserved.
from composer.models.timm.model import Timm as Timm
from composer.models.timm.timm_hparams import TimmHparams as TimmHparams
46 changes: 46 additions & 0 deletions composer/models/timm/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2021 MosaicML. All Rights Reserved.
from typing import Optional

from composer.models.base import MosaicClassifier


class Timm(MosaicClassifier):
"""A wrapper around timm.create_model() used to create mosaic classifiers from timm models
Args:
model_name (str): timm model name e.g:'resnet50'list of models can be found at https://github.com/rwightman/pytorch-image-models
pretrained (bool): imagenet pretrained. default: False
num_classes (int): The number of classes. Needed for classification tasks. default: 1000
drop_rate (float): dropout rate. default: 0.0
drop_path_rate (float): drop path rate (model default if None). default: None
drop_block_rate (float): drop block rate (model default if None). default: None
global_pool (str): Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None. default: None
bn_momentum (float): BatchNorm momentum override (model default if not None). default: None
bn_eps (float): BatchNorm epsilon override (model default if not None). default: None
"""

def __init__(
self,
model_name: str,
pretrained: bool = False,
num_classes: int = 1000,
drop_rate: float = 0.0,
drop_path_rate: Optional[float] = None,
drop_block_rate: Optional[float] = None,
global_pool: Optional[str] = None,
bn_momentum: Optional[float] = None,
bn_eps: Optional[float] = None,
) -> None:
import timm

model = timm.create_model(
model_name=model_name,
pretrained=pretrained,
num_classes=num_classes,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
drop_block_rate=drop_block_rate,
global_pool=global_pool,
bn_momentum=bn_momentum,
bn_eps=bn_eps,
)
super().__init__(module=model)
42 changes: 42 additions & 0 deletions composer/models/timm/timm_hparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2021 MosaicML. All Rights Reserved.
from dataclasses import dataclass
from typing import Optional

import yahp as hp

from composer.models.model_hparams import ModelHparams
from composer.models.timm.model import Timm


@dataclass
class TimmHparams(ModelHparams):

model_name: str = hp.optional(
"timm model name e.g: 'resnet50', list of models can be found at https://github.com/rwightman/pytorch-image-models",
default=None,
)
pretrained: bool = hp.optional("imagenet pretrained", default=False)
num_classes: int = hp.optional("The number of classes. Needed for classification tasks", default=1000)
drop_rate: float = hp.optional("dropout rate", default=0.0)
drop_path_rate: Optional[float] = hp.optional("drop path rate (model default if None)", default=None)
drop_block_rate: Optional[float] = hp.optional("drop block rate (model default if None)", default=None)
global_pool: Optional[str] = hp.optional(
"Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.", default=None)
bn_momentum: Optional[float] = hp.optional("BatchNorm momentum override (model default if not None)", default=None)
bn_eps: Optional[float] = hp.optional("BatchNorm epsilon override (model default if not None)", default=None)

def validate(self):
if self.model_name is None:
import timm
raise ValueError(f"model must be one of {timm.models.list_models()}")

def initialize_object(self):
return Timm(model_name=self.model_name,
pretrained=self.pretrained,
num_classes=self.num_classes,
drop_rate=self.drop_rate,
drop_path_rate=self.drop_path_rate,
drop_block_rate=self.drop_block_rate,
global_pool=self.global_pool,
bn_momentum=self.bn_momentum,
bn_eps=self.bn_eps)
9 changes: 5 additions & 4 deletions composer/trainer/trainer_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from composer.datasets import DataloaderHparams
from composer.loggers import (BaseLoggerBackendHparams, FileLoggerBackendHparams, MosaicMLLoggerBackendHparams,
TQDMLoggerBackendHparams, WandBLoggerBackendHparams)
from composer.models import (BERTForClassificationHparams, BERTHparams, CIFARResNet9Hparams, CIFARResNetHparams,
DeepLabV3Hparams, EfficientNetB0Hparams, GPT2Hparams, MnistClassifierHparams, ModelHparams,
ResNet18Hparams, ResNet50Hparams, ResNet101Hparams, UnetHparams)
from composer.models.resnet20_cifar10.resnet20_cifar10_hparams import CIFARResNet20Hparams
from composer.models import (BERTForClassificationHparams, BERTHparams, CIFARResNet9Hparams, CIFARResNet20Hparams,
CIFARResNetHparams, DeepLabV3Hparams, EfficientNetB0Hparams, GPT2Hparams,
MnistClassifierHparams, ModelHparams, ResNet18Hparams, ResNet50Hparams, ResNet101Hparams,
TimmHparams, UnetHparams)
from composer.optim import (AdamHparams, AdamWHparams, DecoupledAdamWHparams, DecoupledSGDWHparams, OptimizerHparams,
RAdamHparams, RMSPropHparams, SchedulerHparams, SGDHparams, scheduler)
from composer.profiler import ProfilerHparams
Expand Down Expand Up @@ -73,6 +73,7 @@
"gpt2": GPT2Hparams,
"bert": BERTHparams,
"bert_classification": BERTForClassificationHparams,
"timm": TimmHparams
}

dataset_registry = {
Expand Down
56 changes: 56 additions & 0 deletions composer/yamls/models/timm_resnet50_imagenet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
train_dataset:
imagenet:
resize_size: -1
crop_size: 224
is_train: true
datadir: /datasets/ImageNet
shuffle: true
drop_last: true
val_dataset:
imagenet:
resize_size: 256
crop_size: 224
is_train: false
datadir: /datasets/ImageNet
shuffle: false
drop_last: false
optimizer:
decoupled_sgdw:
lr: 2.048
momentum: 0.875
weight_decay: 5.0e-4
dampening: 0
nesterov: false
schedulers:
- warmup:
warmup_iters: "8ep"
warmup_method: linear
warmup_factor: 0
verbose: false
interval: step
- cosine_decay:
T_max: "82ep"
eta_min: 0
verbose: false
interval: step
model:
timm:
model_name: 'resnet50'
num_classes: 1000
loggers:
- tqdm: {}
max_duration: 90ep
train_batch_size: 2048
eval_batch_size: 2048
seed: 17
device:
gpu: {}
dataloader:
pin_memory: true
timeout: 0
prefetch_factor: 2
persistent_workers: true
num_workers: 8
validate_every_n_epochs: 1
grad_accum: 1
precision: amp
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def package_files(directory: str):
'datasets>=1.14.0',
]

extra_deps['vision'] = ['timm>=0.5.4']

extra_deps['unet'] = [
'monai>=0.7.0',
'scikit-learn>=1.0.1',
Expand Down
11 changes: 10 additions & 1 deletion tests/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def walk_model_yamls():

def _configure_dataset_for_synthetic(dataset_hparams: DatasetHparams) -> None:
if not isinstance(dataset_hparams, SyntheticHparamsMixin):
pytest.xfail(f"{dataset_hparams.__class__.__name__} does not support synthetic data or num_total_batchjes")
pytest.xfail(f"{dataset_hparams.__class__.__name__} does not support synthetic data or num_total_batches")

assert isinstance(dataset_hparams, SyntheticHparamsMixin)

Expand All @@ -36,10 +36,19 @@ def _configure_dataset_for_synthetic(dataset_hparams: DatasetHparams) -> None:
class TestHparamsCreate:

def test_hparams_create(self, hparams_file: str):
if "timm" in hparams_file:
pytest.importorskip("timm")
if hparams_file in ["unet.yaml"]:
pytest.importorskip("monai")

hparams = TrainerHparams.create(hparams_file, cli_args=False)
assert isinstance(hparams, TrainerHparams)

def test_trainer_initialize(self, hparams_file: str):
if "timm" in hparams_file:
pytest.importorskip("timm")
if hparams_file in ["unet.yaml"]:
pytest.importorskip("monai")
hparams = TrainerHparams.create(hparams_file, cli_args=False)
hparams.dataloader.num_workers = 0
hparams.dataloader.persistent_workers = False
Expand Down
10 changes: 5 additions & 5 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ def get_model_algs(model_name: str) -> List[str]:
if is_image_model:
algs.remove("alibi")
if "alibi" in algs:
try:
import transformers
del transformers
except ImportError:
pytest.skip("Unable to import transformers; skipping alibi")
pytest.importorskip("transformers")
if model_name in ("unet", "gpt2_52m", "gpt2_83m", 'gpt2_125m'):
algs.remove("mixup")
algs.remove("cutmix")
Expand All @@ -41,6 +37,10 @@ def get_model_algs(model_name: str) -> List[str]:
@pytest.mark.parametrize('model_name', model_names)
@pytest.mark.timeout(15)
def test_load(model_name: str):
if 'timm' in model_name:
pytest.importorskip("timm")
if model_name in ['unet']:
pytest.importorskip("monai")
if model_name in ['deeplabv3_ade20k']:
pytest.skip(f"Model {model_name} requires GPU")

Expand Down
19 changes: 9 additions & 10 deletions tests/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@

import pytest

from composer.models import BaseMosaicModel, ModelHparams
from composer.models import ModelHparams
from composer.trainer.trainer_hparams import model_registry


@pytest.mark.parametrize("model_name", model_registry.keys())
def test_model_registry(model_name, request):
if model_name in ['timm']:
pytest.importorskip("timm")
if model_name in ['unet']:
pytest.importorskip("monai")

# TODO (Moin + Ravi): create dummy versions of these models to pass unit tests
if model_name in ['gpt2', 'bert', 'bert_classification']: # do not pull from HF model hub
request.applymarker(pytest.mark.xfail())
Expand All @@ -31,13 +36,7 @@ def test_model_registry(model_name, request):
if model_name == "deeplabv3":
model_hparams.is_backbone_pretrained = False

assert isinstance(model_hparams, ModelHparams)
if model_name == "timm":
model_hparams.model_name = "resnet18"

try:
# create the model object using the hparams
model = model_hparams.initialize_object()
assert isinstance(model, BaseMosaicModel)
except ModuleNotFoundError as e:
if model_name == "unet" and e.name == 'monai':
pytest.skip("Unet not installed -- skipping")
raise e
assert isinstance(model_hparams, ModelHparams)

0 comments on commit fa1b992

Please sign in to comment.