Shortcuts

Source code for mmrazor.models.losses.cwd

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES


[docs]@LOSSES.register_module() class ChannelWiseDivergence(nn.Module): """PyTorch version of `Channel-wise Distillation for Semantic Segmentation. <https://arxiv.org/abs/2011.13256>`_. Args: tau (float): Temperature coefficient. Defaults to 1.0. loss_weight (float): Weight of loss. Defaults to 1.0. """ def __init__( self, tau=1.0, loss_weight=1.0, ): super(ChannelWiseDivergence, self).__init__() self.tau = tau self.loss_weight = loss_weight
[docs] def forward(self, preds_S, preds_T): """Forward computation. Args: preds_S (torch.Tensor): The student model prediction with shape (N, C, H, W). preds_T (torch.Tensor): The teacher model prediction with shape (N, C, H, W). Return: torch.Tensor: The calculated loss value. """ assert preds_S.shape[-2:] == preds_T.shape[-2:] N, C, H, W = preds_S.shape softmax_pred_T = F.softmax(preds_T.view(-1, W * H) / self.tau, dim=1) logsoftmax = torch.nn.LogSoftmax(dim=1) loss = torch.sum(softmax_pred_T * logsoftmax(preds_T.view(-1, W * H) / self.tau) - softmax_pred_T * logsoftmax(preds_S.view(-1, W * H) / self.tau)) * ( self.tau**2) loss = self.loss_weight * loss / (C * N) return loss
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.