Source code for mmrazor.models.losses.kl_divergence
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
[docs]@LOSSES.register_module()
class KLDivergence(nn.Module):
"""A measure of how one probability distribution Q is different from a
second, reference probability distribution P.
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
reduction (str): Specifies the reduction to apply to the loss:
``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
``'none'``: no reduction will be applied,
``'batchmean'``: the sum of the output will be divided by
the batchsize,
``'sum'``: the output will be summed,
``'mean'``: the output will be divided by the number of
elements in the output.
Default: ``'batchmean'``
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(
self,
tau=1.0,
reduction='batchmean',
loss_weight=1.0,
):
super(KLDivergence, self).__init__()
self.tau = tau
self.loss_weight = loss_weight
accept_reduction = {'none', 'batchmean', 'sum', 'mean'}
assert reduction in accept_reduction, \
f'KLDivergence supports reduction {accept_reduction}, ' \
f'but gets {reduction}.'
self.reduction = reduction
[docs] def forward(self, preds_S, preds_T):
"""Forward computation.
Args:
preds_S (torch.Tensor): The student model prediction with
shape (N, C, H, W) or shape (N, C).
preds_T (torch.Tensor): The teacher model prediction with
shape (N, C, H, W) or shape (N, C).
Return:
torch.Tensor: The calculated loss value.
"""
preds_T = preds_T.detach()
softmax_pred_T = F.softmax(preds_T / self.tau, dim=1)
logsoftmax_preds_S = F.log_softmax(preds_S / self.tau, dim=1)
loss = (self.tau**2) * F.kl_div(
logsoftmax_preds_S, softmax_pred_T, reduction=self.reduction)
return self.loss_weight * loss