Shortcuts

Customize NAS algorithms

Here we show how to develop new NAS algorithms with an example of SPOS.

  1. Register a new algorithm

Create a new file mmrazor/models/algorithms/nas/spos.py, class SPOS inherits from class BaseAlgorithm

from mmrazor.registry import MODELS
from ..base import BaseAlgorithm

@MODELS.register_module()
class SPOS(BaseAlgorithm):
    def __init__(self, **kwargs):
        super(SPOS, self).__init__(**kwargs)
        pass

    def loss(self, batch_inputs, data_samples):
        pass
  1. Develop new algorithm components (optional)

SPOS can directly use class OneShotModuleMutator as core functions provider. If mutators provided in MMRazor don’t meet your needs, you can develop new algorithm components for your algorithm like OneShotModuleMutator, we will take OneShotModuleMutator as an example to introduce how to develop a new algorithm component:

a. Create a new file mmrazor/models/mutators/module_mutator/one_shot_module_mutator.py, class OneShotModuleMutator inherits from class ModuleMutator

b. Finish the functions you need in OneShotModuleMutator, eg: sample_choices, set_choices and so on.

from mmrazor.registry import MODELS
from .module_mutator import ModuleMutator


@MODELS.register_module()
class OneShotModuleMutator(ModuleMutator):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def sample_choices(self) -> Dict[int, Any]:
        pass

    def set_choices(self, choices: Dict[int, Any]) -> None:
        pass

    @property
    def mutable_class_type(self):
        return OneShotMutableModule

c. Import the new mutator

You can either add the following line to mmrazor/models/mutators/__init__.py

from .module_mutator import OneShotModuleMutator

or alternatively add

custom_imports = dict(
    imports=['mmrazor.models.mutators.module_mutator.one_shot_module_mutator'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

d. Use the algorithm component in your config file

mutator=dict(type='mmrazor.OneShotModuleMutator')

For further information, please refer to Mutator for more details.

  1. Rewrite its loss function.

Develop key logic of your algorithm in functionloss. When having special steps to optimize, you should rewrite the function train_step.

@MODELS.register_module()
class SPOS(BaseAlgorithm):
    def __init__(self, **kwargs):
        super(SPOS, self).__init__(**kwargs)
        pass

    def sample_subnet(self):
        pass

    def set_subnet(self, subnet):
        pass

    def loss(self, batch_inputs, data_samples):
        if self.is_supernet:
            random_subnet = self.sample_subnet()
            self.set_subnet(random_subnet)
            return self.architecture(batch_inputs, data_samples, mode='loss')
        else:
            return self.architecture(batch_inputs, data_samples, mode='loss')
  1. Add your custom functions (optional)

After finishing your key logic in function loss, if you also need other custom functions, you can add them in class SPOS as follows.

  1. Import the class

You can either add the following line to mmrazor/models/algorithms/nas/__init__.py

from .spos import SPOS

__all__ = ['SPOS']

or alternatively add

custom_imports = dict(
    imports=['mmrazor.models.algorithms.nas.spos'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

  1. Use the algorithm in your config file

model = dict(
    type='mmrazor.SPOS',
    architecture=supernet,
    mutator=dict(type='mmrazor.OneShotModuleMutator'))
Read the Docs v: latest
Versions
latest
stable
v1.0.0
v1.0.0rc2
v1.0.0rc1
v1.0.0rc0
v0.3.1
v0.3.0
v0.2.0
quantize
main
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.