Source code for mmrazor.models.algorithms.align_method_kd
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.models.builder import ALGORITHMS
from .general_distill import GeneralDistill
[docs]@ALGORITHMS.register_module()
class AlignMethodDistill(GeneralDistill):
def __init__(self, **kwargs):
super(AlignMethodDistill, self).__init__(**kwargs)
def train_step(self, data, optimizer):
with self.distiller.context_manager:
outputs = super().train_step(data, optimizer)
return outputs