Shortcuts

Source code for mmrazor.models.mutators.darts_mutator

# Copyright (c) OpenMMLab. All rights reserved.

import torch
from torch.nn import functional as F

from mmrazor.models.builder import MUTATORS
from .differentiable_mutator import DifferentiableMutator


[docs]@MUTATORS.register_module() class DartsMutator(DifferentiableMutator): def __init__(self, ignore_choices=('zero', ), **kwargs): super().__init__(**kwargs) self.ignore_choices = ignore_choices def search_subnet(self): subnet_dict = dict() for space_id, sub_space in self.search_spaces.items(): if space_id in self.arch_params: space_arch_param = self.arch_params[space_id] arch_probs = F.softmax(space_arch_param, dim=-1) choice_names = sub_space['choice_names'] keep_idx = [ i for i, name in enumerate(choice_names) if name not in self.ignore_choices ] best_choice_prob, best_choice_idx = torch.max( arch_probs[keep_idx], 0) best_choice_idx = keep_idx[best_choice_idx.item()] best_choice_name = choice_names[best_choice_idx] subnet_dict[space_id] = dict( chosen=[best_choice_name], chosen_probs=[best_choice_prob.item()]) def sort_key(x): return subnet_dict[x]['chosen_probs'][0] for space_id, sub_space in self.search_spaces.items(): if space_id not in self.arch_params: num_chosen = sub_space['num_chosen'] choice_names = sub_space['choice_names'] sorted_edges = list( sorted(choice_names, key=sort_key, reverse=True)) chosen = sorted_edges[:num_chosen] subnet_dict[space_id] = dict(chosen=chosen) for not_chosen in sorted_edges[num_chosen:]: subnet_dict.pop(not_chosen) return subnet_dict
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.