update for NAS-Bench-102
This commit is contained in:
		| @@ -1,5 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .api import AANASBenchAPI | ||||
| from .api import NASBench102API | ||||
| from .api import ArchResults, ResultsCount | ||||
| @@ -2,7 +2,7 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, copy, random, torch, numpy as np | ||||
| from collections import OrderedDict | ||||
| from collections import OrderedDict, defaultdict | ||||
| 
 | ||||
| 
 | ||||
| def print_information(information, extra_info=None, show=False): | ||||
| @@ -30,16 +30,17 @@ def print_information(information, extra_info=None, show=False): | ||||
|   return strings | ||||
| 
 | ||||
| 
 | ||||
| class AANASBenchAPI(object): | ||||
| class NASBench102API(object): | ||||
| 
 | ||||
|   def __init__(self, file_path_or_dict, verbose=True): | ||||
|     if isinstance(file_path_or_dict, str): | ||||
|       if verbose: print('try to create AA-NAS-Bench api from {:}'.format(file_path_or_dict)) | ||||
|       if verbose: print('try to create NAS-Bench-102 api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
|       file_path_or_dict = torch.load(file_path_or_dict) | ||||
|     else: | ||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||
|     import pdb; pdb.set_trace() # we will update this api soon | ||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||
|     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) | ||||
| @@ -144,27 +145,46 @@ class ArchResults(object): | ||||
|   def get_comput_costs(self, dataset): | ||||
|     x_seeds = self.dataset_seed[dataset] | ||||
|     results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] | ||||
|     flops   = [result.flop for result in results] | ||||
|     params  = [result.params for result in results] | ||||
| 
 | ||||
|     flops      = [result.flop for result in results] | ||||
|     params     = [result.params for result in results] | ||||
|     lantencies = [result.get_latency() for result in results] | ||||
|     return np.mean(flops), np.mean(params), np.mean(lantencies) | ||||
|     lantencies = [x for x in lantencies if x > 0] | ||||
|     mean_latency = np.mean(lantencies) if len(lantencies) > 0 else None | ||||
|     time_infos = defaultdict(list) | ||||
|     for result in results: | ||||
|       time_info = result.get_times() | ||||
|       for key, value in time_info.items(): time_infos[key].append( value ) | ||||
|       | ||||
|     info = {'flops'  : np.mean(flops), | ||||
|             'params' : np.mean(params), | ||||
|             'latency': mean_latency} | ||||
|     for key, value in time_infos.items(): | ||||
|       if len(value) > 0 and value[0] is not None: | ||||
|         info[key] = np.mean(value) | ||||
|       else: info[key] = None | ||||
|     return info | ||||
| 
 | ||||
|   def get_metrics(self, dataset, setname, iepoch=None, is_random=False): | ||||
|     x_seeds = self.dataset_seed[dataset] | ||||
|     results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] | ||||
|     loss, accuracy = [], [] | ||||
|     infos   = defaultdict(list) | ||||
|     for result in results: | ||||
|       if setname == 'train': | ||||
|         info = result.get_train(iepoch) | ||||
|       else: | ||||
|         info = result.get_eval(setname, iepoch) | ||||
|       loss.append( info['loss'] ) | ||||
|       accuracy.append( info['accuracy'] ) | ||||
|       for key, value in info.items(): infos[key].append( value ) | ||||
|     return_info = dict() | ||||
|     if is_random: | ||||
|       index = random.randint(0, len(loss)-1) | ||||
|       return loss[index], accuracy[index] | ||||
|       index = random.randint(0, len(results)-1) | ||||
|       for key, value in infos.items(): return_info[key] = value[index] | ||||
|     else: | ||||
|       return float(np.mean(loss)), float(np.mean(accuracy)) | ||||
|       for key, value in infos.items(): | ||||
|         if len(value) > 0 and value[0] is not None: | ||||
|           return_info[key] = np.mean(value) | ||||
|         else: return_info[key] = None | ||||
|     return return_info | ||||
| 
 | ||||
|   def show(self, is_print=False): | ||||
|     return print_information(self, None, is_print) | ||||
| @@ -245,8 +265,10 @@ class ResultsCount(object): | ||||
|   def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency): | ||||
|     self.name           = name | ||||
|     self.net_state_dict = state_dict | ||||
|     self.train_accs   = copy.deepcopy(train_accs) | ||||
|     self.train_acc1es = copy.deepcopy(train_accs) | ||||
|     self.train_acc5es = None | ||||
|     self.train_losses = copy.deepcopy(train_losses) | ||||
|     self.train_times  = None | ||||
|     self.arch_config  = copy.deepcopy(arch_config) | ||||
|     self.params     = params | ||||
|     self.flop       = flop | ||||
| @@ -256,44 +278,97 @@ class ResultsCount(object): | ||||
|     # evaluation results | ||||
|     self.reset_eval() | ||||
| 
 | ||||
|   def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times): | ||||
|     self.train_acc1es = train_acc1es | ||||
|     self.train_acc5es = train_acc5es | ||||
|     self.train_losses = train_losses | ||||
|     self.train_times  = train_times | ||||
| 
 | ||||
|   def reset_eval(self): | ||||
|     self.eval_names  = [] | ||||
|     self.eval_accs   = {} | ||||
|     self.eval_acc1es = {} | ||||
|     self.eval_times  = {} | ||||
|     self.eval_losses = {} | ||||
| 
 | ||||
|   def update_latency(self, latency): | ||||
|     self.latency = copy.deepcopy( latency ) | ||||
| 
 | ||||
|   def update_eval(self, accs, losses, times): # old version | ||||
|     data_names = set([x.split('@')[0] for x in accs.keys()]) | ||||
|     for data_name in data_names: | ||||
|       assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name) | ||||
|       self.eval_names.append( data_name ) | ||||
|       for iepoch in range(self.epochs): | ||||
|         xkey = '{:}@{:}'.format(data_name, iepoch) | ||||
|         self.eval_acc1es[ xkey ] = accs[ xkey ] | ||||
|         self.eval_losses[ xkey ] = losses[ xkey ] | ||||
|         self.eval_times [ xkey ] = times[ xkey ] | ||||
| 
 | ||||
|   def update_OLD_eval(self, name, accs, losses): # old version | ||||
|     assert name not in self.eval_names, '{:} has already added'.format(name) | ||||
|     self.eval_names.append( name ) | ||||
|     for iepoch in range(self.epochs): | ||||
|       if iepoch in accs: | ||||
|         self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch] | ||||
|         self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch] | ||||
| 
 | ||||
|   def __repr__(self): | ||||
|     num_eval = len(self.eval_names) | ||||
|     set_name = '[' + ', '.join(self.eval_names) + ']' | ||||
|     return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name)) | ||||
| 
 | ||||
|   def get_latency(self): | ||||
|     if self.latency is None: return -1 | ||||
|     else: return sum(self.latency) / len(self.latency) | ||||
| 
 | ||||
|   def update_eval(self, name, accs, losses): | ||||
|     assert name not in self.eval_names, '{:} has already added'.format(name) | ||||
|     self.eval_names.append( name ) | ||||
|     self.eval_accs[name] = copy.deepcopy( accs ) | ||||
|     self.eval_losses[name] = copy.deepcopy( losses ) | ||||
|   def get_times(self): | ||||
|     if self.train_times is not None and isinstance(self.train_times, dict): | ||||
|       train_times = list( self.train_times.values() ) | ||||
|       time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)} | ||||
|       for name in self.eval_names: | ||||
|         xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)] | ||||
|         time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes) | ||||
|         time_info['T-{:}@total'.format(name)] = np.sum(xtimes) | ||||
|     else: | ||||
|       time_info = {'T-train@epoch':                 None, 'T-train@total':               None } | ||||
|       for name in self.eval_names: | ||||
|         time_info['T-{:}@epoch'.format(name)] = None | ||||
|         time_info['T-{:}@total'.format(name)] = None | ||||
|     return time_info | ||||
| 
 | ||||
|   def __repr__(self): | ||||
|     num_eval = len(self.eval_names) | ||||
|     return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets)'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval)) | ||||
| 
 | ||||
|   def valid_evaluation_set(self): | ||||
|   def get_eval_set(self): | ||||
|     return self.eval_names | ||||
| 
 | ||||
|   def get_train(self, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     return {'loss': self.train_losses[iepoch], 'accuracy': self.train_accs[iepoch]} | ||||
|     if self.train_times is not None: xtime = self.train_times[iepoch] | ||||
|     else                           : xtime = None | ||||
|     return {'iepoch'  : iepoch, | ||||
|             'loss'    : self.train_losses[iepoch], | ||||
|             'accuracy': self.train_acc1es[iepoch], | ||||
|             'time'    : xtime} | ||||
| 
 | ||||
|   def get_eval(self, name, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     return {'loss': self.eval_losses[name][iepoch], 'accuracy': self.eval_accs[name][iepoch]} | ||||
|     if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: | ||||
|       xtime = self.eval_times['{:}@{:}'.format(name,iepoch)] | ||||
|     else: xtime = None | ||||
|     return {'iepoch'  : iepoch, | ||||
|             'loss'    : self.eval_losses['{:}@{:}'.format(name,iepoch)], | ||||
|             'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)], | ||||
|             'time'    : xtime} | ||||
| 
 | ||||
|   def get_net_param(self): | ||||
|     return self.net_state_dict | ||||
| 
 | ||||
|   def get_config(self, str2structure): | ||||
|     #return copy.deepcopy(self.arch_config) | ||||
|     return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ | ||||
|             'N'   : self.arch_config['num_cells'], \ | ||||
|             'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} | ||||
| 
 | ||||
|   def state_dict(self): | ||||
|     _state_dict = {key: value for key, value in self.__dict__.items()} | ||||
|     return _state_dict | ||||
		Reference in New Issue
	
	Block a user