-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
85450da
commit ca07946
Showing
5 changed files
with
131 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" distillation loss cell define """ | ||
from mindspore import nn | ||
from mindspore.ops import functional as F | ||
|
||
|
||
class HardDistillLossCell(nn.WithLossCell): | ||
""" | ||
Wraps the network with hard distillation loss function. | ||
Get the loss of student network and an extra knowledge hard distillation loss | ||
by taking a teacher model prediction and using it as additional supervision. | ||
Args: | ||
backbone (Cell): The student network to train and calculate base loss. | ||
loss_fn (Cell): The loss function used to compute loss of student network. | ||
teacher_model (Cell): The teacher network to calculate distillation loss. | ||
alpha (float): Distillation factor. the coefficient to balance the distillation | ||
loss and base loss. Default: 0.5. | ||
""" | ||
|
||
def __init__(self, backbone, loss_fn, teacher_model, alpha=0.5): | ||
super().__init__(backbone, loss_fn) | ||
self.teacher_model = teacher_model | ||
self.alpha = alpha | ||
|
||
def construct(self, data, label): | ||
out = self._backbone(data) | ||
|
||
out, out_kd = out | ||
base_loss = self._loss_fn(out, label) | ||
|
||
teacher_out = self.teacher_model(data) | ||
|
||
distillation_loss = F.cross_entropy(out_kd, teacher_out.argmax(axis=1)) | ||
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha | ||
|
||
return loss | ||
|
||
|
||
class SoftDistillLossCell(nn.WithLossCell): | ||
""" | ||
Wraps the network with soft distillation loss function. | ||
Get the loss of student network and an extra knowledge soft distillation loss | ||
by taking a teacher model prediction and using it as additional supervision. | ||
Args: | ||
backbone (Cell): The student network to train and calculate base loss. | ||
loss_fn (Cell): The loss function used to compute loss of student network. | ||
teacher_model (Cell): The teacher network to calculate distillation loss. | ||
alpha (float): Distillation factor. the coefficient balancing the distillation | ||
loss and base loss. Default: 0.5. | ||
tau (float): Distillation temperature. The higher the temperature, the lower the | ||
dispersion of the loss calculated by Kullback-Leibler divergence loss. Default: 1.0. | ||
""" | ||
|
||
def __init__(self, backbone, loss_fn, teacher_model, alpha=0.5, tau=1.0): | ||
super().__init__(backbone, loss_fn) | ||
self.teacher_model = teacher_model | ||
self.alpha = alpha | ||
self.tau = tau | ||
|
||
def construct(self, data, label): | ||
out = self._backbone(data) | ||
|
||
out, out_kd = out | ||
base_loss = self._loss_fn(out, label) | ||
|
||
teacher_out = self.teacher_model(data) | ||
|
||
T = self.tau | ||
distillation_loss = ( | ||
F.kl_div( | ||
F.log_softmax(out_kd / T, axis=1), | ||
F.log_softmax(teacher_out / T, axis=1), | ||
reduction="sum", | ||
) | ||
* (T * T) | ||
/ F.size(out_kd) | ||
) | ||
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters