update vis
This commit is contained in:
		| @@ -131,15 +131,17 @@ class NASBench102API(object): | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     best_index, highest_accuracy = -1, None | ||||
|     for i, idx in enumerate(self.evaluated_indexes): | ||||
|       flop, param, latency = arch2infos[idx].get_comput_costs(dataset) | ||||
|       info = arch2infos[idx].get_comput_costs(dataset) | ||||
|       flop, param, latency = info['flops'], info['params'], info['latency'] | ||||
|       if FLOP_max  is not None and flop  > FLOP_max : continue | ||||
|       if Param_max is not None and param > Param_max: continue | ||||
|       loss, accuracy = arch2infos[idx].get_metrics(dataset, metric_on_set) | ||||
|       xinfo = arch2infos[idx].get_metrics(dataset, metric_on_set) | ||||
|       loss, accuracy = xinfo['loss'], xinfo['accuracy'] | ||||
|       if best_index == -1: | ||||
|         best_index, highest_accuracy = idx, accuracy | ||||
|       elif highest_accuracy < accuracy: | ||||
|         best_index, highest_accuracy = idx, accuracy | ||||
|     return best_index | ||||
|     return best_index, highest_accuracy | ||||
|  | ||||
|   # return the topology structure of the `index`-th architecture | ||||
|   def arch(self, index): | ||||
| @@ -183,10 +185,18 @@ class NASBench102API(object): | ||||
|         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|       else: | ||||
|         test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||
|       return {'train-loss'    : train_info['loss'], | ||||
|       try: | ||||
|         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         valid_info = None | ||||
|       xifo = {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy'], | ||||
|               'test-loss'     : test__info['loss'], | ||||
|               'test-accuracy' : test__info['accuracy']} | ||||
|       if valid_info is not None: | ||||
|         xifo['valid-loss'] = valid_info['loss'] | ||||
|         xifo['valid-accuracy'] = valid_info['accuracy'] | ||||
|       return xifo | ||||
|  | ||||
|   def show(self, index=-1): | ||||
|     if index < 0: # show all architectures | ||||
|   | ||||
		Reference in New Issue
	
	Block a user