diff --git a/trojanvision/models/nas/darts.py b/trojanvision/models/nas/darts.py index 0276f49c..ba77599e 100644 --- a/trojanvision/models/nas/darts.py +++ b/trojanvision/models/nas/darts.py @@ -18,6 +18,7 @@ from typing import Union from typing import TYPE_CHECKING from trojanzoo.utils.fim import KFAC, EKFAC +from trojanzoo.utils.model import ExponentialMovingAverage from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler import torch.utils.data diff --git a/trojanzoo/models.py b/trojanzoo/models.py index d4290d3f..088c8bc9 100644 --- a/trojanzoo/models.py +++ b/trojanzoo/models.py @@ -345,12 +345,12 @@ def define_optimizer( raise RuntimeError( f'Invalid warmup lr method "{lr_warmup_method}".' 'Only linear and constant are supported.') - lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + _lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[lr_warmup_epochs]) else: - lr_scheduler = main_lr_scheduler + _lr_scheduler = main_lr_scheduler return optimizer, _lr_scheduler # define loss function