Shortcuts

Source code for mmrazor.models.mutables.mutable_op

# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import MUTABLES, build_op
from .mutable_module import MutableModule


[docs]class MutableOP(MutableModule): """An important type of ``MUTABLES``, inherits from ``MutableModule``. Args: choices (dict): The configs for the choices, the chosen ``OPS`` used to combine ``MUTABLES``. choice_args (dict): The args used to set chosen ``OPS``. """ def __init__(self, choices, choice_args, **kwargs): super(MutableOP, self).__init__(**kwargs) self.choices = self.build_choices(choices, choice_args) self.choice_mask = self.build_choice_mask() self.full_choice_names = copy.deepcopy(self.choice_names)
[docs] def build_choices(self, cfgs, choice_args): """Build all chosen ``OPS`` used to combine ``MUTABLES``, and the choices will be sampled. Args: cfgs (dict): The configs for the choices. choice_args (dict): The args used to set chosen ``OPS``. Returns: torch.nn.ModuleDict: Consists of chosen ``OPS`` in the arg `cfgs`. """ choices = nn.ModuleDict() for name, cfg in cfgs.items(): cfg.update(choice_args) op_module = build_op(cfg) choices.add_module(name, op_module) return choices
[docs]@MUTABLES.register_module() class OneShotOP(MutableOP): """A type of ``MUTABLES`` for the one-shot NAS.""" def __init__(self, **kwargs): super(OneShotOP, self).__init__(**kwargs) assert self.num_chosen == 1
[docs] def forward(self, x): """Forward computation for chosen ``OPS``, in one-shot NAS, the number of chosen ``OPS`` can only be one. Args: x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation. Returns: torch.Tensor: The result of forward. """ outputs = list() for name, chosen_bool in zip(self.full_choice_names, self.choice_mask): if name not in self.choice_names: continue if not chosen_bool: continue module = self.choices[name] outputs.append(module(x)) assert len(outputs) > 0 return sum(outputs)
[docs]@MUTABLES.register_module() class DifferentiableOP(MutableOP): """Differentiable OP. Search the best module from choices by learnable parameters. Args: with_arch_param (bool): whether build learable architecture parameters. """ def __init__(self, with_arch_param, **kwargs): super(DifferentiableOP, self).__init__(**kwargs) self.with_arch_param = with_arch_param
[docs] def build_arch_param(self): """build learnable architecture parameters.""" if self.with_arch_param: return nn.Parameter(torch.randn(self.num_choices) * 1e-3) else: return None
[docs] def compute_arch_probs(self, arch_param): """compute chosen probs according architecture parameters.""" return F.softmax(arch_param, -1)
[docs] def forward(self, x, arch_param=None): """forward function. In some algorithms, there are several ``MutableModule`` share the same architecture parameters. So the architecture parameters are passed in as args. Args: prev_inputs (list[torch.Tensor]): each choice's inputs. arch_param (torch.nn.Parameter): architecture parameters. """ if self.with_arch_param: assert arch_param is not None, \ f'In {self.space_id}, the arch_param can not be None when the \ with_arch_param=True.' # 1. compute choices' probs. probs = self.compute_arch_probs(arch_param) # 2. compute every op's outputs. outputs = list() for prob, module in zip(probs, self.choice_modules): if prob > 0: outputs.append(prob * module(x)) else: outputs = list() for name, chosen_bool in zip(self.full_choice_names, self.choice_mask): if name not in self.choice_names: continue if not chosen_bool: continue module = self.choices[name] outputs.append(module(x)) assert len(outputs) > 0 return sum(outputs)
[docs]@MUTABLES.register_module() class GumbelOP(DifferentiableOP): """Gumbel OP. Search the best module from choices by gumbel trick. """ def __init__(self, tau=1.0, hard=True, **kwargs): super(GumbelOP, self).__init__(**kwargs) self.tau = tau self.hard = hard
[docs] def set_temperature(self, temperature): """Modify the temperature.""" self.temperature = temperature
[docs] def compute_arch_probs(self, arch_param): """compute chosen probs by gumbel trick.""" probs = F.gumbel_softmax( arch_param, tau=self.tau, hard=self.hard, dim=-1) return probs
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.