Skip to content

Commit

Permalink
feat: support distillation strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
The-truthh committed Jul 19, 2023
1 parent 85450da commit ca07946
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 2 deletions.
14 changes: 14 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,20 @@ def create_parser():
group.add_argument('--drop_overflow_update', type=bool, default=False,
help='Whether to execute optimizer if there is an overflow (default=False)')

# distillation
group = parser.add_argument_group('Distillation parameters')
group.add_argument('--distillation_type', type=str, default=None,
choices=['hard', 'soft'],
help='The type of distillation (default=None)')
group.add_argument('--teacher_model', type=str, default=None,
help='Name of teacher model (default=None)')
group.add_argument('--teacher_ckpt_path', type=str, default='',
help='Initialize teacher model from this checkpoint. '
'If resume training, specify the checkpoint path (default="").')
group.add_argument('--distillation_alpha', type=float, default=0.5,
help='The coefficient balancing the distillation loss and base loss'
'(default=0.5)')

# modelarts
group = parser.add_argument_group('modelarts')
group.add_argument('--enable_modelarts', type=str2bool, nargs='?', const=True, default=False,
Expand Down
1 change: 1 addition & 0 deletions mindcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .amp import *
from .callbacks import *
from .checkpoint_manager import *
from .distill_loss_cell import *
from .download import *
from .logger import *
from .path import *
Expand Down
83 changes: 83 additions & 0 deletions mindcv/utils/distill_loss_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
""" distillation loss cell define """
from mindspore import nn
from mindspore.ops import functional as F


class HardDistillLossCell(nn.WithLossCell):
"""
Wraps the network with hard distillation loss function.
Get the loss of student network and an extra knowledge hard distillation loss
by taking a teacher model prediction and using it as additional supervision.
Args:
backbone (Cell): The student network to train and calculate base loss.
loss_fn (Cell): The loss function used to compute loss of student network.
teacher_model (Cell): The teacher network to calculate distillation loss.
alpha (float): Distillation factor. the coefficient to balance the distillation
loss and base loss. Default: 0.5.
"""

def __init__(self, backbone, loss_fn, teacher_model, alpha=0.5):
super().__init__(backbone, loss_fn)
self.teacher_model = teacher_model
self.alpha = alpha

def construct(self, data, label):
out = self._backbone(data)

out, out_kd = out
base_loss = self._loss_fn(out, label)

teacher_out = self.teacher_model(data)

distillation_loss = F.cross_entropy(out_kd, teacher_out.argmax(axis=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha

return loss


class SoftDistillLossCell(nn.WithLossCell):
"""
Wraps the network with soft distillation loss function.
Get the loss of student network and an extra knowledge soft distillation loss
by taking a teacher model prediction and using it as additional supervision.
Args:
backbone (Cell): The student network to train and calculate base loss.
loss_fn (Cell): The loss function used to compute loss of student network.
teacher_model (Cell): The teacher network to calculate distillation loss.
alpha (float): Distillation factor. the coefficient balancing the distillation
loss and base loss. Default: 0.5.
tau (float): Distillation temperature. The higher the temperature, the lower the
dispersion of the loss calculated by Kullback-Leibler divergence loss. Default: 1.0.
"""

def __init__(self, backbone, loss_fn, teacher_model, alpha=0.5, tau=1.0):
super().__init__(backbone, loss_fn)
self.teacher_model = teacher_model
self.alpha = alpha
self.tau = tau

def construct(self, data, label):
out = self._backbone(data)

out, out_kd = out
base_loss = self._loss_fn(out, label)

teacher_out = self.teacher_model(data)

T = self.tau
distillation_loss = (
F.kl_div(
F.log_softmax(out_kd / T, axis=1),
F.log_softmax(teacher_out / T, axis=1),
reduction="sum",
)
* (T * T)
/ F.size(out_kd)
)
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha

return loss
19 changes: 17 additions & 2 deletions mindcv/utils/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model

from .amp import auto_mixed_precision
from .distill_loss_cell import HardDistillLossCell, SoftDistillLossCell
from .train_step import TrainStep

__all__ = [
Expand Down Expand Up @@ -38,6 +39,7 @@ def require_customized_train_step(
clip_grad: bool = False,
gradient_accumulation_steps: int = 1,
amp_cast_list: Optional[str] = None,
distillation_type: Optional[str] = None,
):
if ema:
return True
Expand All @@ -47,6 +49,8 @@ def require_customized_train_step(
return True
if amp_cast_list:
return True
if distillation_type:
return True
return False


Expand Down Expand Up @@ -88,6 +92,9 @@ def create_trainer(
clip_grad: bool = False,
clip_value: float = 15.0,
gradient_accumulation_steps: int = 1,
distillation_type: Optional[str] = None,
teacher_network: Optional[nn.Cell] = None,
distillation_alpha: float = 0.5,
):
"""Create Trainer.
Expand Down Expand Up @@ -120,7 +127,7 @@ def create_trainer(
if gradient_accumulation_steps < 1:
raise ValueError("`gradient_accumulation_steps` must be >= 1!")

if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list):
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list, distillation_type):
mindspore_kwargs = dict(
network=network,
loss_fn=loss,
Expand Down Expand Up @@ -149,7 +156,15 @@ def create_trainer(
else: # require customized train step
eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
auto_mixed_precision(network, amp_level, amp_cast_list)
net_with_loss = add_loss_network(network, loss, amp_level)
if distillation_type:
if distillation_type == "hard":
net_with_loss = HardDistillLossCell(network, loss, teacher_network, distillation_alpha)
elif distillation_type == "soft":
net_with_loss = SoftDistillLossCell(network, loss, teacher_network, distillation_alpha)
else:
raise ValueError(f"Distillation type only support ['hard', 'soft'], but got {distillation_type}.")
else:
net_with_loss = add_loss_network(network, loss, amp_level)
train_step_kwargs = dict(
network=net_with_loss,
optimizer=optimizer,
Expand Down
16 changes: 16 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ def train(args):
aux_factor=args.aux_factor,
)

# create teacher model
teacher_network = None
if args.distillation_type:
if not args.teacher_ckpt_path:
logger.warning("You are using distillation, but your teacher model has not loaded weights.")
teacher_network = create_model(
model_name=args.teacher_model,
num_classes=num_classes,
checkpoint_path=args.teacher_ckpt_path,
)
teacher_network.set_train(False)

# create learning rate schedule
lr_scheduler = create_scheduler(
num_batches,
Expand Down Expand Up @@ -213,6 +225,7 @@ def train(args):
args.clip_grad,
args.gradient_accumulation_steps,
args.amp_cast_list,
args.distillation_type,
)
):
optimizer_loss_scale = args.loss_scale
Expand Down Expand Up @@ -250,6 +263,9 @@ def train(args):
clip_grad=args.clip_grad,
clip_value=args.clip_value,
gradient_accumulation_steps=args.gradient_accumulation_steps,
distillation_type=args.distillation_type,
teacher_network=teacher_network,
distillation_alpha=args.distillation_alpha,
)

# callback
Expand Down

0 comments on commit ca07946

Please sign in to comment.