Shortcuts

Source code for mmrazor.models.mutators.differentiable_mutator

# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from functools import partial

from torch import nn

from mmrazor.models.builder import MUTATORS
from mmrazor.models.mutables import MutableModule
from .base import BaseMutator


[docs]@MUTATORS.register_module() class DifferentiableMutator(BaseMutator): """A mutator for the differentiable NAS, which mainly provide some core functions of changing the structure of ``ARCHITECTURES``.""" def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def prepare_from_supernet(self, supernet): """Inherit from ``BaseMutator``'s, execute some customized functions exclude implementing origin ``prepare_from_supernet``. Args: supernet (:obj:`torch.nn.Module`): The architecture to be used in your algorithm. """ super().prepare_from_supernet(supernet) self.arch_params = self.build_arch_params(supernet) self.modify_supernet_forward(supernet)
[docs] def build_arch_params(self, supernet): """This function will build many arch params, which are generally used in diffirentiale search algorithms, such as Darts' series. Each space_id corresponds to an arch param, so the Mutable with the same space_id share the same arch param. Args: supernet (:obj:`torch.nn.Module`): The architecture to be used in your algorithm. Returns: torch.nn.ParameterDict: the arch params are got after traversing the supernet. """ arch_params = nn.ParameterDict() # Traverse all the child modules of the model. If a child module is an # Space instance and its space_id is not recorded, call its # :func:'build_space_architecture' and record the return value. If not, # pass. def traverse(module): for name, child in module.named_children(): if isinstance(child, MutableModule): space_id = child.space_id if space_id not in arch_params: space_arch_param = child.build_arch_param() if space_arch_param is not None: arch_params[space_id] = space_arch_param traverse(child) traverse(supernet) return arch_params
[docs] def modify_supernet_forward(self, supernet): """Modify the supernet's default value in forward. Traverse all child modules of the model, modify the supernet's default value in :func:'forward' of each Space. Args: supernet (:obj:`torch.nn.Module`): The architecture to be used in your algorithm. """ def traverse(module): for name, child in module.named_children(): if isinstance(child, MutableModule): if child.space_id in self.arch_params.keys(): space_id = child.space_id space_arch_param = self.arch_params[space_id] child.forward = partial( child.forward, arch_param=space_arch_param) traverse(child) traverse(supernet)
@abstractmethod def search_subnet(self): pass
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.