Shortcuts

Source code for mmrazor.models.losses.weighted_soft_label_distillation

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES


[docs]@LOSSES.register_module() class WSLD(nn.Module): """PyTorch version of `Rethinking Soft Labels for Knowledge Distillation: A Bias-Variance Tradeoff Perspective <https://arxiv.org/abs/2102.00650>`_. Args: tau (float): Temperature coefficient. Defaults to 1.0. loss_weight (float): Weight of loss. Defaults to 1.0. num_classes (int): Defaults to 1000. """ def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000): super(WSLD, self).__init__() self.tau = tau self.loss_weight = loss_weight self.num_classes = num_classes self.softmax = nn.Softmax(dim=1).cuda() self.logsoftmax = nn.LogSoftmax(dim=1).cuda()
[docs] def forward(self, student, teacher): gt_labels = self.current_data['gt_label'] student_logits = student / self.tau teacher_logits = teacher / self.tau teacher_probs = self.softmax(teacher_logits) ce_loss = -torch.sum( teacher_probs * self.logsoftmax(student_logits), 1, keepdim=True) student_detach = student.detach() teacher_detach = teacher.detach() log_softmax_s = self.logsoftmax(student_detach) log_softmax_t = self.logsoftmax(teacher_detach) one_hot_labels = F.one_hot( gt_labels, num_classes=self.num_classes).float() ce_loss_s = -torch.sum(one_hot_labels * log_softmax_s, 1, keepdim=True) ce_loss_t = -torch.sum(one_hot_labels * log_softmax_t, 1, keepdim=True) focal_weight = ce_loss_s / (ce_loss_t + 1e-7) ratio_lower = torch.zeros(1).cuda() focal_weight = torch.max(focal_weight, ratio_lower) focal_weight = 1 - torch.exp(-focal_weight) ce_loss = focal_weight * ce_loss loss = (self.tau**2) * torch.mean(ce_loss) loss = self.loss_weight * loss return loss
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.