From 4132fd8789955143ef5958ccdf763d3653feb78b Mon Sep 17 00:00:00 2001 From: hanhuiyu1996 Date: Thu, 4 Jan 2024 20:44:14 +0800 Subject: [PATCH] support layer_decay in optim_factory --- config.py | 2 + mindcv/optim/optim_factory.py | 164 +++++++++++++++++++++++++++++++++- train.py | 3 +- 3 files changed, 167 insertions(+), 2 deletions(-) diff --git a/config.py b/config.py index 0b4028afe..06bbabef0 100644 --- a/config.py +++ b/config.py @@ -167,6 +167,8 @@ def create_parser(): help='Whether use clip grad (default=False)') group.add_argument('--clip_value', type=float, default=15.0, help='Clip value (default=15.0)') + group.add_argument('--layer_decay', type=float, default=None, + help='layer-wise learning rate decay (default: None)') group.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Accumulate the gradients of n batches before update.") diff --git a/mindcv/optim/optim_factory.py b/mindcv/optim/optim_factory.py index 5d8c14504..1e5d32b11 100644 --- a/mindcv/optim/optim_factory.py +++ b/mindcv/optim/optim_factory.py @@ -1,7 +1,11 @@ """ optim factory """ +import collections import logging import os -from typing import Optional +import re +from collections import defaultdict +from itertools import chain, islice +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union from mindspore import load_checkpoint, load_param_into_net, nn @@ -14,6 +18,8 @@ _logger = logging.getLogger(__name__) +MATCH_PREV_GROUP = [9] + def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay): if weight_decay_filter == "disable": @@ -37,6 +43,152 @@ def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay ] +def param_groups_layer_decay( + model: nn.Cell, + lr: Optional[float] = 1e-3, + weight_decay: float = 0.05, + no_weight_decay_list: Tuple[str] = (), + layer_decay: float = 0.75, +): + """ + Parameter groups for layer-wise lr decay & weight decay + """ + no_weight_decay_list = set(no_weight_decay_list) + param_group_names = {} # NOTE for debugging + param_groups = {} + if hasattr(model, "group_matcher"): + layer_map = group_with_matcher(model.trainable_params(), model.group_matcher(coarse=False), reverse=True) + else: + layer_map = _layer_map(model) + + num_layers = max(layer_map.values()) + 1 + layer_max = num_layers - 1 + layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) + + for name, param in model.parameters_and_names(): + if not param.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if param.ndim == 1 or name in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0.0 + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = layer_map.get(name, layer_max) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_groups: + this_scale = layer_scales[layer_id] + param_group_names[group_name] = { + "lr": [learning_rate * this_scale for learning_rate in lr], + "weight_decay": this_decay, + "param_names": [], + } + param_groups[group_name] = { + "lr": [learning_rate * this_scale for learning_rate in lr], + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["param_names"].append(name) + param_groups[group_name]["params"].append(param) + + return list(param_groups.values()) + + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects: Iterator[Tuple[str, Any]], group_matcher: Union[Dict, Callable], reverse: bool = False +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (_, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return (float("inf"),) # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return (ord,) + return tuple(ord) + + grouping = defaultdict(list) + for param in named_objects: + grouping[_get_grouping(param.name)].append(param.name) + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def _group(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def _layer_map(model, layers_per_group=12, num_groups=None): + def _in_head(n, hp): + if not hp: + return True + elif isinstance(hp, (tuple, list)): + return any([n.startswith(hpi) for hpi in hp]) + else: + return n.startswith(hp) + + # attention: need to add pretrained_cfg attr to model + head_prefix = getattr(model, "pretrained_cfg", {}).get("classifier", None) + names_trunk = [] + names_head = [] + for n, _ in model.parameters_and_names(): + names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) + + # group non-head layers + num_trunk_layers = len(names_trunk) + if num_groups is not None: + layers_per_group = -(num_trunk_layers // -num_groups) + names_trunk = list(_group(names_trunk, layers_per_group)) + num_trunk_groups = len(names_trunk) + layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} + layer_map.update({n: num_trunk_groups for n in names_head}) + return layer_map + + def create_optimizer( model_or_params, opt: str = "adam", @@ -45,6 +197,7 @@ def create_optimizer( momentum: float = 0.9, nesterov: bool = False, weight_decay_filter: str = "disable", + layer_decay: Optional[float] = None, loss_scale: float = 1.0, schedule_decay: float = 4e-3, checkpoint_path: str = "", @@ -95,6 +248,15 @@ def create_optimizer( "when creating an mindspore.nn.Optimizer instance. " "NOTE: mindspore.nn.Optimizer will filter Norm parmas from weight decay. " ) + elif layer_decay is not None and isinstance(model_or_params, nn.Cell): + params = param_groups_layer_decay( + model_or_params, + lr=lr, + weight_decay=weight_decay, + layer_decay=layer_decay, + no_weight_decay_list=no_weight_decay, + ) + weight_decay = 0.0 elif weight_decay_filter == "disable" or "norm_and_bias": params = init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay) weight_decay = 0.0 diff --git a/train.py b/train.py index 0406d90a1..62b637863 100644 --- a/train.py +++ b/train.py @@ -210,13 +210,14 @@ def main(): else: optimizer_loss_scale = 1.0 optimizer = create_optimizer( - network.trainable_params(), + network, opt=args.opt, lr=lr_scheduler, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=args.use_nesterov, weight_decay_filter=args.weight_decay_filter, + layer_decay=args.layer_decay, loss_scale=optimizer_loss_scale, checkpoint_path=opt_ckpt_path, eps=args.eps,