From ca079463dab11a5fef74292833638e3076a0eca4 Mon Sep 17 00:00:00 2001 From: The-truthh <821372701@qq.com> Date: Mon, 17 Jul 2023 18:59:29 +0800 Subject: [PATCH] feat: support distillation strategy --- config.py | 14 ++++++ mindcv/utils/__init__.py | 1 + mindcv/utils/distill_loss_cell.py | 83 +++++++++++++++++++++++++++++++ mindcv/utils/trainer_factory.py | 19 ++++++- train.py | 16 ++++++ 5 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 mindcv/utils/distill_loss_cell.py diff --git a/config.py b/config.py index b78e5a9b8..9b5742caa 100644 --- a/config.py +++ b/config.py @@ -256,6 +256,20 @@ def create_parser(): group.add_argument('--drop_overflow_update', type=bool, default=False, help='Whether to execute optimizer if there is an overflow (default=False)') + # distillation + group = parser.add_argument_group('Distillation parameters') + group.add_argument('--distillation_type', type=str, default=None, + choices=['hard', 'soft'], + help='The type of distillation (default=None)') + group.add_argument('--teacher_model', type=str, default=None, + help='Name of teacher model (default=None)') + group.add_argument('--teacher_ckpt_path', type=str, default='', + help='Initialize teacher model from this checkpoint. ' + 'If resume training, specify the checkpoint path (default="").') + group.add_argument('--distillation_alpha', type=float, default=0.5, + help='The coefficient balancing the distillation loss and base loss' + '(default=0.5)') + # modelarts group = parser.add_argument_group('modelarts') group.add_argument('--enable_modelarts', type=str2bool, nargs='?', const=True, default=False, diff --git a/mindcv/utils/__init__.py b/mindcv/utils/__init__.py index 39b346e04..8b1ce6d63 100644 --- a/mindcv/utils/__init__.py +++ b/mindcv/utils/__init__.py @@ -2,6 +2,7 @@ from .amp import * from .callbacks import * from .checkpoint_manager import * +from .distill_loss_cell import * from .download import * from .logger import * from .path import * diff --git a/mindcv/utils/distill_loss_cell.py b/mindcv/utils/distill_loss_cell.py new file mode 100644 index 000000000..b590504bc --- /dev/null +++ b/mindcv/utils/distill_loss_cell.py @@ -0,0 +1,83 @@ +""" distillation loss cell define """ +from mindspore import nn +from mindspore.ops import functional as F + + +class HardDistillLossCell(nn.WithLossCell): + """ + Wraps the network with hard distillation loss function. + + Get the loss of student network and an extra knowledge hard distillation loss + by taking a teacher model prediction and using it as additional supervision. + + Args: + backbone (Cell): The student network to train and calculate base loss. + loss_fn (Cell): The loss function used to compute loss of student network. + teacher_model (Cell): The teacher network to calculate distillation loss. + alpha (float): Distillation factor. the coefficient to balance the distillation + loss and base loss. Default: 0.5. + """ + + def __init__(self, backbone, loss_fn, teacher_model, alpha=0.5): + super().__init__(backbone, loss_fn) + self.teacher_model = teacher_model + self.alpha = alpha + + def construct(self, data, label): + out = self._backbone(data) + + out, out_kd = out + base_loss = self._loss_fn(out, label) + + teacher_out = self.teacher_model(data) + + distillation_loss = F.cross_entropy(out_kd, teacher_out.argmax(axis=1)) + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + + return loss + + +class SoftDistillLossCell(nn.WithLossCell): + """ + Wraps the network with soft distillation loss function. + + Get the loss of student network and an extra knowledge soft distillation loss + by taking a teacher model prediction and using it as additional supervision. + + Args: + backbone (Cell): The student network to train and calculate base loss. + loss_fn (Cell): The loss function used to compute loss of student network. + teacher_model (Cell): The teacher network to calculate distillation loss. + alpha (float): Distillation factor. the coefficient balancing the distillation + loss and base loss. Default: 0.5. + tau (float): Distillation temperature. The higher the temperature, the lower the + dispersion of the loss calculated by Kullback-Leibler divergence loss. Default: 1.0. + """ + + def __init__(self, backbone, loss_fn, teacher_model, alpha=0.5, tau=1.0): + super().__init__(backbone, loss_fn) + self.teacher_model = teacher_model + self.alpha = alpha + self.tau = tau + + def construct(self, data, label): + out = self._backbone(data) + + out, out_kd = out + base_loss = self._loss_fn(out, label) + + teacher_out = self.teacher_model(data) + + T = self.tau + distillation_loss = ( + F.kl_div( + F.log_softmax(out_kd / T, axis=1), + F.log_softmax(teacher_out / T, axis=1), + reduction="sum", + ) + * (T * T) + / F.size(out_kd) + ) + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + + return loss diff --git a/mindcv/utils/trainer_factory.py b/mindcv/utils/trainer_factory.py index db47a48e6..ad5273ef8 100644 --- a/mindcv/utils/trainer_factory.py +++ b/mindcv/utils/trainer_factory.py @@ -9,6 +9,7 @@ from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model from .amp import auto_mixed_precision +from .distill_loss_cell import HardDistillLossCell, SoftDistillLossCell from .train_step import TrainStep __all__ = [ @@ -38,6 +39,7 @@ def require_customized_train_step( clip_grad: bool = False, gradient_accumulation_steps: int = 1, amp_cast_list: Optional[str] = None, + distillation_type: Optional[str] = None, ): if ema: return True @@ -47,6 +49,8 @@ def require_customized_train_step( return True if amp_cast_list: return True + if distillation_type: + return True return False @@ -88,6 +92,9 @@ def create_trainer( clip_grad: bool = False, clip_value: float = 15.0, gradient_accumulation_steps: int = 1, + distillation_type: Optional[str] = None, + teacher_network: Optional[nn.Cell] = None, + distillation_alpha: float = 0.5, ): """Create Trainer. @@ -120,7 +127,7 @@ def create_trainer( if gradient_accumulation_steps < 1: raise ValueError("`gradient_accumulation_steps` must be >= 1!") - if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list): + if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list, distillation_type): mindspore_kwargs = dict( network=network, loss_fn=loss, @@ -149,7 +156,15 @@ def create_trainer( else: # require customized train step eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"]) auto_mixed_precision(network, amp_level, amp_cast_list) - net_with_loss = add_loss_network(network, loss, amp_level) + if distillation_type: + if distillation_type == "hard": + net_with_loss = HardDistillLossCell(network, loss, teacher_network, distillation_alpha) + elif distillation_type == "soft": + net_with_loss = SoftDistillLossCell(network, loss, teacher_network, distillation_alpha) + else: + raise ValueError(f"Distillation type only support ['hard', 'soft'], but got {distillation_type}.") + else: + net_with_loss = add_loss_network(network, loss, amp_level) train_step_kwargs = dict( network=net_with_loss, optimizer=optimizer, diff --git a/train.py b/train.py index 644948a54..6034020b7 100644 --- a/train.py +++ b/train.py @@ -180,6 +180,18 @@ def train(args): aux_factor=args.aux_factor, ) + # create teacher model + teacher_network = None + if args.distillation_type: + if not args.teacher_ckpt_path: + logger.warning("You are using distillation, but your teacher model has not loaded weights.") + teacher_network = create_model( + model_name=args.teacher_model, + num_classes=num_classes, + checkpoint_path=args.teacher_ckpt_path, + ) + teacher_network.set_train(False) + # create learning rate schedule lr_scheduler = create_scheduler( num_batches, @@ -213,6 +225,7 @@ def train(args): args.clip_grad, args.gradient_accumulation_steps, args.amp_cast_list, + args.distillation_type, ) ): optimizer_loss_scale = args.loss_scale @@ -250,6 +263,9 @@ def train(args): clip_grad=args.clip_grad, clip_value=args.clip_value, gradient_accumulation_steps=args.gradient_accumulation_steps, + distillation_type=args.distillation_type, + teacher_network=teacher_network, + distillation_alpha=args.distillation_alpha, ) # callback