Update REA, REINFORCE, RANDOM, and BOHB
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -123,3 +123,5 @@ scripts-search/l2s-algos | |||||||
| TEMP-L.sh | TEMP-L.sh | ||||||
|  |  | ||||||
| .nfs00* | .nfs00* | ||||||
|  | *.swo | ||||||
|  | */*.swo | ||||||
|   | |||||||
| @@ -5,9 +5,9 @@ | |||||||
| # required to install hpbandster ################################## | # required to install hpbandster ################################## | ||||||
| # pip install hpbandster         ################################## | # pip install hpbandster         ################################## | ||||||
| ################################################################### | ################################################################### | ||||||
| # python exps/algos-v2/bohb.py --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | ||||||
| ################################################################### | ################################################################### | ||||||
| import os, sys, time, random, argparse | import os, sys, time, random, argparse, collections | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| import torch | import torch | ||||||
| @@ -17,7 +17,7 @@ from config_utils import load_config | |||||||
| from datasets     import get_datasets, SearchDataset | from datasets     import get_datasets, SearchDataset | ||||||
| from procedures   import prepare_seed, prepare_logger | from procedures   import prepare_seed, prepare_logger | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| from nas_201_api  import NASBench201API as API | from nas_201_api  import NASBench201API, NASBench301API | ||||||
| from models       import CellStructure, get_search_spaces | from models       import CellStructure, get_search_spaces | ||||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | ||||||
| import ConfigSpace | import ConfigSpace | ||||||
| @@ -63,52 +63,21 @@ def config2topology_func(max_nodes=4): | |||||||
|  |  | ||||||
| class MyWorker(Worker): | class MyWorker(Worker): | ||||||
|  |  | ||||||
|   def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs): |   def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): | ||||||
|     super().__init__(*args, **kwargs) |     super().__init__(*args, **kwargs) | ||||||
|     self.convert_func   = convert_func |     self.convert_func   = convert_func | ||||||
|     self._dataname      = dataname |     self._dataset       = dataset | ||||||
|     self._nas_bench     = nas_bench |     self._api           = api | ||||||
|     self.time_budget    = time_budget |     self.total_times    = [] | ||||||
|     self.seen_archs     = [] |     self.trajectory     = [] | ||||||
|     self.sim_cost_time  = 0 |  | ||||||
|     self.real_cost_time = 0 |  | ||||||
|     self.is_end         = False |  | ||||||
|  |  | ||||||
|   def get_the_best(self): |  | ||||||
|     assert len(self.seen_archs) > 0 |  | ||||||
|     best_index, best_acc = -1, None |  | ||||||
|     for arch_index in self.seen_archs: |  | ||||||
|       info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) |  | ||||||
|       vacc = info['valid-accuracy'] |  | ||||||
|       if best_acc is None or best_acc < vacc: |  | ||||||
|         best_acc = vacc |  | ||||||
|         best_index = arch_index |  | ||||||
|     assert best_index != -1 |  | ||||||
|     return best_index |  | ||||||
|  |  | ||||||
|   def compute(self, config, budget, **kwargs): |   def compute(self, config, budget, **kwargs): | ||||||
|     start_time = time.time() |     arch  = self.convert_func( config ) | ||||||
|     structure  = self.convert_func( config ) |     accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(arch, self._dataset, iepoch=int(budget)-1, hp='12') | ||||||
|     arch_index = self._nas_bench.query_index_by_arch( structure ) |     self.trajectory.append((accuracy, arch)) | ||||||
|     info       = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) |     self.total_times.append(total_time) | ||||||
|     cur_time   = info['train-all-time'] + info['valid-per-time'] |     return ({'loss': 100 - accuracy, | ||||||
|     cur_vacc   = info['valid-accuracy'] |              'info': self._api.query_index_by_arch(arch)}) | ||||||
|     self.real_cost_time += (time.time() - start_time) |  | ||||||
|     if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: |  | ||||||
|       self.sim_cost_time += cur_time |  | ||||||
|       self.seen_archs.append( arch_index ) |  | ||||||
|       return ({'loss': 100 - float(cur_vacc), |  | ||||||
|                'info': {'seen-arch'     : len(self.seen_archs), |  | ||||||
|                         'sim-test-time' : self.sim_cost_time, |  | ||||||
|                         'current-arch'  : arch_index} |  | ||||||
|             }) |  | ||||||
|     else: |  | ||||||
|       self.is_end = True |  | ||||||
|       return ({'loss': 100, |  | ||||||
|                'info': {'seen-arch'     : len(self.seen_archs), |  | ||||||
|                         'sim-test-time' : self.sim_cost_time, |  | ||||||
|                         'current-arch'  : None} |  | ||||||
|             }) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs, api): | def main(xargs, api): | ||||||
| @@ -117,12 +86,13 @@ def main(xargs, api): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   logger.log('{:} use api : {:}'.format(time_string(), api)) |   logger.log('{:} use api : {:}'.format(time_string(), api)) | ||||||
|  |   api.reset_time() | ||||||
|   search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') |   search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') | ||||||
|   if xargs.search_space == 'tss': |   if xargs.search_space == 'tss': | ||||||
|   	cs = get_topology_config_space(xargs.max_nodes, search_space) |   	cs = get_topology_config_space(search_space) | ||||||
|   	config2structure = config2topology_func(xargs.max_nodes) |   	config2structure = config2topology_func() | ||||||
|   else: |   else: | ||||||
|   	cs = get_size_config_space(xargs.max_nodes, search_space) |     cs = get_size_config_space(search_space) | ||||||
|     import pdb; pdb.set_trace() |     import pdb; pdb.set_trace() | ||||||
|    |    | ||||||
|   hb_run_id = '0' |   hb_run_id = '0' | ||||||
| @@ -133,14 +103,13 @@ def main(xargs, api): | |||||||
|  |  | ||||||
|   workers = [] |   workers = [] | ||||||
|   for i in range(num_workers): |   for i in range(num_workers): | ||||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) |     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataset=xargs.dataset, api=api, run_id=hb_run_id, id=i) | ||||||
|     w.run(background=True) |     w.run(background=True) | ||||||
|     workers.append(w) |     workers.append(w) | ||||||
|  |  | ||||||
|   start_time = time.time() |   start_time = time.time() | ||||||
|   bohb = BOHB(configspace=cs, |   bohb = BOHB(configspace=cs, run_id=hb_run_id, | ||||||
|             run_id=hb_run_id, |       eta=3, min_budget=1, max_budget=12, | ||||||
|             eta=3, min_budget=12, max_budget=200, |  | ||||||
|       nameserver=ns_host, |       nameserver=ns_host, | ||||||
|       nameserver_port=ns_port, |       nameserver_port=ns_port, | ||||||
|       num_samples=xargs.num_samples, |       num_samples=xargs.num_samples, | ||||||
| @@ -152,22 +121,23 @@ def main(xargs, api): | |||||||
|   bohb.shutdown(shutdown_workers=True) |   bohb.shutdown(shutdown_workers=True) | ||||||
|   NS.shutdown() |   NS.shutdown() | ||||||
|  |  | ||||||
|   real_cost_time = time.time() - start_time |   # print('There are {:} runs.'.format(len(results.get_all_runs()))) | ||||||
|  |   # workers[0].total_times | ||||||
|  |   # workers[0].trajectory | ||||||
|  |   current_best_index = [] | ||||||
|  |   for idx in range(len(workers[0].trajectory)): | ||||||
|  |     trajectory = workers[0].trajectory[:idx+1] | ||||||
|  |     arch = max(trajectory, key=lambda x: x[0])[1] | ||||||
|  |     current_best_index.append(api.query_index_by_arch(arch)) | ||||||
|    |    | ||||||
|   id2config = results.get_id2config_mapping() |   best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1] | ||||||
|   incumbent = results.get_incumbent_id() |   logger.log('Best found configuration: {:} within {:.3f} s'.format(best_arch, workers[0].total_times[-1])) | ||||||
|   logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time)) |   info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') | ||||||
|   best_arch = config2structure( id2config[incumbent]['config'] ) |   logger.log('{:}'.format(info)) | ||||||
|  |  | ||||||
|   info = nas_bench.query_by_arch(best_arch, '200') |  | ||||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) |  | ||||||
|   else           : logger.log('{:}'.format(info)) |  | ||||||
|   logger.log('-'*100) |   logger.log('-'*100) | ||||||
|  |  | ||||||
|   logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) |  | ||||||
|   logger.close() |   logger.close() | ||||||
|   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time |  | ||||||
|  |  | ||||||
|  |   return logger.log_dir, current_best_index, workers[0].total_times | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
| @@ -185,8 +155,8 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--bandwidth_factor', default=3,   type=int, nargs='?', help='factor multiplied to the bandwidth') |   parser.add_argument('--bandwidth_factor', default=3,   type=int, nargs='?', help='factor multiplied to the bandwidth') | ||||||
|   parser.add_argument('--n_iters',          default=300, type=int, nargs='?', help='number of iterations for optimization method') |   parser.add_argument('--n_iters',          default=300, type=int, nargs='?', help='number of iterations for optimization method') | ||||||
|   # log |   # log | ||||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',           type=str,  default='./output/search', help='Folder to save checkpoints and log.') | ||||||
|   parser.add_argument('--rand_seed',          type=int,   help='manual seed') |   parser.add_argument('--rand_seed',          type=int,  default=-1, help='manual seed') | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|    |    | ||||||
|   if args.search_space == 'tss': |   if args.search_space == 'tss': | ||||||
|   | |||||||
| @@ -43,7 +43,7 @@ def main(xargs, api): | |||||||
|   current_best_index = [] |   current_best_index = [] | ||||||
|   while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: |   while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: | ||||||
|     arch = random_arch() |     arch = random_arch() | ||||||
|     accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') |     accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12') | ||||||
|     total_time_cost.append(total_cost) |     total_time_cost.append(total_cost) | ||||||
|     history.append(arch) |     history.append(arch) | ||||||
|     if best_arch is None or best_acc < accuracy: |     if best_arch is None or best_acc < accuracy: | ||||||
|   | |||||||
| @@ -160,7 +160,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran | |||||||
|   while len(population) < population_size: |   while len(population) < population_size: | ||||||
|     model = Model() |     model = Model() | ||||||
|     model.arch = random_arch() |     model.arch = random_arch() | ||||||
|     model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12') |     model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12') | ||||||
|     # Append the info |     # Append the info | ||||||
|     population.append(model) |     population.append(model) | ||||||
|     history.append((model.accuracy, model.arch)) |     history.append((model.accuracy, model.arch)) | ||||||
| @@ -184,7 +184,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran | |||||||
|     # Create the child model and store it. |     # Create the child model and store it. | ||||||
|     child = Model() |     child = Model() | ||||||
|     child.arch = mutate_arch(parent.arch) |     child.arch = mutate_arch(parent.arch) | ||||||
|     child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, '12') |     child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, hp='12') | ||||||
|     # Append the info |     # Append the info | ||||||
|     population.append(child) |     population.append(child) | ||||||
|     history.append((child.accuracy, child.arch)) |     history.append((child.accuracy, child.arch)) | ||||||
|   | |||||||
| @@ -150,7 +150,7 @@ def main(xargs, api): | |||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|     log_prob, action = select_action( policy ) |     log_prob, action = select_action( policy ) | ||||||
|     arch   = policy.generate_arch( action ) |     arch   = policy.generate_arch( action ) | ||||||
|     reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') |     reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12') | ||||||
|     trace.append((reward, arch)) |     trace.append((reward, arch)) | ||||||
|     total_costs.append(current_total_cost) |     total_costs.append(current_total_cost) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,18 +1,19 @@ | |||||||
| #!/bin/bash | #!/bin/bash | ||||||
| # bash ./exps/algos-v2/run-all.sh | # bash ./exps/algos-v2/run-all.sh | ||||||
|  | set -e | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
| echo $# arguments | echo $# arguments | ||||||
|  |  | ||||||
| datasets="cifar10 cifar100 ImageNet16-120" | datasets="cifar10 cifar100 ImageNet16-120" | ||||||
| search_spaces="tss sss" | search_spaces="tss sss" | ||||||
|  |  | ||||||
|  |  | ||||||
| for dataset in ${datasets} | for dataset in ${datasets} | ||||||
| do | do | ||||||
|   for search_space in ${search_spaces} |   for search_space in ${search_spaces} | ||||||
|   do |   do | ||||||
|     # python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 |     python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 | ||||||
|     python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 |     python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||||
|     # python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} |     python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} | ||||||
|  |     python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|   done |   done | ||||||
| done | done | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ | |||||||
| ############################################################### | ############################################################### | ||||||
| # Usage: python exps/experimental/vis-bench-algos.py          # | # Usage: python exps/experimental/vis-bench-algos.py          # | ||||||
| ############################################################### | ############################################################### | ||||||
| import os, sys, time, torch, argparse | import os, gc, sys, time, torch, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import List, Text, Dict, Any | from typing import List, Text, Dict, Any | ||||||
| from shutil import copyfile | from shutil import copyfile | ||||||
| @@ -31,6 +31,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | |||||||
|   alg2name['REA'] = 'R-EA-SS3' |   alg2name['REA'] = 'R-EA-SS3' | ||||||
|   alg2name['REINFORCE'] = 'REINFORCE-0.001' |   alg2name['REINFORCE'] = 'REINFORCE-0.001' | ||||||
|   alg2name['RANDOM'] = 'RANDOM' |   alg2name['RANDOM'] = 'RANDOM' | ||||||
|  |   alg2name['BOHB'] = 'BOHB' | ||||||
|   for alg, name in alg2name.items(): |   for alg, name in alg2name.items(): | ||||||
|     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') |     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') | ||||||
|     assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) |     assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) | ||||||
| @@ -58,14 +59,27 @@ def query_performance(api, data, dataset, ticket): | |||||||
|     results.append(interplate) |     results.append(interplate) | ||||||
|   return sum(results) / len(results) |   return sum(results) / len(results) | ||||||
|  |  | ||||||
|  | y_min_s = {('cifar10', 'tss'): 90, | ||||||
|  |            ('cifar10', 'sss'): 92, | ||||||
|  |            ('cifar100', 'tss'): 65, | ||||||
|  |            ('cifar100', 'sss'): 65, | ||||||
|  |            ('ImageNet16-120', 'tss'): 36, | ||||||
|  |            ('ImageNet16-120', 'sss'): 40} | ||||||
|  |  | ||||||
|  | y_max_s = {('cifar10', 'tss'): 94.5, | ||||||
|  |            ('cifar10', 'sss'): 93.3, | ||||||
|  |            ('cifar100', 'tss'): 72, | ||||||
|  |            ('cifar100', 'sss'): 70, | ||||||
|  |            ('ImageNet16-120', 'tss'): 44, | ||||||
|  |            ('ImageNet16-120', 'sss'): 46} | ||||||
|  |  | ||||||
| def visualize_curve(api, vis_save_dir, search_space, max_time): | def visualize_curve(api, vis_save_dir, search_space, max_time): | ||||||
|   vis_save_dir = vis_save_dir.resolve() |   vis_save_dir = vis_save_dir.resolve() | ||||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|   dpi, width, height = 250, 5100, 1500 |   dpi, width, height = 250, 5200, 1400 | ||||||
|   figsize = width / float(dpi), height / float(dpi) |   figsize = width / float(dpi), height / float(dpi) | ||||||
|   LabelSize, LegendFontsize = 14, 14 |   LabelSize, LegendFontsize = 16, 16 | ||||||
|  |  | ||||||
|   def sub_plot_fn(ax, dataset): |   def sub_plot_fn(ax, dataset): | ||||||
|     alg2data = fetch_data(search_space=search_space, dataset=dataset) |     alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||||
| @@ -73,6 +87,8 @@ def visualize_curve(api, vis_save_dir, search_space, max_time): | |||||||
|     total_tickets = 150 |     total_tickets = 150 | ||||||
|     time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)] |     time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)] | ||||||
|     colors = ['b', 'g', 'c', 'm', 'y'] |     colors = ['b', 'g', 'c', 'm', 'y'] | ||||||
|  |     ax.set_xlim(0, 200) | ||||||
|  |     ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) | ||||||
|     for idx, (alg, data) in enumerate(alg2data.items()): |     for idx, (alg, data) in enumerate(alg2data.items()): | ||||||
|       print('plot alg : {:}'.format(alg)) |       print('plot alg : {:}'.format(alg)) | ||||||
|       accuracies = [] |       accuracies = [] | ||||||
| @@ -107,5 +123,7 @@ if __name__ == '__main__': | |||||||
|  |  | ||||||
|   api201 = NASBench201API(verbose=False) |   api201 = NASBench201API(verbose=False) | ||||||
|   visualize_curve(api201, save_dir, 'tss', args.max_time) |   visualize_curve(api201, save_dir, 'tss', args.max_time) | ||||||
|  |   del api201 | ||||||
|  |   gc.collect() | ||||||
|   api301 = NASBench301API(verbose=False) |   api301 = NASBench301API(verbose=False) | ||||||
|   visualize_curve(api301, save_dir, 'sss', args.max_time) |   visualize_curve(api301, save_dir, 'sss', args.max_time) | ||||||
|   | |||||||
| @@ -68,14 +68,14 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | |||||||
|   def reset_time(self): |   def reset_time(self): | ||||||
|     self._used_time = 0 |     self._used_time = 0 | ||||||
|  |  | ||||||
|   def simulate_train_eval(self, arch, dataset, hp='12', account_time=True): |   def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True): | ||||||
|     index = self.query_index_by_arch(arch) |     index = self.query_index_by_arch(arch) | ||||||
|     all_names = ('cifar10', 'cifar100', 'ImageNet16-120') |     all_names = ('cifar10', 'cifar100', 'ImageNet16-120') | ||||||
|     assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names) |     assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names) | ||||||
|     if dataset == 'cifar10': |     if dataset == 'cifar10': | ||||||
|       info = self.get_more_info(index, 'cifar10-valid', iepoch=None, hp=hp, is_random=True) |       info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True) | ||||||
|     else: |     else: | ||||||
|       info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True) |       info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True) | ||||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] |     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||||
|     latency = self.get_latency(index, dataset) |     latency = self.get_latency(index, dataset) | ||||||
|     if account_time: |     if account_time: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user