Skip to content

Commit

Permalink
add get no_weight_decay layer form model when filter layers from weig…
Browse files Browse the repository at this point in the history
…ht decay in optim_factory
  • Loading branch information
sageyou committed Jan 4, 2024
1 parent b63150f commit a086e4e
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions mindcv/optim/optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_logger = logging.getLogger(__name__)


def init_group_params(params, weight_decay, weight_decay_filter):
def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay):
if weight_decay_filter == "disable":
return [
{"params": params, "weight_decay": weight_decay},
Expand All @@ -24,11 +24,12 @@ def init_group_params(params, weight_decay, weight_decay_filter):

decay_params = []
no_decay_params = []
no_weight_decay = set(no_weight_decay)
for param in params:
if "beta" not in param.name and "gamma" not in param.name and "bias" not in param.name:
decay_params.append(param)
else:
if "beta" in param.name or "gamma" in param.name or "bias" in param.name or param.name in no_weight_decay:
no_decay_params.append(param)
else:
decay_params.append(param)
return [
{"params": decay_params, "weight_decay": weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
Expand All @@ -37,7 +38,7 @@ def init_group_params(params, weight_decay, weight_decay_filter):


def create_optimizer(
params,
model_or_params,
opt: str = "adam",
lr: Optional[float] = 1e-3,
weight_decay: float = 0,
Expand Down Expand Up @@ -78,21 +79,31 @@ def create_optimizer(
Optimizer object
"""

opt = opt.lower()
no_weight_decay = {}
if isinstance(model_or_params, nn.Cell):
# a model was passed in, extract parameters and add weight decays to appropriate layers
if hasattr(model_or_params, "no_weight_decay"):
no_weight_decay = model_or_params.no_weight_decay()
params = model_or_params.trainable_params()

else:
params = model_or_params

if weight_decay_filter == "auto":
_logger.warning(
"You are using AUTO weight decay filter, which means the weight decay filter isn't explicitly pass in "
"when creating an mindspore.nn.Optimizer instance. "
"NOTE: mindspore.nn.Optimizer will filter Norm parmas from weight decay. "
)
elif weight_decay_filter == "disable" or "norm_and_bias":
params = init_group_params(params, weight_decay, weight_decay_filter)
params = init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay)
weight_decay = 0.0
else:
raise ValueError(
f"weight decay filter only support ['disable', 'auto', 'norm_and_bias'], but got{weight_decay_filter}."
)

opt = opt.lower()
opt_args = dict(**kwargs)
# if lr is not None:
# opt_args.setdefault('lr', lr)
Expand Down

0 comments on commit a086e4e

Please sign in to comment.