Shortcuts

Customize KD algorithms

Here we show how to develop new KD algorithms with an example of SingleTeacherDistill.

  1. Register a new algorithm

Create a new file mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py, class SingleTeacherDistill inherits from class BaseAlgorithm

from mmrazor.registry import MODELS
from ..base import BaseAlgorithm

@ALGORITHMS.register_module()
class SingleTeacherDistill(BaseAlgorithm):
    def __init__(self, use_gt, **kwargs):
        super(Distillation, self).__init__(**kwargs)
        pass

    def train_step(self, data, optimizer):
        pass
  1. Develop connectors (Optional) .

Take ConvModuleConnector as an example.

from mmrazor.registry import MODELS
from .base_connector import BaseConnector

@MODELS.register_module()
class ConvModuleConnector(BaseConnector):
    def __init__(self, in_channel, out_channel, kernel_size = 1, stride = 1):
        ...

    def forward_train(self, feature):
        ...
  1. Develop distiller.

Take ConfigurableDistiller as an example.

from .base_distiller import BaseDistiller
from mmrazor.registry import MODELS


@MODELS.register_module()
class ConfigurableDistiller(BaseDistiller):
    def __init__(self,
                 student_recorders = None,
                 teacher_recorders = None,
                 distill_deliveries = None,
                 connectors = None,
                 distill_losses = None,
                 loss_forward_mappings = None):
        ...

     def build_connectors(self, connectors):
        ...

     def build_distill_losses(self, losses):
        ...

     def compute_distill_losses(self):
        ...
  1. Develop custom loss (Optional).

Here we take L1Loss as an example. Create a new file in mmrazor/models/losses/l1_loss.py.

from mmrazor.registry import MODELS

@MODELS.register_module()
class L1Loss(nn.Module):
    def __init__(
        self,
        loss_weight: float = 1.0,
        size_average: Optional[bool] = None,
        reduce: Optional[bool] = None,
        reduction: str = 'mean',
    ) -> None:
        super().__init__()
        ...

    def forward(self, s_feature, t_feature):
        loss = F.l1_loss(s_feature, t_feature, self.size_average, self.reduce,
                         self.reduction)
        return self.loss_weight * loss
  1. Import the class

You can either add the following line to mmrazor/models/algorithms/__init__.py

from .single_teacher_distill import SingleTeacherDistill

__all__ = [..., 'SingleTeacherDistill']

or alternatively add

custom_imports = dict(
    imports=['mmrazor.models.algorithms.distill.configurable.single_teacher_distill'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

  1. Use the algorithm in your config file

algorithm = dict(
    type='Distill',
    distiller=dict(type='SingleTeacherDistill', ...),
    # you can also use your new algorithm components here
    ...
)
Read the Docs v: latest
Versions
latest
stable
v1.0.0
v1.0.0rc2
v1.0.0rc1
v1.0.0rc0
v0.3.1
v0.3.0
v0.2.0
quantize
main
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.