• Docs >
  • Module code >
  • mmrazor.models.mutables.mutable_channel.simple_mutable_channel
Shortcuts

Source code for mmrazor.models.mutables.mutable_channel.simple_mutable_channel

# Copyright (c) OpenMMLab. All rights reserved.

from typing import Union

import torch

from mmrazor.registry import MODELS
from ..derived_mutable import DerivedMutable
from .base_mutable_channel import BaseMutableChannel


[docs]@MODELS.register_module() class SimpleMutableChannel(BaseMutableChannel): """SimpleMutableChannel is a simple BaseMutableChannel, it directly take a mask as a choice. Args: num_channels (int): number of channels. """ def __init__(self, num_channels: int, **kwargs) -> None: super().__init__(num_channels, **kwargs) mask = torch.ones([self.num_channels ]) # save bool as float for dist training self.register_buffer('mask', mask) self.mask: torch.Tensor # choice @property def current_choice(self) -> torch.Tensor: """Get current choice.""" return self.mask.bool() @current_choice.setter def current_choice(self, choice: torch.Tensor): """Set current choice.""" self.mask = choice.to(self.mask.device).float() @property def current_mask(self) -> torch.Tensor: """Get current mask.""" return self.current_choice.bool() # basic extension
[docs] def expand_mutable_channel( self, expand_ratio: Union[int, float]) -> DerivedMutable: """Get a derived SimpleMutableChannel with expanded mask.""" def _expand_mask(): mask = self.current_mask mask = torch.unsqueeze( mask, -1).expand(list(mask.shape) + [expand_ratio]).flatten(-2) return mask return DerivedMutable(_expand_mask, _expand_mask, [self])
Read the Docs v: latest
Versions
latest
stable
v1.0.0
v1.0.0rc2
v1.0.0rc1
v1.0.0rc0
v0.3.1
v0.3.0
v0.2.0
quantize
main
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.