Shortcuts

Source code for mmrazor.models.mutables.mutable_module.diff_mutable_module

# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from mmrazor.registry import MODELS
from mmrazor.utils.typing import DumpChosen
from .mutable_module import MutableModule

PartialType = Callable[[Any, Optional[nn.Parameter]], Any]


[docs]class DiffMutableModule(MutableModule): """Base class for differentiable mutables. Args: module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, and `Pretrained`. Note: :meth:`forward_all` is called when calculating FLOPs. """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs)
[docs] @abstractmethod def sample_choice(self, arch_param: Tensor): """Sample choice according arch parameters.""" raise NotImplementedError
[docs] def forward(self, x: Any, arch_param: Optional[nn.Parameter] = None): """Calls either :func:`forward_fixed` or :func:`forward_arch_param` depending on whether :func:`is_fixed` is ``True`` and whether :func:`arch_param` is None. To reduce the coupling between `Mutable` and `Mutator`, the `arch_param` is generated by the `Mutator` and is passed to the forward function as an argument. Note: :meth:`forward_fixed` is called when in `fixed` mode. :meth:`forward_arch_param` is called when in `unfixed` mode. Args: x (Any): input data for forward computation. arch_param (nn.Parameter, optional): the architecture parameters for ``DiffMutableModule``. Returns: Any: the result of forward """ if self.is_fixed: return self.forward_fixed(x) else: if arch_param is None: return self.forward_all(x) else: return self.forward_arch_param(x, arch_param=arch_param)
[docs] def compute_arch_probs(self, arch_param: nn.Parameter) -> Tensor: """compute chosen probs according to architecture params.""" return F.softmax(arch_param, -1)
[docs] @abstractmethod def forward_arch_param(self, x, arch_param: nn.Parameter): """Forward when the mutable is not fixed. All subclasses must implement this method. """
[docs] def set_forward_args(self, arch_param: nn.Parameter) -> None: """Interface for modifying the arch_param using partial.""" forward_with_default_args: PartialType = \ partial(self.forward, arch_param=arch_param) setattr(self, 'forward', forward_with_default_args)
[docs]@MODELS.register_module() class DiffMutableOP(DiffMutableModule): """A type of ``MUTABLES`` for differentiable architecture search, such as DARTS. Search the best module by learnable parameters `arch_param`. Args: candidates (dict[str, dict]): the configs for the candidate operations. fix_threshold (float): The threshold that determines whether to fix the choice of current module as the op with the maximum `probs`. It happens when the maximum prob is `fix_threshold` or more higher then all the other probs. Default to 1.0. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, and `Pretrained`. """ def __init__( self, candidates: Dict[str, Dict], fix_threshold: float = 1.0, module_kwargs: Optional[Dict[str, Dict]] = None, alias: Optional[str] = None, init_cfg: Optional[Dict] = None, ) -> None: super().__init__( module_kwargs=module_kwargs, alias=alias, init_cfg=init_cfg) assert len(candidates) >= 1, \ f'Number of candidate op must greater than or equal to 1, ' \ f'but got: {len(candidates)}' self._is_fixed = False if fix_threshold < 0 or fix_threshold > 1.0: raise ValueError( f'The fix_threshold should be in [0, 1]. Got {fix_threshold}.') self.fix_threshold = fix_threshold self._candidates = self._build_ops(candidates, self.module_kwargs) @staticmethod def _build_ops(candidates: Dict[str, Dict], module_kwargs: Optional[Dict[str, Dict]]) -> nn.ModuleDict: """Build candidate operations based on candidates configures. Args: candidates (dict[str, dict]): the configs for the candidate operations. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Returns: ModuleDict (dict[str, Any], optional): the key of ``ops`` is the name of each choice in configs and the value of ``ops`` is the corresponding candidate operation. """ ops = nn.ModuleDict() for name, op_cfg in candidates.items(): assert name not in ops if module_kwargs is not None: op_cfg.update(module_kwargs) ops[name] = MODELS.build(op_cfg) return ops
[docs] def forward_fixed(self, x) -> Tensor: """Forward when the mutable is in `fixed` mode. Args: x (Any): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. Returns: Tensor: the result of forward the fixed operation. """ return sum(self._candidates[choice](x) for choice in self._chosen)
[docs] def forward_arch_param(self, x, arch_param: nn.Parameter) -> Tensor: """Forward with architecture parameters. Args: x (Any): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. arch_param (str, optional): architecture parameters for `DiffMutableModule` Returns: Tensor: the result of forward with ``arch_param``. """ # compute the probs of choice probs = self.compute_arch_probs(arch_param=arch_param) # forward based on probs outputs = list() for prob, module in zip(probs, self._candidates.values()): if prob > 0.: outputs.append(prob * module(x)) return sum(outputs)
[docs] def forward_all(self, x) -> Tensor: """Forward all choices. Used to calculate FLOPs. Args: x (Any): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. Returns: Tensor: the result of forward all of the ``choice`` operation. """ outputs = list() for op in self._candidates.values(): outputs.append(op(x)) return sum(outputs)
[docs] def fix_chosen(self, chosen: Union[str, List[str]]) -> None: """Fix mutable with `choice`. This operation would convert `unfixed` mode to `fixed` mode. The :attr:`is_fixed` will be set to True and only the selected operations can be retained. Args: chosen (str): the chosen key in ``MUTABLE``. Defaults to None. """ if self.is_fixed: raise AttributeError( 'The mode of current MUTABLE is `fixed`. ' 'Please do not call `fix_chosen` function again.') if isinstance(chosen, str): chosen = [chosen] for c in self.choices: if c not in chosen: self._candidates.pop(c) self._chosen = chosen self.is_fixed = True
[docs] def sample_choice(self, arch_param: Tensor) -> str: """Sample choice based on arch_parameters.""" return self.choices[torch.argmax(arch_param).item()]
[docs] def dump_chosen(self) -> DumpChosen: chosen = self.export_chosen() meta = dict(all_choices=self.choices) return DumpChosen(chosen=chosen, meta=meta)
def export_chosen(self) -> str: assert self.current_choice is not None return self.current_choice @property def choices(self) -> List[str]: """list: all choices. """ return list(self._candidates.keys())
[docs]@MODELS.register_module() class OneHotMutableOP(DiffMutableOP): """A type of ``MUTABLES`` for one-hot sample based architecture search, such as DSNAS. Search the best module by learnable parameters `arch_param`. Args: candidates (dict[str, dict]): the configs for the candidate operations. module_kwargs (dict[str, dict], optional): Module initialization named arguments. Defaults to None. alias (str, optional): alias of the `MUTABLE`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 5 initializer including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, and `Pretrained`. """
[docs] def sample_weights(self, arch_param: nn.Parameter, probs: torch.Tensor, random_sample: bool = False) -> Tensor: """Use one-hot distributions to sample the arch weights based on the arch params. Args: arch_param (nn.Parameter): architecture parameters for `DiffMutableModule`. probs (Tensor): the probs of choice. random_sample (bool): Whether to random sample arch weights or not Defaults to False. Returns: Tensor: Sampled one-hot arch weights. """ import torch.distributions as D if random_sample: uni = torch.ones_like(arch_param) m = D.one_hot_categorical.OneHotCategorical(uni) else: m = D.one_hot_categorical.OneHotCategorical(probs=probs) return m.sample()
[docs] def forward_arch_param( self, x: Any, arch_param: nn.Parameter, ) -> Tensor: """Forward with architecture parameters. Args: x (Any): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. arch_param (str, optional): architecture parameters for `DiffMutableModule`. Returns: Tensor: the result of forward with ``arch_param``. """ # compute the probs of choice probs = self.compute_arch_probs(arch_param=arch_param) if not self.is_fixed: self.arch_weights = self.sample_weights(arch_param, probs) sorted_param = torch.topk(probs, 2) index = ( sorted_param[0][0] - sorted_param[0][1] >= self.fix_threshold) if index: self.fix_chosen(self.choices[index]) if self.is_fixed: index = self.choices.index(self._chosen[0]) self.arch_weights.data.zero_() self.arch_weights.data[index].fill_(1.0) self.arch_weights.requires_grad_() # forward based on self.arch_weights outputs = list() for prob, module in zip(self.arch_weights, self._candidates.values()): if prob > 0.: outputs.append(prob * module(x)) return sum(outputs)
[docs]@MODELS.register_module() class DiffChoiceRoute(DiffMutableModule): """A type of ``MUTABLES`` for Neural Architecture Search, which can select inputs from different edges in a differentiable or non-differentiable way. It is commonly used in DARTS. Args: edges (nn.ModuleDict): the key of `edges` is the name of different edges. The value of `edges` can be :class:`nn.Module` or :class:`DiffMutableModule`. with_arch_param (bool): whether forward with arch_param. When set to `True`, a differentiable way is adopted. When set to `False`, a non-differentiable way is adopted. alias (str, optional): alias of the `DiffChoiceRoute`. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 6 initializers including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, and `Pretrained`. Examples: >>> import torch >>> import torch.nn as nn >>> edges_dict=nn.ModuleDict() >>> edges_dict.add_module('first_edge', nn.Conv2d(32, 32, 3, 1, 1)) >>> edges_dict.add_module('second_edge', nn.Conv2d(32, 32, 5, 1, 2)) >>> edges_dict.add_module('third_edge', nn.MaxPool2d(3, 1, 1)) >>> edges_dict.add_module('fourth_edge', nn.MaxPool2d(5, 1, 2)) >>> edges_dict.add_module('fifth_edge', nn.MaxPool2d(7, 1, 3)) >>> diff_choice_route_cfg = dict( ... type="DiffChoiceRoute", ... edges=edges_dict, ... with_arch_param=True, ... ) >>> arch_param Parameter containing: tensor([-6.1426e-04, 2.3596e-04, 1.4427e-03, 7.1668e-05, -8.9739e-04], requires_grad=True) >>> x = [torch.randn(4, 32, 64, 64) for _ in range(5)] >>> output=diffchoiceroute.forward_arch_param(x, arch_param) >>> output.shape torch.Size([4, 32, 64, 64]) """ def __init__( self, edges: nn.ModuleDict, num_chosen: int = 2, with_arch_param: bool = False, alias: Optional[str] = None, init_cfg: Optional[Dict] = None, ) -> None: super().__init__(alias=alias, init_cfg=init_cfg) assert len(edges) >= 1, \ f'Number of edges must greater than or equal to 1, ' \ f'but got: {len(edges)}' self._with_arch_param = with_arch_param self._is_fixed = False self._candidates: nn.ModuleDict = edges self.num_chosen = num_chosen
[docs] def forward(self, x: Any, arch_param: Optional[nn.Parameter] = None): """Calls either :func:`forward_fixed` or :func:`forward_arch_param` depending on whether :func:`is_fixed` is ``True`` and whether :func:`arch_param` is None. To reduce the coupling between `Mutable` and `Mutator`, the `arch_param` is generated by the `Mutator` and is passed to the forward function as an argument. Note: :meth:`forward_fixed` is called when in `fixed` mode. :meth:`forward_arch_param` is called when in `unfixed` mode. Args: x (Any): input data for forward computation. arch_param (nn.Parameter, optional): the architecture parameters for ``DiffMutableModule``. Returns: Any: the result of forward """ if self.is_fixed: return self.forward_fixed(x) else: if arch_param is not None and self._with_arch_param: return self.forward_arch_param(x, arch_param=arch_param) else: return self.forward_all(x)
[docs] def forward_fixed(self, inputs: Union[List, Tuple]) -> Tensor: """Forward when the mutable is in `fixed` mode. Args: inputs (Union[List[Any], Tuple[Any]]): inputs could be a list or a tuple of Torch.tensor, containing input data for forward computation. Returns: Tensor: the result of forward the fixed operation. """ assert self._chosen is not None, \ 'Please call fix_chosen before calling `forward_fixed`.' outputs = list() for choice, x in zip(self._unfixed_choices, inputs): if choice in self._chosen: outputs.append(self._candidates[choice](x)) return sum(outputs)
[docs] def forward_arch_param(self, x, arch_param: nn.Parameter) -> Tensor: """Forward with architecture parameters. Args: x (list[Any] | tuple[Any]]): x could be a list or a tuple of Torch.tensor, containing input data for forward selection. arch_param (nn.Parameter): architecture parameters for for ``DiffMutableModule``. Returns: Tensor: the result of forward with ``arch_param``. """ assert len(x) == len(self._candidates), \ f'Length of `edges` {len(self._candidates)} should be ' \ f'same as the length of inputs {len(x)}.' probs = self.compute_arch_probs(arch_param=arch_param) outputs = list() for prob, module, input in zip(probs, self._candidates.values(), x): if prob > 0: # prob may equal to 0 in gumbel softmax. outputs.append(prob * module(input)) return sum(outputs)
[docs] def forward_all(self, x): """Forward all choices. Args: x (Any): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. Returns: Tensor: the result of forward all of the ``choice`` operation. """ assert len(x) == len(self._candidates), \ f'Lenght of edges {len(self._candidates)} should be same as ' \ f'the length of inputs {len(x)}.' outputs = list() for op, input in zip(self._candidates.values(), x): outputs.append(op(input)) return sum(outputs)
[docs] def fix_chosen(self, chosen: List[str]) -> None: """Fix mutable with `choice`. This operation would convert to `fixed` mode. The :attr:`is_fixed` will be set to True and only the selected operations can be retained. Args: chosen (list(str)): the chosen key in ``MUTABLE``. """ self._unfixed_choices = self.choices if self.is_fixed: raise AttributeError( 'The mode of current MUTABLE is `fixed`. ' 'Please do not call `fix_chosen` function again.') for c in self.choices: if c not in chosen: self._candidates.pop(c) self._chosen = chosen self.is_fixed = True
@property def choices(self) -> List[str]: """list: all choices. """ return list(self._candidates.keys())
[docs] def dump_chosen(self) -> DumpChosen: chosen = self.export_chosen() meta = dict(all_choices=self.choices) return DumpChosen(chosen=chosen, meta=meta)
def export_chosen(self) -> str: assert self.current_choice is not None return self.current_choice
[docs] def sample_choice(self, arch_param: Tensor) -> List[str]: """sample choice based on `arch_param`.""" sort_idx = torch.argsort(-arch_param).cpu().numpy().tolist() choice_idx = sort_idx[:self.num_chosen] choice = [self.choices[i] for i in choice_idx] return choice
@MODELS.register_module() class GumbelChoiceRoute(DiffChoiceRoute): """A type of ``MUTABLES`` for Neural Architecture Search using Gumbel-Max trick, which can select inputs from different edges in a differentiable or non-differentiable way. It is commonly used in DARTS. Args: edges (nn.ModuleDict): the key of `edges` is the name of different edges. The value of `edges` can be :class:`nn.Module` or :class:`DiffMutableModule`. tau (float): non-negative scalar temperature in gumbel softmax. hard (bool): if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd. Defaults to `True`. with_arch_param (bool): whether forward with arch_param. When set to `True`, a differentiable way is adopted. When set to `False`, a non-differentiable way is adopted. init_cfg (dict, optional): initialization configuration dict for ``BaseModule``. OpenMMLab has implement 6 initializers including `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, and `Pretrained`. """ def __init__( self, edges: nn.ModuleDict, tau: float = 1.0, hard: bool = True, with_arch_param: bool = False, alias: Optional[str] = None, init_cfg: Optional[Dict] = None, ) -> None: super().__init__( edges=edges, with_arch_param=with_arch_param, alias=alias, init_cfg=init_cfg) self.tau = tau self.hard = hard def compute_arch_probs(self, arch_param: nn.Parameter) -> Tensor: """Compute chosen probs by Gumbel-Max trick.""" return F.gumbel_softmax( arch_param, tau=self.tau, hard=self.hard, dim=-1) def set_temperature(self, tau: float) -> None: """Set temperature of gumbel softmax.""" self.tau = tau
Read the Docs v: latest
Versions
latest
stable
v1.0.0
v1.0.0rc2
v1.0.0rc1
v1.0.0rc0
v0.3.1
v0.3.0
v0.2.0
quantize
main
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.