Shortcuts

Source code for mmrazor.models.mutables.mutable_channel.base_mutable_channel

# Copyright (c) OpenMMLab. All rights reserved.
""""""
from abc import abstractmethod

import torch

from mmrazor.utils.typing import DumpChosen
from ..base_mutable import BaseMutable
from ..derived_mutable import DerivedMethodMixin


[docs]class BaseMutableChannel(BaseMutable, DerivedMethodMixin): """BaseMutableChannel works as a channel mask for DynamicOps to select channels. |---------------------------------------| |mutable_in_channel(BaseMutableChannel) | |---------------------------------------| | DynamicOp | |---------------------------------------| |mutable_out_channel(BaseMutableChannel)| |---------------------------------------| All subclasses should implement the following APIs and the other abstract method in ``BaseMutable`` - ``current_mask`` Args: num_channels (int): number(dimension) of channels(mask). """ def __init__(self, num_channels: int, **kwargs): super().__init__(**kwargs) self.name = '' self.num_channels = num_channels @property # type: ignore @abstractmethod def current_mask(self) -> torch.Tensor: """Return a mask indicating the channel selection.""" raise NotImplementedError() @property def activated_channels(self) -> int: """Number of activated channels.""" return (self.current_mask == 1).sum().item() # implementation of abstract methods
[docs] def fix_chosen(self, chosen=None): """Fix the mutable with chosen.""" if chosen is not None: self.current_choice = chosen if self.is_fixed: raise AttributeError( 'The mode of current MUTABLE is `fixed`. ' 'Please do not call `fix_chosen` function again.') self.is_fixed = True
[docs] def dump_chosen(self) -> DumpChosen: """Dump chosen.""" meta = dict(max_channels=self.mask.size(0)) chosen = self.export_chosen() return DumpChosen(chosen=chosen, meta=meta)
def export_chosen(self) -> int: return self.activated_channels
[docs] def num_choices(self) -> int: """Number of available choices.""" raise NotImplementedError()
# others def __repr__(self): repr_str = self.__class__.__name__ repr_str += '(' repr_str += f'num_channels={self.num_channels}, ' repr_str += f'activated_channels={self.activated_channels}' repr_str += ')' return repr_str
Read the Docs v: latest
Versions
latest
stable
v1.0.0
v1.0.0rc2
v1.0.0rc1
v1.0.0rc0
v0.3.1
v0.3.0
v0.2.0
quantize
main
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.