This project implements a approach to Knowledge Distillation (KD) by mixing:
- Contrastive Learning Loss: Enhances alignment between teacher and student representations.
- Logit Standardization: Normalizes logits for improved training stability and performance.
- Lion Optimizer: A memory-efficient, fast optimizer for large-scale machine learning tasks.
The framework aims to improve the transfer of knowledge from a ResNet-101 teacher model to a ResNet-18 student model, achieving high performance on the Imagenette dataset with reduced complexity.
- Knowledge Distillation: Transfers knowledge from a larger pre-trained teacher model (ResNet-101) to a smaller student model (ResNet-18).
- Contrastive Supervised Loss: Ensures alignment of student predictions with real ground truth labels.
- Logit Standardization Loss: Derived from the CVPR 2024 paper Logit Standardization in Knowledge Distillation, this loss stabilizes training by normalizing logits.
- Lion Optimizer: Speeds up convergence and improves generalization.
This project utilizes the Imagenette dataset, a subset of ImageNet, designed for efficient training and experimentation. Imagenette includes 10 classes of images with the same high-quality annotations as ImageNet.
- Dataset Link: Imagenette on GitHub
- Why Imagenette?
- Smaller and faster to train than the full ImageNet dataset.
- Ideal for benchmarking lightweight models and training methods.
This project uses pre-trained ResNet models:
- Teacher Model: ResNet-101, a deep and expressive model.
- Student Model: ResNet-18, a smaller and efficient architecture.
The Logit Standardization Loss aligns the outputs of the teacher and student models using normalization and temperature scaling. It is defined as:
class LogitStandardizationLoss(nn.Module):
def __init__(self):
super(LogitStandardizationLoss, self).__init__()
def forward(self, logits_teacher_in, logits_student_in, temperature=2):
logits_student = normalize1(logits_student_in)
logits_teacher = normalize1(logits_teacher_in)
log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
loss_kd *= temperature**2
return loss_kd
A contrastive supervised loss, adapted from the SupContrast repository, aligns the real labels and predicted labels of the student model:
class SupContrastiveLoss(nn.Module):
def __init__(self, temperature=0.07):
super(SupContrastiveLoss, self).__init__()
self.temperature = temperature
def forward(self, features, labels):
device = features.device
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels, labels.T).float().to(device)
features = F.normalize(features, dim=1)
similarity_matrix = torch.matmul(features, features.T) / self.temperature
logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
logits = similarity_matrix - logits_max.detach()
exp_logits = torch.exp(logits) * (1 - torch.eye(features.size(0)).to(device))
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = -mean_log_prob_pos.mean()
return loss
The final loss combines the Logit Standardization Loss and Contrastive Supervised Loss:
total_loss = kd_loss + alpha * contrastive_loss
Where:
kd_loss
: Logit Standardization Loss.contrastive_loss
: Contrastive Supervised Loss.alpha
: A weighting factor to balance the two losses.
- Teacher Model:
- Pre-trained ResNet-101 generates the logits and feature embeddings.
- Student Model:
- ResNet-18 learns using the combined loss, supervised by the teacher and ground truth labels.
Clone the repository and install dependencies:
git clone https://github.com/pariyajebreili/ConKD-Lion.git
cd ConKD-Lion
-
Prepare the Dataset: Download and preprocess the Imagenette dataset using the scripts provided in the repository.
-
Train the Teacher Model: Use the notebook
resnet101-teacher.ipynb
to train a ResNet-101 model as the teacher. -
Run Knowledge Distillation: Use the notebook
resnet101-18-teacher-student.ipynb
to train the student model (ResNet-18) using the pre-trained ResNet-101 teacher. -
Evaluate the Student Model: Evaluate the trained ResNet-18 student model.
-
Run Additional Experiments: Use the provided notebooks as templates to experiment with other model architectures, loss functions, or optimizers.
- Dataset: Imagenette on GitHub
- Teacher Model: ResNet-101
- Student Model: ResNet-18
- Logit Standardization:
- Supervised Contrastive Learning:
- Lion Optimizer: GitHub Repository for Lion Optimizer
This project is licensed under the MIT License.
Contributions are welcome! Please open an issue or submit a pull request.
Let me know if you’d like any additional tweaks!