Shortcuts

Source code for mmrazor.models.algorithms.darts

# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Microsoft Corporation.
import copy

import torch
from torch import nn

from mmrazor.models.builder import ALGORITHMS
from .base import BaseAlgorithm


[docs]@ALGORITHMS.register_module() class Darts(BaseAlgorithm): def __init__(self, unroll, **kwargs): super(Darts, self).__init__(**kwargs) self.unroll = unroll
[docs] def train_step(self, data, optimizer): """The iteration step during training. This method defines an iteration step during training, except for the back propagation and optimizer updating, which are done in an optimizer hook. Note that in some complicated cases or models, the whole process including back propagation and optimizer updating are also defined in this method, such as GAN. Args: data (dict): The output of dataloader. optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of runner is passed to ``train_step()``. This argument is unused and reserved. Returns: dict: It should contain at least 3 keys: ``loss``, ``log_vars``, ``num_samples``. ``loss`` is a tensor for back propagation, which can be a weighted sum of multiple losses. ``log_vars`` contains all the variables to be sent to the logger. ``num_samples`` indicates the batch size (when the model is DDP, it means the batch size on each GPU), which is used for averaging the logs. """ if isinstance(data, (tuple, list)) and isinstance(optimizer, dict): assert len(data) == len(optimizer) train_arch_data, train_supernet_data = data optimizer['mutator'].zero_grad() if self.unroll: self._unrolled_backward(train_arch_data, train_supernet_data, optimizer) else: arch_losses = self(**train_arch_data) arch_loss, _ = self._parse_losses(arch_losses) arch_loss.backward() optimizer['mutator'].step() model_losses = self(**train_supernet_data) model_loss, log_vars = self._parse_losses(model_losses) optimizer['architecture'].zero_grad() model_loss.backward() nn.utils.clip_grad_norm_( self.architecture.parameters(), max_norm=5, norm_type=2) optimizer['architecture'].step() outputs = dict( loss=model_loss, log_vars=log_vars, num_samples=len(train_supernet_data['img'].data)) else: outputs = super(Darts, self).train_step(data, optimizer) return outputs
def _unrolled_backward(self, train_arch_data, train_supernet_data, optimizer): """Compute unrolled loss and backward its gradients.""" backup_params = copy.deepcopy(tuple(self.architecture.parameters())) # do virtual step on training data lr = optimizer['architecture'].param_groups[0]['lr'] momentum = optimizer['architecture'].param_groups[0]['momentum'] weight_decay = optimizer['architecture'].param_groups[0][ 'weight_decay'] self._compute_virtual_model(train_supernet_data, lr, momentum, weight_decay, optimizer) # calculate unrolled loss on validation data # keep gradients for model here for compute hessian losses = self(**train_arch_data) loss, _ = self._parse_losses(losses) w_model, w_arch = tuple(self.architecture.parameters()), tuple( self.mutator.parameters()) w_grads = torch.autograd.grad(loss, w_model + w_arch) d_model, d_arch = w_grads[:len(w_model)], w_grads[len(w_model):] # compute hessian and final gradients hessian = self._compute_hessian(backup_params, d_model, train_supernet_data) with torch.no_grad(): for param, d, h in zip(w_arch, d_arch, hessian): # gradient = dalpha - lr * hessian param.grad = d - lr * h # restore weights self._restore_weights(backup_params) def _compute_virtual_model(self, data, lr, momentum, weight_decay, optimizer): """Compute unrolled weights w`""" # don't need zero_grad, using autograd to calculate gradients losses = self(**data) loss, _ = self._parse_losses(losses) gradients = torch.autograd.grad(loss, self.architecture.parameters()) with torch.no_grad(): for w, g in zip(self.architecture.parameters(), gradients): m = optimizer['architecture'].state[w].get( 'momentum_buffer', 0.) w = w - lr * (momentum * m + g + weight_decay * w) def _restore_weights(self, backup_params): with torch.no_grad(): for param, backup in zip(self.architecture.parameters(), backup_params): param.copy_(backup) def _compute_hessian(self, backup_params, dw, data): """ dw = dw` { L_val(w`, alpha) } w+ = w + eps * dw w- = w - eps * dw hessian = (dalpha { L_trn(w+, alpha) } \ - dalpha { L_trn(w-, alpha) }) / (2*eps) eps = 0.01 / ||dw|| """ self._restore_weights(backup_params) norm = torch.cat([w.view(-1) for w in dw]).norm() eps = 0.01 / norm if norm < 1E-8: print( 'In computing hessian, norm is smaller than 1E-8, \ cause eps to be %.6f.', norm.item()) dalphas = [] for e in [eps, -2. * eps]: # w+ = w + eps*dw`, w- = w - eps*dw` with torch.no_grad(): for p, d in zip(self.architecture.parameters(), dw): p += e * d losses = self(**data) loss, _ = self._parse_losses(losses) dalphas.append( torch.autograd.grad(loss, tuple(self.mutator.parameters()))) # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) } dalpha_pos, dalpha_neg = dalphas hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)] return hessian
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.