• Docs >
  • Module code >
  • mmrazor.models.mutables.mutable_channel.units.slimmable_channel_unit
Shortcuts

Source code for mmrazor.models.mutables.mutable_channel.units.slimmable_channel_unit

# Copyright (c) OpenMMLab. All rights reserved.

from typing import List, Union

import torch.nn as nn

from mmrazor.models.architectures import dynamic_ops
from mmrazor.registry import MODELS
from ..mutable_channel_container import MutableChannelContainer
from .one_shot_mutable_channel_unit import OneShotMutableChannelUnit


[docs]@MODELS.register_module() class SlimmableChannelUnit(OneShotMutableChannelUnit): """A type of ``MutableChannelUnit`` to train several subnets together. Args: num_channels (int): The raw number of channels. candidate_choices (List[Union[int, float]], optional): A list of candidate width ratios. Each candidate indicates how many channels to be reserved. Defaults to [0.5, 1.0](choice_mode='ratio'). choice_mode (str, optional): Mode of candidates. One of 'ratio' or 'number'. Defaults to 'number'. divisor (int, optional): Used to make choice divisible. min_value (int, optional): The minimal value used when make divisible. min_ratio (float, optional): The minimal ratio used when make divisible. """ def __init__(self, num_channels: int, candidate_choices: List[Union[int, float]] = [], choice_mode='number', divisor=1, min_value=1, min_ratio=0.9) -> None: super().__init__(num_channels, candidate_choices, choice_mode, divisor, min_value, min_ratio)
[docs] def prepare_for_pruning(self, model: nn.Module): """Prepare for pruning.""" self._replace_with_dynamic_ops( model, { nn.Conv2d: dynamic_ops.DynamicConv2d, nn.BatchNorm2d: dynamic_ops.SwitchableBatchNorm2d, nn.Linear: dynamic_ops.DynamicLinear }) self.alter_candidates_of_switchbn(self.candidate_choices) self._register_channel_container(model, MutableChannelContainer) self._register_mutable_channel(self.mutable_channel)
[docs] def alter_candidates_of_switchbn(self, candidates: List): """Change candidates of SwitchableBatchNorm2d.""" for channel in list(self.output_related) + list(self.input_related): if isinstance(channel.module, dynamic_ops.SwitchableBatchNorm2d) \ and len(channel.module.candidate_bn) == 0: channel.module.init_candidates(candidates) self.current_choice = self.max_choice
Read the Docs v: latest
Versions
latest
stable
v1.0.0
v1.0.0rc2
v1.0.0rc1
v1.0.0rc0
v0.3.1
v0.3.0
v0.2.0
quantize
main
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.