Shortcuts

Source code for mmrazor.core.searcher.evolution_search

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import random
import time

import mmcv.fileio
from mmcv.runner import get_dist_info

from ..builder import SEARCHERS
from ..utils import broadcast_object_list


[docs]@SEARCHERS.register_module() class EvolutionSearcher(): """Implement of evolution search. Args: algorithm (:obj:`torch.nn.Module`): Algorithm to be used. dataloader (nn.Dataloader): Pytorch data loader. test_fn (function): Test api to used for evaluation. work_dir (str): Working direction is to save search result and log. logger (logging.Logger): To log info in search stage. candidate_pool_size (int): The length of candidate pool. candidate_top_k (int): Specify top k candidates based on scores. constraints (dict): Constraints to be used for screening candidates. metrics (str): Metrics to be used for evaluating candidates. metric_options (str): Options to be used for metrics. score_key (str): To be used for specifying one metric from evaluation results. max_epoch (int): Specify max epoch to end evolution search. num_mutation (int): The number of candidates got by mutation. num_crossover (int): The number of candidates got by crossover. mutate_prob (float): The probability of mutation. resume_from (str): Specify the path of saved .pkl file for resuming searching """ def __init__(self, algorithm, dataloader, test_fn, work_dir, logger, candidate_pool_size=50, candidate_top_k=10, constraints=dict(flops=330 * 1e6), metrics=None, metric_options=None, score_key='accuracy_top-1', max_epoch=20, num_mutation=25, num_crossover=25, mutate_prob=0.1, resume_from=None, **search_kwargs): if not hasattr(algorithm, 'module'): raise NotImplementedError('Do not support searching with cpu.') self.algorithm = algorithm.module self.algorithm_for_test = algorithm self.dataloader = dataloader self.constraints = constraints self.metrics = metrics self.metric_options = metric_options self.score_key = score_key self.candidate_pool = list() self.candidate_pool_size = candidate_pool_size self.max_epoch = max_epoch self.test_fn = test_fn self.candidate_top_k = candidate_top_k self.num_mutation = num_mutation self.num_crossover = num_crossover self.mutate_prob = mutate_prob self.top_k_candidates_with_score = dict() self.candidate_pool_with_score = dict() self.work_dir = work_dir self.resume_from = resume_from self.logger = logger
[docs] def check_constraints(self): """Check whether is beyond constraints. Returns: bool: The result of checking. """ flops = self.algorithm.get_subnet_flops() if flops < self.constraints['flops']: return True else: return False
[docs] def update_top_k(self): """Update top k candidates.""" self.top_k_candidates_with_score.update(self.candidate_pool_with_score) self.top_k_candidates_with_score = dict( sorted( self.top_k_candidates_with_score.items(), key=lambda x: x[0], reverse=True)) keys = list(self.top_k_candidates_with_score.keys()) new_dict = dict() for k in keys[:self.candidate_top_k]: new_dict[k] = self.top_k_candidates_with_score[k] self.top_k_candidates_with_score = new_dict.copy()
[docs] def search(self): """Execute the pipeline of evolution search.""" epoch_start = 0 if self.resume_from is not None: searcher_resume = mmcv.fileio.load(self.resume_from) for k in searcher_resume.keys(): setattr(self, k, searcher_resume[k]) epoch_start = int(searcher_resume['epoch']) self.logger.info('#' * 100) self.logger.info(f'Resume from epoch: {epoch_start}') self.logger.info('#' * 100) self.logger.info('Experiment setting:') self.logger.info(f'candidate_pool_size: {self.candidate_pool_size}') self.logger.info(f'candidate_top_k: {self.candidate_top_k}') self.logger.info(f'num_crossover: {self.num_crossover}') self.logger.info(f'num_mutation: {self.num_mutation}') self.logger.info(f'mutate_prob: {self.mutate_prob}') self.logger.info(f'max_epoch: {self.max_epoch}') self.logger.info(f'score_key: {self.score_key}') self.logger.info(f'constraints: {self.constraints}') self.logger.info('#' * 100) rank = get_dist_info()[0] for epoch in range(epoch_start, self.max_epoch): if rank == 0: while len(self.candidate_pool) < self.candidate_pool_size: candidate = \ self.algorithm.mutator.sample_subnet(searching=True) self.algorithm.mutator.set_subnet(candidate) if self.check_constraints(): self.candidate_pool.append(candidate) broadcast_candidate_pool = self.candidate_pool else: broadcast_candidate_pool = [None] * self.candidate_pool_size broadcast_candidate_pool = broadcast_object_list( broadcast_candidate_pool) for i, candidate in enumerate(broadcast_candidate_pool): self.algorithm.mutator.set_subnet(candidate) outputs = self.test_fn(self.algorithm_for_test, self.dataloader) if rank == 0: eval_result = self.dataloader.dataset.evaluate( outputs, self.metrics, self.metric_options) score = eval_result[self.score_key] self.candidate_pool_with_score[score] = candidate self.logger.info(f'Epoch:[{epoch + 1}/{self.max_epoch}] ' f'Candidate:[{i + 1}/' f'{self.candidate_pool_size}] ' f'Score:{score}') if rank == 0: scores_before = list(self.top_k_candidates_with_score.keys()) self.logger.info(f'top k scores before update: ' f'{scores_before}') self.update_top_k() scores_after = list(self.top_k_candidates_with_score.keys()) self.logger.info(f'top k scores before update: ' f'{scores_after}') mutation_candidates = list() max_mutate_iters = self.num_mutation * 10 mutate_iter = 0 while len(mutation_candidates) < self.num_mutation: mutate_iter += 1 if mutate_iter > max_mutate_iters: break candidate = random.choice( list(self.top_k_candidates_with_score.values())) mutation = self.algorithm.mutator.mutation( candidate, self.mutate_prob) self.algorithm.mutator.set_subnet(mutation) if self.check_constraints(): mutation_candidates.append(mutation) crossover_candidates = list() crossover_iter = 0 max_crossover_iters = self.num_crossover * 10 while len(crossover_candidates) < self.num_crossover: crossover_iter += 1 if crossover_iter > max_crossover_iters: break random_candidate1 = random.choice( list(self.top_k_candidates_with_score.values())) random_candidate2 = random.choice( list(self.top_k_candidates_with_score.values())) crossover_candidate = \ self.algorithm.mutator.crossover( random_candidate1, random_candidate2) self.algorithm.mutator.set_subnet(crossover_candidate) if self.check_constraints(): crossover_candidates.append(crossover_candidate) self.candidate_pool = ( mutation_candidates + crossover_candidates) save_for_resume = dict() save_for_resume['epoch'] = epoch + 1 for k in ['candidate_pool', 'top_k_candidates_with_score']: save_for_resume[k] = getattr(self, k) mmcv.fileio.dump( save_for_resume, osp.join(self.work_dir, f'search_epoch_{epoch + 1}.pkl')) self.logger.info( f'Epoch:[{epoch + 1}/{self.max_epoch}], top1_score: ' f'{list(self.top_k_candidates_with_score.keys())[0]}') self.candidate_pool = broadcast_object_list(self.candidate_pool) if rank == 0: final_subnet_dict = list( self.top_k_candidates_with_score.values())[0] self.algorithm.mutator.set_chosen_subnet(final_subnet_dict) final_subnet_dict_to_save = dict() for k in final_subnet_dict.keys(): final_subnet_dict_to_save[k] = dict({ 'chosen': self.algorithm.mutator.search_spaces[k]['chosen'] }) timestamp_subnet = time.strftime('%Y%m%d_%H%M', time.localtime()) save_name = f'final_subnet_{timestamp_subnet}.yaml' mmcv.fileio.dump(final_subnet_dict_to_save, osp.join(self.work_dir, save_name)) self.logger.info('Search finished and ' f'{save_name} saved in {self.work_dir}.')
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.