-
Notifications
You must be signed in to change notification settings - Fork 433
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
257720a
commit fa1b992
Showing
10 changed files
with
180 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
from composer.models.timm.model import Timm as Timm | ||
from composer.models.timm.timm_hparams import TimmHparams as TimmHparams |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
from typing import Optional | ||
|
||
from composer.models.base import MosaicClassifier | ||
|
||
|
||
class Timm(MosaicClassifier): | ||
"""A wrapper around timm.create_model() used to create mosaic classifiers from timm models | ||
Args: | ||
model_name (str): timm model name e.g:'resnet50'list of models can be found at https://github.com/rwightman/pytorch-image-models | ||
pretrained (bool): imagenet pretrained. default: False | ||
num_classes (int): The number of classes. Needed for classification tasks. default: 1000 | ||
drop_rate (float): dropout rate. default: 0.0 | ||
drop_path_rate (float): drop path rate (model default if None). default: None | ||
drop_block_rate (float): drop block rate (model default if None). default: None | ||
global_pool (str): Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None. default: None | ||
bn_momentum (float): BatchNorm momentum override (model default if not None). default: None | ||
bn_eps (float): BatchNorm epsilon override (model default if not None). default: None | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
pretrained: bool = False, | ||
num_classes: int = 1000, | ||
drop_rate: float = 0.0, | ||
drop_path_rate: Optional[float] = None, | ||
drop_block_rate: Optional[float] = None, | ||
global_pool: Optional[str] = None, | ||
bn_momentum: Optional[float] = None, | ||
bn_eps: Optional[float] = None, | ||
) -> None: | ||
import timm | ||
|
||
model = timm.create_model( | ||
model_name=model_name, | ||
pretrained=pretrained, | ||
num_classes=num_classes, | ||
drop_rate=drop_rate, | ||
drop_path_rate=drop_path_rate, | ||
drop_block_rate=drop_block_rate, | ||
global_pool=global_pool, | ||
bn_momentum=bn_momentum, | ||
bn_eps=bn_eps, | ||
) | ||
super().__init__(module=model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2021 MosaicML. All Rights Reserved. | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import yahp as hp | ||
|
||
from composer.models.model_hparams import ModelHparams | ||
from composer.models.timm.model import Timm | ||
|
||
|
||
@dataclass | ||
class TimmHparams(ModelHparams): | ||
|
||
model_name: str = hp.optional( | ||
"timm model name e.g: 'resnet50', list of models can be found at https://github.com/rwightman/pytorch-image-models", | ||
default=None, | ||
) | ||
pretrained: bool = hp.optional("imagenet pretrained", default=False) | ||
num_classes: int = hp.optional("The number of classes. Needed for classification tasks", default=1000) | ||
drop_rate: float = hp.optional("dropout rate", default=0.0) | ||
drop_path_rate: Optional[float] = hp.optional("drop path rate (model default if None)", default=None) | ||
drop_block_rate: Optional[float] = hp.optional("drop block rate (model default if None)", default=None) | ||
global_pool: Optional[str] = hp.optional( | ||
"Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.", default=None) | ||
bn_momentum: Optional[float] = hp.optional("BatchNorm momentum override (model default if not None)", default=None) | ||
bn_eps: Optional[float] = hp.optional("BatchNorm epsilon override (model default if not None)", default=None) | ||
|
||
def validate(self): | ||
if self.model_name is None: | ||
import timm | ||
raise ValueError(f"model must be one of {timm.models.list_models()}") | ||
|
||
def initialize_object(self): | ||
return Timm(model_name=self.model_name, | ||
pretrained=self.pretrained, | ||
num_classes=self.num_classes, | ||
drop_rate=self.drop_rate, | ||
drop_path_rate=self.drop_path_rate, | ||
drop_block_rate=self.drop_block_rate, | ||
global_pool=self.global_pool, | ||
bn_momentum=self.bn_momentum, | ||
bn_eps=self.bn_eps) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
train_dataset: | ||
imagenet: | ||
resize_size: -1 | ||
crop_size: 224 | ||
is_train: true | ||
datadir: /datasets/ImageNet | ||
shuffle: true | ||
drop_last: true | ||
val_dataset: | ||
imagenet: | ||
resize_size: 256 | ||
crop_size: 224 | ||
is_train: false | ||
datadir: /datasets/ImageNet | ||
shuffle: false | ||
drop_last: false | ||
optimizer: | ||
decoupled_sgdw: | ||
lr: 2.048 | ||
momentum: 0.875 | ||
weight_decay: 5.0e-4 | ||
dampening: 0 | ||
nesterov: false | ||
schedulers: | ||
- warmup: | ||
warmup_iters: "8ep" | ||
warmup_method: linear | ||
warmup_factor: 0 | ||
verbose: false | ||
interval: step | ||
- cosine_decay: | ||
T_max: "82ep" | ||
eta_min: 0 | ||
verbose: false | ||
interval: step | ||
model: | ||
timm: | ||
model_name: 'resnet50' | ||
num_classes: 1000 | ||
loggers: | ||
- tqdm: {} | ||
max_duration: 90ep | ||
train_batch_size: 2048 | ||
eval_batch_size: 2048 | ||
seed: 17 | ||
device: | ||
gpu: {} | ||
dataloader: | ||
pin_memory: true | ||
timeout: 0 | ||
prefetch_factor: 2 | ||
persistent_workers: true | ||
num_workers: 8 | ||
validate_every_n_epochs: 1 | ||
grad_accum: 1 | ||
precision: amp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters