Skip to content

Commit

Permalink
Detection recipe enhancements (#5715)
Browse files Browse the repository at this point in the history
* Detection recipe enhancements

* Add back nesterov momentum
  • Loading branch information
datumbox authored Apr 1, 2022
1 parent ec1c2a1 commit d59398b
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 11 deletions.
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def main(args):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

if args.norm_weight_decay is None:
parameters = model.parameters()
parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
Expand Down
23 changes: 22 additions & 1 deletion references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class DetectionPresetTrain:
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
if data_augmentation == "hflip":
self.transforms = T.Compose(
[
Expand All @@ -12,6 +12,27 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "ssd":
self.transforms = T.Compose(
[
Expand Down
31 changes: 29 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def get_args_parser(add_help=True):
parser.add_argument(
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
)
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument(
"--lr",
default=0.02,
Expand All @@ -84,6 +85,12 @@ def get_args_parser(add_help=True):
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
)
Expand Down Expand Up @@ -176,6 +183,8 @@ def main(args):

print("Creating model")
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
if args.data_augmentation in ["multiscale", "lsj"]:
kwargs["_skip_resize"] = True
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
Expand All @@ -191,8 +200,26 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.norm_weight_decay is None:
parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]

opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")

scaler = torch.cuda.amp.GradScaler() if args.amp else None

Expand Down
1 change: 0 additions & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_get_weight(name, weight):
)
def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
print(weights_enum)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")

Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
**kwargs,
):

if not hasattr(backbone, "out_channels"):
Expand Down Expand Up @@ -268,7 +269,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

super().__init__(backbone, rpn, roi_heads, transform)

Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def __init__(
nms_thresh: float = 0.6,
detections_per_img: int = 100,
topk_candidates: int = 1000,
**kwargs,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -410,7 +411,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

self.center_sampling_radius = center_sampling_radius
self.score_thresh = score_thresh
Expand Down
2 changes: 2 additions & 0 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
keypoint_head=None,
keypoint_predictor=None,
num_keypoints=None,
**kwargs,
):

if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
Expand Down Expand Up @@ -259,6 +260,7 @@ def __init__(
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
**kwargs,
)

self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
Expand Down
2 changes: 2 additions & 0 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
**kwargs,
):

if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
Expand Down Expand Up @@ -254,6 +255,7 @@ def __init__(
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
**kwargs,
)

self.roi_heads.mask_roi_pool = mask_roi_pool
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def __init__(
fg_iou_thresh=0.5,
bg_iou_thresh=0.4,
topk_candidates=1000,
**kwargs,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -383,7 +384,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(
iou_thresh: float = 0.5,
topk_candidates: int = 400,
positive_fraction: float = 0.25,
**kwargs: Any,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -227,7 +228,7 @@ def __init__(
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
)

self.score_thresh = score_thresh
Expand Down
6 changes: 5 additions & 1 deletion torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional, Any

import torch
import torchvision
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(
image_std: List[float],
size_divisible: int = 32,
fixed_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
):
super().__init__()
if not isinstance(min_size, (list, tuple)):
Expand All @@ -101,6 +102,7 @@ def __init__(
self.image_std = image_std
self.size_divisible = size_divisible
self.fixed_size = fixed_size
self._skip_resize = kwargs.pop("_skip_resize", False)

def forward(
self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
Expand Down Expand Up @@ -170,6 +172,8 @@ def resize(
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
h, w = image.shape[-2:]
if self.training:
if self._skip_resize:
return image, target
size = float(self.torch_choice(self.min_size))
else:
# FIXME assume for now that testing uses the largest scale
Expand Down
8 changes: 7 additions & 1 deletion torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def split_normalization_params(
) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
norm_classes = [
nn.modules.batchnorm._BatchNorm,
nn.LayerNorm,
nn.GroupNorm,
nn.modules.instancenorm._InstanceNorm,
nn.LocalResponseNorm,
]

for t in norm_classes:
if not issubclass(t, nn.Module):
Expand Down

0 comments on commit d59398b

Please sign in to comment.