Shortcuts

Source code for mmrazor.structures.subnet.candidate

# Copyright (c) OpenMMLab. All rights reserved.
from collections import UserList
from typing import Any, Dict, List, Optional, Union


[docs]class Candidates(UserList): """The data structure of sampled candidate. The format is Union[Dict[str, Dict], List[Dict[str, Dict]]]. Examples: >>> candidates = Candidates() >>> subnet_1 = {'1': 'choice1', '2': 'choice2'} >>> candidates.append(subnet_1) >>> candidates [{"{'1': 'choice1', '2': 'choice2'}": {'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}] >>> candidates.set_resources(0, 49.9, 'flops') >>> candidates.set_score(0, 100.) >>> candidates [{"{'1': 'choice1', '2': 'choice2'}": {'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}}] >>> subnet_2 = {'choice_3': 'layer_3', 'choice_4': 'layer_4'} >>> candidates.append(subnet_2) >>> candidates [{"{'1': 'choice1', '2': 'choice2'}": {'score': 100.0, 'flops': 49.9, 'params': 0.0, 'latency': 0.0}}, {"{'choice_3': 'layer_3', 'choice_4':'layer_4'}": {'score': 0.0, 'flops': 0.0, 'params': 0.0, 'latency': 0.0}}] >>> candidates.subnets [{'1': 'choice1', '2': 'choice2'}, {'choice_3': 'layer_3', 'choice_4': 'layer_4'}] >>> candidates.resources('flops') [49.9, 0.0] >>> candidates.scores [100.0, 0.0] """ _format_return = Union[Dict[str, Dict], List[Dict[str, Dict]]] _format_input = Union[Dict, List[Dict], Dict[str, Dict], List[Dict[str, Dict]]] _indicators = ('score', 'flops', 'params', 'latency') def __init__(self, initdata: Optional[_format_input] = None): self.data = [] if initdata is not None: initdata = self._format(initdata) if isinstance(initdata, list): self.data = initdata else: self.data.append(initdata) @property def scores(self) -> List[float]: """The scores of candidates.""" return [ round(value.get('score', 0.), 2) for item in self.data for _, value in item.items() ]
[docs] def resources(self, key_indicator: str = 'flops') -> List[float]: """The resources of candidates.""" assert key_indicator in ['flops', 'params', 'latency'] return [ value.get(key_indicator, 0.) for item in self.data for _, value in item.items() ]
@property def subnets(self) -> List[Dict]: """The subnets of candidates.""" import copy assert len(self.data) > 0, ('Got empty candidates.') if 'value_subnet' in self.data[0]: subnets = [] for data in self.data: subnet = dict() _data = copy.deepcopy(data) for k1 in ['value_subnet', 'channel_subnet']: for k2 in self._indicators: _data[k1].pop(k2) subnet[k1] = _data[k1] subnets.append(subnet) return subnets else: return [eval(key) for item in self.data for key, _ in item.items()] def _format(self, data: _format_input) -> _format_return: """Transform [Dict, ...] to Union[Dict[str, Dict], List[Dict[str, Dict]]]. Args: data: Four types of input are supported: 1. Dict: only include network information. 2. List[Dict]: multiple candidates only include network information. 3. Dict[str, Dict]: network information and the corresponding resources. 4. List[Dict[str, Dict]]: multiple candidate information. Returns: Union[Dict[str, Dict], UserList[Dict[str, Dict]]]: A dict or a list of dict that contains a pair of network information and the corresponding Score | FLOPs | Params | Latency results in each candidate. Notes: Score | FLOPs | Params | Latency: 1. a candidate resources with a default value of -1 indicates that it has not been estimated. 2. a candidate resources with a default value of 0 indicates that some indicators have been evaluated. """ def _format_item( cond: Union[Dict, Dict[str, Dict]]) -> Dict[str, Dict]: """Transform Dict to Dict[str, Dict].""" if isinstance(list(cond.values())[0], dict): for value in list(cond.values()): for key in list(self._indicators): value.setdefault(key, 0.) return cond else: return {str(cond): {}.fromkeys(self._indicators, -1)} if isinstance(data, UserList): return [_format_item(i) for i in data.data] elif isinstance(data, list): return [_format_item(i) for i in data] else: return _format_item(data)
[docs] def append(self, item: _format_input) -> None: """Append operation.""" item = self._format(item) if isinstance(item, list): self.data = self.data + item else: self.data.append(item)
[docs] def insert(self, i: int, item: _format_input) -> None: """Insert operation.""" item = self._format(item) self.data.insert(i, item)
[docs] def extend(self, other: Any) -> None: """Extend operation.""" other = self._format(other) if isinstance(other, list): self.data.extend(other) else: self.data.extend([other])
[docs] def set_score(self, i: int, score: float) -> None: """Set score to the specified subnet by index.""" self.set_resource(i, score, 'score')
[docs] def set_resource(self, i: int, resources: float, key_indicator: str = 'flops') -> None: """Set resources to the specified subnet by index.""" assert key_indicator in ['score', 'flops', 'params', 'latency'] for _, value in self.data[i].items(): value[key_indicator] = resources
[docs] def update_resources(self, resources: list, start: int = 0) -> None: """Update resources to the specified candidate.""" end = start + len(resources) assert len( self.data) >= end, 'Check the number of candidate resources.' for i, item in enumerate(self.data[start:end]): for _, value in item.items(): value.update(resources[i])
[docs] def sort_by(self, key_indicator: str = 'score', reverse: bool = True) -> None: """Sort by a specific indicator in descending order. Args: key_indicator (str): sort all candidates by key_indicator. Defaults to 'score'. reverse (bool): sort all candidates in descending order. """ self.data.sort( key=lambda x: list(x.values())[0][key_indicator], reverse=reverse)
Read the Docs v: latest
Versions
latest
stable
v1.0.0
v1.0.0rc2
v1.0.0rc1
v1.0.0rc0
v0.3.1
v0.3.0
v0.2.0
quantize
main
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.