Shortcuts

Source code for mmrazor.core.hooks.drop_path_prob

# Copyright (c) Open-MMLab. All rights reserved.

from mmcv.cnn.bricks import DropPath
from mmcv.runner import HOOKS, Hook


[docs]@HOOKS.register_module() class DropPathProbHook(Hook): """Set drop_path_prob periodically. Args: max_prob (float): The max probability of dropping. interval (int): The saving period. If ``by_epoch=True``, interval indicates epochs, otherwise it indicates iterations. Default: -1, which means "never". by_epoch (bool): Saving checkpoints by epoch or by iteration. Default: True. """ def __init__(self, max_prob, interval=-1, by_epoch=True, **kwargs): self.interval = interval self.by_epoch = by_epoch self.max_prob = max_prob assert self.by_epoch
[docs] def before_train_epoch(self, runner): """Executed in before_train_epoch stage.""" if not self.by_epoch: return if self.every_n_epochs( runner, self.interval) or (self.save_last and self.is_last_epoch(runner)): cur_epoch = runner.epoch max_epoch = runner._max_epochs drop_prob = self.max_prob * (cur_epoch * 1.0) / (max_epoch * 1.0) for module in runner.model.modules(): if isinstance(module, DropPath): module.drop_prob = drop_prob runner.logger.info(f'Set drop_prob to {drop_prob} \ at {runner.epoch + 1} epochs')
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.