Shortcuts

Source code for mmrazor.models.losses.relational_kd

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES


def euclidean_distance(pred, squared=False, eps=1e-12):
    """Calculate the Euclidean distance between the two examples in the output
    representation space.

    Args:
        pred (torch.Tensor): The prediction of the teacher or student with
            shape (N, C).
        squared (bool): Whether to calculate the squared Euclidean
            distance. Defaults to False.
        eps (float): The minimum Euclidean distance between the two
            examples. Defaults to 1e-12.
    """
    pred_square = pred.pow(2).sum(dim=-1)  # (N, )
    prod = torch.mm(pred, pred.t())  # (N, N)
    distance = (pred_square.unsqueeze(1) + pred_square.unsqueeze(0) -
                2 * prod).clamp(min=eps)  # (N, N)

    if not squared:
        distance = distance.sqrt()

    distance = distance.clone()
    distance[range(len(prod)), range(len(prod))] = 0
    return distance


def angle(pred):
    """Calculate the angle-wise relational potential which measures the angle
    formed by the three examples in the output representation space.

    Args:
        pred (torch.Tensor): The prediction of the teacher or student with
            shape (N, C).
    """
    pred_vec = pred.unsqueeze(0) - pred.unsqueeze(1)  # (N, N, C)
    norm_pred_vec = F.normalize(pred_vec, p=2, dim=2)
    angle = torch.bmm(norm_pred_vec,
                      norm_pred_vec.transpose(1, 2)).view(-1)  # (N*N*N, )
    return angle


[docs]@LOSSES.register_module() class DistanceWiseRKD(nn.Module): """PyTorch version of distance-wise loss of `Relational Knowledge Distillation. <https://arxiv.org/abs/1904.05068>`_. Args: loss_weight (float): Weight of distance-wise distillation loss. Defaults to 25.0. with_l2_norm (bool): Whether to normalize the model predictions before calculating the loss. Defaults to True. """ def __init__(self, loss_weight=25.0, with_l2_norm=True): super(DistanceWiseRKD, self).__init__() self.loss_weight = loss_weight self.with_l2_norm = with_l2_norm
[docs] def distance_loss(self, preds_S, preds_T): """Calculate distance-wise distillation loss.""" d_T = euclidean_distance(preds_T, squared=False) # mean_d_T is a normalization factor for distance mean_d_T = d_T[d_T > 0].mean() d_T = d_T / mean_d_T d_S = euclidean_distance(preds_S, squared=False) mean_d_S = d_S[d_S > 0].mean() d_S = d_S / mean_d_S return F.smooth_l1_loss(d_S, d_T)
[docs] def forward(self, preds_S, preds_T): """Forward computation. Args: preds_S (torch.Tensor): The student model prediction with shape (N, C, H, W) or shape (N, C). preds_T (torch.Tensor): The teacher model prediction with shape (N, C, H, W) or shape (N, C). Return: torch.Tensor: The calculated loss value. """ preds_S = preds_S.view(preds_S.shape[0], -1) preds_T = preds_T.view(preds_T.shape[0], -1) if self.with_l2_norm: preds_S = F.normalize(preds_S, p=2, dim=1) preds_T = F.normalize(preds_T, p=2, dim=1) loss = self.distance_loss(preds_S, preds_T) * self.loss_weight return loss
[docs]@LOSSES.register_module() class AngleWiseRKD(nn.Module): """PyTorch version of angle-wise loss of `Relational Knowledge Distillation. <https://arxiv.org/abs/1904.05068>`_. Args: loss_weight (float): Weight of angle-wise distillation loss. Defaults to 50.0. with_l2_norm (bool): Whether to normalize the model predictions before calculating the loss. Defaults to True. """ def __init__(self, loss_weight=50.0, with_l2_norm=True): super(AngleWiseRKD, self).__init__() self.loss_weight = loss_weight self.with_l2_norm = with_l2_norm
[docs] def angle_loss(self, preds_S, preds_T): """Calculate the angle-wise distillation loss.""" angle_T = angle(preds_T) angle_S = angle(preds_S) return F.smooth_l1_loss(angle_S, angle_T)
[docs] def forward(self, preds_S, preds_T): """Forward computation. Args: preds_S (torch.Tensor): The student model prediction with shape (N, C, H, W) or shape (N, C). preds_T (torch.Tensor): The teacher model prediction with shape (N, C, H, W) or shape (N, C). Return: torch.Tensor: The calculated loss value. """ preds_S = preds_S.view(preds_S.shape[0], -1) preds_T = preds_T.view(preds_T.shape[0], -1) if self.with_l2_norm: preds_S = F.normalize(preds_S, p=2, dim=-1) preds_T = F.normalize(preds_T, p=2, dim=-1) loss = self.angle_loss(preds_S, preds_T) * self.loss_weight return loss
Read the Docs v: v0.3.0
Versions
latest
stable
v0.3.0
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.