Customize pruning algorithms¶
Here we show how to develop new Pruning algorithms with an example of AutoSlim.
Register a new algorithm
Create a new file mmrazor/models/algorithms/prunning/autoslim.py
, class AutoSlim
inherits from class BaseAlgorithm
.
from mmrazor.registry import MODELS
from .base import BaseAlgorithm
@MODELS.register_module()
class AutoSlim(BaseAlgorithm):
def __init__(self,
mutator,
distiller,
architecture,
data_preprocessor,
num_random_samples = 2,
init_cfg = None) -> None:
super().__init__(**kwargs)
pass
def train_step(self, data, optimizer):
pass
Develop new algorithm components (optional)
AutoSlim can directly use class OneShotChannelMutator
as core functions provider. If it can not meet your needs, you can develop new algorithm components for your algorithm like OneShotChannalMutator
. We will take OneShotChannelMutator
as an example to introduce how to develop a new algorithm component:
a. Create a new file mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py
, class OneShotChannelMutator
can inherits from ChannelMutator
.
b. Finish the functions you need, eg: build_search_groups
, set_choices
, sample_choices
and so on
from mmrazor.registry import MODELS
from .channel_mutator import ChannelMutator
@MODELS.register_module()
class OneShotChannelMutator(ChannelMutator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def sample_choices(self):
pass
def set_choices(self, choice_dict):
pass
# supernet is a kind of architecture in `mmrazor/models/architectures/`
def build_search_groups(self, supernet):
pass
c. Import the module in mmrazor/models/mutators/channel_mutator/__init__.py
from .one_shot_channel_mutator import OneShotChannelMutator
__all__ = [..., 'OneShotChannelMutator']
Rewrite its train_step
Develop key logic of your algorithm in functiontrain_step
from mmrazor.registry import MODELS
from ..base import BaseAlgorithm
@ALGORITHMS.register_module()
class AutoSlim(BaseAlgorithm):
def __init__(self,
mutator,
distiller,
architecture,
data_preprocessor,
num_random_samples = 2,
init_cfg = None) -> None:
super(AutoSlim, self).__init__(**kwargs)
pass
def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
def distill_step(
batch_inputs: torch.Tensor, data_samples: List[BaseDataElement]
) -> Dict[str, torch.Tensor]:
...
return subnet_losses
batch_inputs, data_samples = self.data_preprocessor(data, True)
total_losses = dict()
for kind in self.sample_kinds:
# update the max subnet loss.
if kind == 'max':
self.set_max_subnet()
with optim_wrapper.optim_context(
self), self.distiller.teacher_recorders: # type: ignore
max_subnet_losses = self(batch_inputs, data_samples, mode='loss')
parsed_max_subnet_losses, _ = self.parse_losses(max_subnet_losses)
optim_wrapper.update_params(parsed_max_subnet_losses)
total_losses.update(add_prefix(max_subnet_losses, 'max_subnet'))
# update the min subnet loss.
elif kind == 'min':
self.set_min_subnet()
min_subnet_losses = distill_step(batch_inputs, data_samples)
total_losses.update(add_prefix(min_subnet_losses, 'min_subnet'))
# update the random subnets loss.
elif 'random' in kind:
self.set_subnet(self.sample_subnet())
random_subnet_losses = distill_step(batch_inputs, data_samples)
total_losses.update(
add_prefix(random_subnet_losses, f'{kind}_subnet'))
return total_losses
Add your custom functions (optional)
After finishing your key logic in function train_step
, if you also need other custom functions, you can add them in class AutoSlim
.
Import the class
You can either add the following line to mmrazor/models/algorithms/__init__.py
from .pruning import AutoSlim
__all__ = [..., 'AutoSlim']
Or alternatively add
custom_imports = dict(
imports=['mmrazor.models.algorithms.pruning.autoslim'],
allow_failed_imports=False)
to the config file to avoid modifying the original code.
Use the algorithm in your config file
model = dict(
type='AutoSlim',
architecture=...,
mutator=dict(type='OneShotChannelMutator', ...),
)