update GDAS
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, copy, torch, numpy as np | ||||
| import os, sys, copy, random, torch, numpy as np | ||||
| from collections import OrderedDict | ||||
|  | ||||
|  | ||||
| @@ -149,7 +149,7 @@ class ArchResults(object): | ||||
|     lantencies = [result.get_latency() for result in results] | ||||
|     return np.mean(flops), np.mean(params), np.mean(lantencies) | ||||
|  | ||||
|   def get_metrics(self, dataset, setname, iepoch=None): | ||||
|   def get_metrics(self, dataset, setname, iepoch=None, is_random=False): | ||||
|     x_seeds = self.dataset_seed[dataset] | ||||
|     results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] | ||||
|     loss, accuracy = [], [] | ||||
| @@ -160,7 +160,11 @@ class ArchResults(object): | ||||
|         info = result.get_eval(setname, iepoch) | ||||
|       loss.append( info['loss'] ) | ||||
|       accuracy.append( info['accuracy'] ) | ||||
|     return float(np.mean(loss)), float(np.mean(accuracy)) | ||||
|     if is_random: | ||||
|       index = random.randint(0, len(loss)-1) | ||||
|       return loss[index], accuracy[index] | ||||
|     else: | ||||
|       return float(np.mean(loss)), float(np.mean(accuracy)) | ||||
|  | ||||
|   def show(self, is_print=False): | ||||
|     return print_information(self, None, is_print) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user