diff --git a/mindcv/optim/optim_factory.py b/mindcv/optim/optim_factory.py index 9d5e54ccc..5d8c14504 100644 --- a/mindcv/optim/optim_factory.py +++ b/mindcv/optim/optim_factory.py @@ -15,7 +15,7 @@ _logger = logging.getLogger(__name__) -def init_group_params(params, weight_decay, weight_decay_filter): +def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay): if weight_decay_filter == "disable": return [ {"params": params, "weight_decay": weight_decay}, @@ -24,11 +24,12 @@ def init_group_params(params, weight_decay, weight_decay_filter): decay_params = [] no_decay_params = [] + no_weight_decay = set(no_weight_decay) for param in params: - if "beta" not in param.name and "gamma" not in param.name and "bias" not in param.name: - decay_params.append(param) - else: + if "beta" in param.name or "gamma" in param.name or "bias" in param.name or param.name in no_weight_decay: no_decay_params.append(param) + else: + decay_params.append(param) return [ {"params": decay_params, "weight_decay": weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, @@ -37,7 +38,7 @@ def init_group_params(params, weight_decay, weight_decay_filter): def create_optimizer( - params, + model_or_params, opt: str = "adam", lr: Optional[float] = 1e-3, weight_decay: float = 0, @@ -78,7 +79,16 @@ def create_optimizer( Optimizer object """ - opt = opt.lower() + no_weight_decay = {} + if isinstance(model_or_params, nn.Cell): + # a model was passed in, extract parameters and add weight decays to appropriate layers + if hasattr(model_or_params, "no_weight_decay"): + no_weight_decay = model_or_params.no_weight_decay() + params = model_or_params.trainable_params() + + else: + params = model_or_params + if weight_decay_filter == "auto": _logger.warning( "You are using AUTO weight decay filter, which means the weight decay filter isn't explicitly pass in " @@ -86,13 +96,14 @@ def create_optimizer( "NOTE: mindspore.nn.Optimizer will filter Norm parmas from weight decay. " ) elif weight_decay_filter == "disable" or "norm_and_bias": - params = init_group_params(params, weight_decay, weight_decay_filter) + params = init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay) weight_decay = 0.0 else: raise ValueError( f"weight decay filter only support ['disable', 'auto', 'norm_and_bias'], but got{weight_decay_filter}." ) + opt = opt.lower() opt_args = dict(**kwargs) # if lr is not None: # opt_args.setdefault('lr', lr)