Update for Rebuttal
This commit is contained in:
		| @@ -95,7 +95,7 @@ def mutate_size_func(info): | ||||
|   return mutate_size_func | ||||
|  | ||||
|  | ||||
| def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, dataset): | ||||
| def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset): | ||||
|   """Algorithm for regularized evolution (i.e. aging evolution). | ||||
|    | ||||
|   Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image | ||||
| @@ -119,7 +119,10 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran | ||||
|   while len(population) < population_size: | ||||
|     model = Model() | ||||
|     model.arch = random_arch() | ||||
|     model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12') | ||||
|     if use_proxy: | ||||
|       model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12') | ||||
|     else: | ||||
|       model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs) | ||||
|     # Append the info | ||||
|     population.append(model) | ||||
|     history.append((model.accuracy, model.arch)) | ||||
| @@ -171,7 +174,11 @@ def main(xargs, api): | ||||
|   x_start_time = time.time() | ||||
|   logger.log('{:} use api : {:}'.format(time_string(), api)) | ||||
|   logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) | ||||
|   history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset) | ||||
|   history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, | ||||
|                                                                    xargs.ea_population, | ||||
|                                                                    xargs.ea_sample_size, | ||||
|                                                                    xargs.time_budget, | ||||
|                                                                    random_arch, mutate_arch, api, xargs.use_proxy > 0, xargs.dataset) | ||||
|   logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time)) | ||||
|   best_arch = max(history, key=lambda x: x[0])[1] | ||||
|   logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) | ||||
| @@ -187,11 +194,13 @@ if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("Regularized Evolution Algorithm") | ||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||
|   parser.add_argument('--search_space',       type=str,   choices=['tss', 'sss'], help='Choose the search space.') | ||||
|   # channels and number-of-cells | ||||
|   # hyperparameters for REA | ||||
|   parser.add_argument('--ea_cycles',          type=int,   help='The number of cycles in EA.') | ||||
|   parser.add_argument('--ea_population',      type=int,   help='The population size in EA.') | ||||
|   parser.add_argument('--ea_sample_size',     type=int,   help='The sample size in EA.') | ||||
|   parser.add_argument('--time_budget',        type=int,   default=20000, help='The total time cost budge for searching (in seconds).') | ||||
|   parser.add_argument('--use_proxy',          type=int,   default=1,     help='Whether to use the proxy (H0) task or not.') | ||||
|   # | ||||
|   parser.add_argument('--loops_if_rand',      type=int,   default=500,   help='The total runs for evaluation.') | ||||
|   # log | ||||
|   parser.add_argument('--save_dir',           type=str,   default='./output/search', help='Folder to save checkpoints and log.') | ||||
| @@ -201,7 +210,8 @@ if __name__ == '__main__': | ||||
|   api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||
|  | ||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), | ||||
|                                '{:}-T{:}'.format(args.dataset, args.time_budget), 'R-EA-SS{:}'.format(args.ea_sample_size)) | ||||
|                                '{:}-T{:}{:}'.format(args.dataset, args.time_budget, '' if args.use_proxy > 0 else '-FULL'), | ||||
|                                'R-EA-SS{:}'.format(args.ea_sample_size)) | ||||
|   print('save-dir : {:}'.format(args.save_dir)) | ||||
|   print('xargs : {:}'.format(args)) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user