Skip to content

Commit

Permalink
support layer_decay in optim_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
sageyou committed Jan 9, 2024
1 parent a086e4e commit 4132fd8
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 2 deletions.
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def create_parser():
help='Whether use clip grad (default=False)')
group.add_argument('--clip_value', type=float, default=15.0,
help='Clip value (default=15.0)')
group.add_argument('--layer_decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
group.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Accumulate the gradients of n batches before update.")

Expand Down
164 changes: 163 additions & 1 deletion mindcv/optim/optim_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
""" optim factory """
import collections
import logging
import os
from typing import Optional
import re
from collections import defaultdict
from itertools import chain, islice
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union

from mindspore import load_checkpoint, load_param_into_net, nn

Expand All @@ -14,6 +18,8 @@

_logger = logging.getLogger(__name__)

MATCH_PREV_GROUP = [9]


def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay):
if weight_decay_filter == "disable":
Expand All @@ -37,6 +43,152 @@ def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay
]


def param_groups_layer_decay(
model: nn.Cell,
lr: Optional[float] = 1e-3,
weight_decay: float = 0.05,
no_weight_decay_list: Tuple[str] = (),
layer_decay: float = 0.75,
):
"""
Parameter groups for layer-wise lr decay & weight decay
"""
no_weight_decay_list = set(no_weight_decay_list)
param_group_names = {} # NOTE for debugging
param_groups = {}
if hasattr(model, "group_matcher"):
layer_map = group_with_matcher(model.trainable_params(), model.group_matcher(coarse=False), reverse=True)
else:
layer_map = _layer_map(model)

num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))

for name, param in model.parameters_and_names():
if not param.requires_grad:
continue

# no decay: all 1D parameters and model specific ones
if param.ndim == 1 or name in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.0
else:
g_decay = "decay"
this_decay = weight_decay

layer_id = layer_map.get(name, layer_max)
group_name = "layer_%d_%s" % (layer_id, g_decay)

if group_name not in param_groups:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr": [learning_rate * this_scale for learning_rate in lr],
"weight_decay": this_decay,
"param_names": [],
}
param_groups[group_name] = {
"lr": [learning_rate * this_scale for learning_rate in lr],
"weight_decay": this_decay,
"params": [],
}

param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param)

return list(param_groups.values())


MATCH_PREV_GROUP = (99999,)


def group_with_matcher(
named_objects: Iterator[Tuple[str, Any]], group_matcher: Union[Dict, Callable], reverse: bool = False
):
if isinstance(group_matcher, dict):
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
compiled = []
for group_ordinal, (_, mspec) in enumerate(group_matcher.items()):
if mspec is None:
continue
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
if isinstance(mspec, (tuple, list)):
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
for sspec in mspec:
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
else:
compiled += [(re.compile(mspec), (group_ordinal,), None)]
group_matcher = compiled

def _get_grouping(name):
if isinstance(group_matcher, (list, tuple)):
for match_fn, prefix, suffix in group_matcher:
r = match_fn.match(name)
if r:
parts = (prefix, r.groups(), suffix)
# map all tuple elem to int for numeric sort, filter out None entries
return tuple(map(float, chain.from_iterable(filter(None, parts))))
return (float("inf"),) # un-matched layers (neck, head) mapped to largest ordinal
else:
ord = group_matcher(name)
if not isinstance(ord, collections.abc.Iterable):
return (ord,)
return tuple(ord)

grouping = defaultdict(list)
for param in named_objects:
grouping[_get_grouping(param.name)].append(param.name)
# remap to integers
layer_id_to_param = defaultdict(list)
lid = -1
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
lid += 1
layer_id_to_param[lid].extend(grouping[k])

if reverse:
# output reverse mapping
param_to_layer_id = {}
for lid, lm in layer_id_to_param.items():
for n in lm:
param_to_layer_id[n] = lid
return param_to_layer_id

return layer_id_to_param


def _group(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())


def _layer_map(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp):
if not hp:
return True
elif isinstance(hp, (tuple, list)):
return any([n.startswith(hpi) for hpi in hp])
else:
return n.startswith(hp)

# attention: need to add pretrained_cfg attr to model
head_prefix = getattr(model, "pretrained_cfg", {}).get("classifier", None)
names_trunk = []
names_head = []
for n, _ in model.parameters_and_names():
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)

# group non-head layers
num_trunk_layers = len(names_trunk)
if num_groups is not None:
layers_per_group = -(num_trunk_layers // -num_groups)
names_trunk = list(_group(names_trunk, layers_per_group))
num_trunk_groups = len(names_trunk)
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map


def create_optimizer(
model_or_params,
opt: str = "adam",
Expand All @@ -45,6 +197,7 @@ def create_optimizer(
momentum: float = 0.9,
nesterov: bool = False,
weight_decay_filter: str = "disable",
layer_decay: Optional[float] = None,
loss_scale: float = 1.0,
schedule_decay: float = 4e-3,
checkpoint_path: str = "",
Expand Down Expand Up @@ -95,6 +248,15 @@ def create_optimizer(
"when creating an mindspore.nn.Optimizer instance. "
"NOTE: mindspore.nn.Optimizer will filter Norm parmas from weight decay. "
)
elif layer_decay is not None and isinstance(model_or_params, nn.Cell):
params = param_groups_layer_decay(
model_or_params,
lr=lr,
weight_decay=weight_decay,
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay,
)
weight_decay = 0.0
elif weight_decay_filter == "disable" or "norm_and_bias":
params = init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay)
weight_decay = 0.0
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,14 @@ def main():
else:
optimizer_loss_scale = 1.0
optimizer = create_optimizer(
network.trainable_params(),
network,
opt=args.opt,
lr=lr_scheduler,
weight_decay=args.weight_decay,
momentum=args.momentum,
nesterov=args.use_nesterov,
weight_decay_filter=args.weight_decay_filter,
layer_decay=args.layer_decay,
loss_scale=optimizer_loss_scale,
checkpoint_path=opt_ckpt_path,
eps=args.eps,
Expand Down

0 comments on commit 4132fd8

Please sign in to comment.