Shortcuts

Source code for mmrazor.core.hooks.sampler_seed

# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.runner import Hook


# @HOOKS.register_module()
[docs]class DistSamplerSeedHook(Hook): """Data-loading sampler for distributed training. When distributed training, it is only useful in conjunction with :obj:`EpochBasedRunner`, while :obj:``IterBasedRunner`` achieves the same purpose with :obj:`IterLoader`. """
[docs] def before_epoch(self, runner): """Executed in before_epoch stage.""" if hasattr(runner.data_loader, '_dataloaders'): data_loaders = runner.data_loader._dataloaders else: data_loaders = [runner.data_loader] for data_loader in data_loaders: if hasattr(data_loader.sampler, 'set_epoch'): # in case the data loader uses ``SequentialSampler`` in Pytorch data_loader.sampler.set_epoch(runner.epoch) elif hasattr(data_loader.batch_sampler.sampler, 'set_epoch'): # batch sampler in pytorch warps the sampler as its attributes. data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.