What is Mutable¶
Mutable is one of basic function components in NAS algorithms and some pruning algorithms, which makes supernet searchable by providing optional modules or parameters.
To understand it better, we take the mutable module as an example to explain as follows.
As shown in the figure above,
Mutable is a container that holds some candidate operations, thus it can sample candidates to constitute the subnet.
Supernet usually consists of multiple
Supernet will be searchable with the help of
Mutable. And all candidate operations in
Mutable constitute the search space of
If you want to know more about the relationship between Mutable and Mutator, please refer to Mutator
1. Support module mutable¶
2. Support parameter mutable¶
To implement more complicated and funny algorithms easier, we supported making some important parameters searchable, such as input channel, output channel, kernel size and so on.
What is more, we can implement dynamic op by using mutable parameters.
3. Support deriving from mutable parameter¶
Because of the restriction of defined architecture, there may be correlations between some mutable parameters, such as concat and expand ratio.
If conv3 = concat (conv1, conv2)
When out_channel (conv1) = 3, out_channel (conv2) = 4
Then in_channel (conv3) must be 7 rather than mutable.
So use derived mutable from conv1 and conv2 to generate in_channel (conv3)
With the help of derived mutable, we can meet these special requirements in some NAS algorithms and pruning algorithms. What is more, it can be used to deal with different granularity between search spaces.
As shown in the figure above.
White blocks stand the basic classes, which include
BaseMutableis the base class for all mutables, which defines required properties and abstracmethods.
DerivedMethodMixinis a mixin class to provide mutable parameters with some useful methods to derive mutable.
Gray blocks stand different types of base mutables.
Because there are correlations between channels of some layers, we divide mutable parameters into
MutableValue, so you can also think
MutableChannelis a special
For supporting module and parameters mutable, we provide
MutableValuethese base classes to implement required basic functions. And we also add
DiffMutableModuletwo types based on
MutableModuleto meet different types of algorithms’ requirements.
For supporting deriving from mutable parameters, we make
DerivedMethodMixin, thus they can get derived functions provided by
Red blocks and green blocks stand registered classes for implementing some specific algorithms, which means that you can use them directly in configs. If they do not meet your requirements, you can also customize your mutable based on our base classes. If you are interested in their realization, please refer to their docstring.
How to use existing mutables to configure searchable backbones¶
We will use
OneShotMutableOP to build a
SearchableShuffleNetV2 backbone as follows.
Configure needed mutables
# we only use OneShotMutableOP, then take 4 ShuffleOP as its candidates. _STAGE_MUTABLE = dict( _scope_='mmrazor', type='OneShotMutableOP', candidates=dict( shuffle_3x3=dict(type='ShuffleBlock', kernel_size=3), shuffle_5x5=dict(type='ShuffleBlock', kernel_size=5), shuffle_7x7=dict(type='ShuffleBlock', kernel_size=7), shuffle_xception=dict(type='ShuffleXception')))
# Use the _STAGE_MUTABLE in various stages. arch_setting = [ # Parameters to build layers. 3 parameters are needed to construct a # layer, from left to right: channel, num_blocks, mutable_cfg. [64, 4, _STAGE_MUTABLE], [160, 4, _STAGE_MUTABLE], [320, 8, _STAGE_MUTABLE], [640, 4, _STAGE_MUTABLE] ]
Configure searchable backbone.
nas_backbone = dict( _scope_='mmrazor', type='SearchableShuffleNetV2', widen_factor=1.0, arch_setting=arch_setting)
Then you can use it in your architecture. If existing mutables do not meet your needs, you can also customize your needed mutable.
How to customize your mutable¶
About base mutable¶
Before customizing mutables, we need to know what some base mutables do.
In order to implement the searchable mechanism, mutables need to own some base functions, such as changing status from mutable to fixed, recording the current status and current choice and so on. So in
BaseMutable, these relevant abstract methods and properties will be defined as follows.
# Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod from typing import Dict, Generic, Optional, TypeVar from mmengine.model import BaseModule CHOICE_TYPE = TypeVar('CHOICE_TYPE') CHOSEN_TYPE = TypeVar('CHOSEN_TYPE') class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]): def __init__(self, alias: Optional[str] = None, init_cfg: Optional[Dict] = None) -> None: super().__init__(init_cfg=init_cfg) self.alias = alias self._is_fixed = False self._current_choice: Optional[CHOICE_TYPE] = None @property def current_choice(self) -> Optional[CHOICE_TYPE]: return self._current_choice @current_choice.setter def current_choice(self, choice: Optional[CHOICE_TYPE]) -> None: self._current_choice = choice @property def is_fixed(self) -> bool: return self._is_fixed @is_fixed.setter def is_fixed(self, is_fixed: bool) -> None: ...... self._is_fixed = is_fixed @abstractmethod def fix_chosen(self, chosen: CHOSEN_TYPE) -> None: pass @abstractmethod def dump_chosen(self) -> CHOSEN_TYPE: pass @property @abstractmethod def num_choices(self) -> int: pass
MutableModule not only owns its basic functions, but also needs some specialized functions to implement module mutable, such as getting all choices, executing forward computation.
# Copyright (c) OpenMMLab. All rights reserved. from abc import abstractmethod from typing import Any, Dict, List, Optional from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE, BaseMutable class MutableModule(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): def __init__(self, module_kwargs: Optional[Dict[str, Dict]] = None, **kwargs) -> None: super().__init__(**kwargs) self.module_kwargs = module_kwargs @property @abstractmethod def choices(self) -> List[CHOICE_TYPE]: """list: all choices. All subclasses must implement this method.""" @abstractmethod def forward(self, x: Any) -> Any: """Forward computation.""" @property def num_choices(self) -> int: """Number of choices.""" return len(self.choices)
If you want to know more about other types mutables, please refer to their docstring.
Steps of customizing mutables¶
There are 4 steps to implement a custom mutable.
Registry a new mutable
Implement abstract methods.
Implement other methods.
Import the class
Then you can use your customized mutable in configs as in the previous chapter.
OneShotMutableOP as an example for customizing mutable.
1. Registry a new mutable¶
First, you need to determine which type mutable to implement. Thus, you can implement your mutable faster by inheriting from correlative base mutable.
Then create a new file
OneShotMutableOP inherits from
# Copyright (c) OpenMMLab. All rights reserved. import random from abc import abstractmethod from typing import Any, Dict, List, Optional, Union import numpy as np import torch.nn as nn from torch import Tensor from mmrazor.registry import MODELS from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE from .mutable_module import MutableModule @MODELS.register_module() class OneShotMutableOP(OneShotMutableModule[str, str]): ...
2. Implement abstract methods¶
2.1 Basic abstract methods¶
These basic abstract methods are mainly from
MutableModule, such as
@MODELS.register_module() class OneShotMutableOP(OneShotMutableModule[str, str]): ...... def fix_chosen(self, chosen: str) -> None: """Fix mutable with subnet config. This operation would convert `unfixed` mode to `fixed` mode. The :attr:`is_fixed` will be set to True and only the selected operations can be retained. Args: chosen (str): the chosen key in ``MUTABLE``. Defaults to None. """ if self.is_fixed: raise AttributeError( 'The mode of current MUTABLE is `fixed`. ' 'Please do not call `fix_chosen` function again.') for c in self.choices: if c != chosen: self._candidates.pop(c) self._chosen = chosen self.is_fixed = True def dump_chosen(self) -> str: assert self.current_choice is not None return self.current_choice @property def choices(self) -> List[str]: """list: all choices. """ return list(self._candidates.keys()) @property def num_choices(self): return len(self.choices)
2.2 Specified abstract methods¶
OneShotMutableModule, sample and forward these required abstract methods are defined, such as
forward_all. So we need to implement these abstract methods.
@MODELS.register_module() class OneShotMutableOP(OneShotMutableModule[str, str]): ...... def sample_choice(self) -> str: """uniform sampling.""" return np.random.choice(self.choices, 1) def forward_fixed(self, x: Any) -> Tensor: """Forward with the `fixed` mutable. Args: x (Any): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. Returns: Tensor: the result of forward the fixed operation. """ return self._candidates[self._chosen](x) def forward_choice(self, x: Any, choice: str) -> Tensor: """Forward with the `unfixed` mutable and current choice is not None. Args: x (Any): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. choice (str): the chosen key in ``OneShotMutableOP``. Returns: Tensor: the result of forward the ``choice`` operation. """ assert isinstance(choice, str) and choice in self.choices return self._candidates[choice](x) def forward_all(self, x: Any) -> Tensor: """Forward all choices. Used to calculate FLOPs. Args: x (Any): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. Returns: Tensor: the result of forward all of the ``choice`` operation. """ outputs = list() for op in self._candidates.values(): outputs.append(op(x)) return sum(outputs)
3. Implement other methods¶
After finishing some required methods, we need to add some special methods, such as
_build_ops, because it is needed in building candidates for sampling.
@MODELS.register_module() class OneShotMutableOP(OneShotMutableModule[str, str]): ...... @staticmethod def _build_ops( candidates: Union[Dict[str, Dict], nn.ModuleDict], module_kwargs: Optional[Dict[str, Dict]] = None) -> nn.ModuleDict: """Build candidate operations based on choice configures. Args: candidates (dict[str, dict] | :obj:`nn.ModuleDict`): the configs for the candidate operations or nn.ModuleDict. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Returns: ModuleDict (dict[str, Any], optional): the key of ``ops`` is the name of each choice in configs and the value of ``ops`` is the corresponding candidate operation. """ if isinstance(candidates, nn.ModuleDict): return candidates ops = nn.ModuleDict() for name, op_cfg in candidates.items(): assert name not in ops if module_kwargs is not None: op_cfg.update(module_kwargs) ops[name] = MODELS.build(op_cfg) return ops
4. Import the class¶
You can either add the following line to
from .one_shot_mutable_module import OneShotMutableModule __all__ = ['OneShotMutableModule']
or alternatively add
custom_imports = dict( imports=['mmrazor.models.mutables.mutable_module.one_shot_mutable_module'], allow_failed_imports=False)
to the config file to avoid modifying the original code.
OneShotMutableOP is over, then you can use it directly in your algorithm.