update NAS-Bench-102 baselines
This commit is contained in:
		| @@ -62,11 +62,12 @@ class MyWorker(Worker): | |||||||
|  |  | ||||||
|   def compute(self, config, budget, **kwargs): |   def compute(self, config, budget, **kwargs): | ||||||
|     structure = self.convert_func( config ) |     structure = self.convert_func( config ) | ||||||
|     reward    = train_and_eval(structure, self.nas_bench, None) |     reward, time_cost = train_and_eval(structure, self.nas_bench, None) | ||||||
|  |     import pdb; pdb.set_trace() | ||||||
|     self.test_time += 1 |     self.test_time += 1 | ||||||
|     return ({ |     return ({ | ||||||
|             'loss': float(100-reward), |             'loss': float(100-reward), | ||||||
|             'info': None}) |             'info': time_cost}) | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs, nas_bench): | def main(xargs, nas_bench): | ||||||
| @@ -121,7 +122,7 @@ def main(xargs, nas_bench): | |||||||
|  |  | ||||||
|   bohb = BOHB(configspace=cs, |   bohb = BOHB(configspace=cs, | ||||||
|             run_id=hb_run_id, |             run_id=hb_run_id, | ||||||
|             eta=3, min_budget=3, max_budget=108, |             eta=3, min_budget=3, max_budget=xargs.time_budget, | ||||||
|             nameserver=ns_host, |             nameserver=ns_host, | ||||||
|             nameserver_port=ns_port, |             nameserver_port=ns_port, | ||||||
|             num_samples=xargs.num_samples, |             num_samples=xargs.num_samples, | ||||||
| @@ -130,6 +131,7 @@ def main(xargs, nas_bench): | |||||||
|   #          optimization_strategy=xargs.strategy, num_samples=xargs.num_samples, |   #          optimization_strategy=xargs.strategy, num_samples=xargs.num_samples, | ||||||
|    |    | ||||||
|   results = bohb.run(xargs.n_iters, min_n_workers=num_workers) |   results = bohb.run(xargs.n_iters, min_n_workers=num_workers) | ||||||
|  |   import pdb; pdb.set_trace() | ||||||
|  |  | ||||||
|   bohb.shutdown(shutdown_workers=True) |   bohb.shutdown(shutdown_workers=True) | ||||||
|   NS.shutdown() |   NS.shutdown() | ||||||
| @@ -160,9 +162,10 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') |   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||||
|  |   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') | ||||||
|   # BOHB |   # BOHB | ||||||
|   parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') |   parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') | ||||||
|   parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') |   parser.add_argument('--min_bandwidth',    default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') | ||||||
|   parser.add_argument('--num_samples',      default=64, type=int, nargs='?', help='number of samples for the acquisition function') |   parser.add_argument('--num_samples',      default=64, type=int, nargs='?', help='number of samples for the acquisition function') | ||||||
|   parser.add_argument('--random_fraction',  default=.33, type=float, nargs='?', help='fraction of random configurations') |   parser.add_argument('--random_fraction',  default=.33, type=float, nargs='?', help='fraction of random configurations') | ||||||
|   parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') |   parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') | ||||||
|   | |||||||
| @@ -82,6 +82,16 @@ def valid_func(xloader, network, criterion): | |||||||
|   return arch_losses.avg, arch_top1.avg, arch_top5.avg |   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def search_find_best(valid_loader, network, criterion, select_num): | ||||||
|  |   best_arch, best_acc = None, -1 | ||||||
|  |   for iarch in range(select_num): | ||||||
|  |     arch = network.module.random_genotype( True ) | ||||||
|  |     valid_a_loss, valid_a_top1, valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||||
|  |     if best_arch is None or best_acc < valid_a_top1: | ||||||
|  |       best_arch, best_acc = arch, valid_a_top1 | ||||||
|  |   return best_arch | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs): | def main(xargs): | ||||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' |   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||||
|   torch.backends.cudnn.enabled   = True |   torch.backends.cudnn.enabled   = True | ||||||
| @@ -143,6 +153,7 @@ def main(xargs): | |||||||
|     last_info   = torch.load(last_info) |     last_info   = torch.load(last_info) | ||||||
|     start_epoch = last_info['epoch'] |     start_epoch = last_info['epoch'] | ||||||
|     checkpoint  = torch.load(last_info['last_checkpoint']) |     checkpoint  = torch.load(last_info['last_checkpoint']) | ||||||
|  |     genotypes   = checkpoint['genotypes'] | ||||||
|     valid_accuracies = checkpoint['valid_accuracies'] |     valid_accuracies = checkpoint['valid_accuracies'] | ||||||
|     search_model.load_state_dict( checkpoint['search_model'] ) |     search_model.load_state_dict( checkpoint['search_model'] ) | ||||||
|     w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) |     w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) | ||||||
| @@ -150,7 +161,7 @@ def main(xargs): | |||||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) |     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||||
|   else: |   else: | ||||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) |     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||||
|     start_epoch, valid_accuracies = 0, {'best': -1} |     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} | ||||||
|  |  | ||||||
|   # start training |   # start training | ||||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
| @@ -160,11 +171,14 @@ def main(xargs): | |||||||
|     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) |     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) | ||||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) |     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) | ||||||
|  |  | ||||||
|  |     # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) | ||||||
|     search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) |     search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) | ||||||
|     search_time.update(time.time() - start_time) |     search_time.update(time.time() - start_time) | ||||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) |     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) |     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||||
|     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) |     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||||
|  |     cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) | ||||||
|  |     genotypes[epoch] = cur_arch | ||||||
|     # check the best accuracy |     # check the best accuracy | ||||||
|     valid_accuracies[epoch] = valid_a_top1 |     valid_accuracies[epoch] = valid_a_top1 | ||||||
|     if valid_a_top1 > valid_accuracies['best']: |     if valid_a_top1 > valid_accuracies['best']: | ||||||
| @@ -178,6 +192,7 @@ def main(xargs): | |||||||
|                 'search_model': search_model.state_dict(), |                 'search_model': search_model.state_dict(), | ||||||
|                 'w_optimizer' : w_optimizer.state_dict(), |                 'w_optimizer' : w_optimizer.state_dict(), | ||||||
|                 'w_scheduler' : w_scheduler.state_dict(), |                 'w_scheduler' : w_scheduler.state_dict(), | ||||||
|  |                 'genotypes'   : genotypes, | ||||||
|                 'valid_accuracies' : valid_accuracies}, |                 'valid_accuracies' : valid_accuracies}, | ||||||
|                 model_base_path, logger) |                 model_base_path, logger) | ||||||
|     last_info = save_checkpoint({ |     last_info = save_checkpoint({ | ||||||
| @@ -188,6 +203,7 @@ def main(xargs): | |||||||
|     if find_best: |     if find_best: | ||||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) |       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) | ||||||
|       copy_checkpoint(model_base_path, model_best_path, logger) |       copy_checkpoint(model_base_path, model_best_path, logger) | ||||||
|  |     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||||
|     # measure elapsed time |     # measure elapsed time | ||||||
|     epoch_time.update(time.time() - start_time) |     epoch_time.update(time.time() - start_time) | ||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
| @@ -202,7 +218,6 @@ def main(xargs): | |||||||
|     logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss)) |     logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss)) | ||||||
|     if best_arch is None or best_acc < valid_a_top1: |     if best_arch is None or best_acc < valid_a_top1: | ||||||
|       best_arch, best_acc = arch, valid_a_top1 |       best_arch, best_acc = arch, valid_a_top1 | ||||||
|  |  | ||||||
|   search_time.update(time.time() - start_time) |   search_time.update(time.time() - start_time) | ||||||
|   logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) |   logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) | ||||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) |   if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ from datasets     import get_datasets, SearchDataset | |||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| from nas_102_api  import NASBench102API | from nas_102_api  import NASBench102API as API | ||||||
| from models       import CellStructure, get_search_spaces | from models       import CellStructure, get_search_spaces | ||||||
| from R_EA import train_and_eval | from R_EA import train_and_eval | ||||||
|  |  | ||||||
| @@ -132,10 +132,18 @@ def main(xargs, nas_bench): | |||||||
|  |  | ||||||
|   # REINFORCE |   # REINFORCE | ||||||
|   # attempts = 0 |   # attempts = 0 | ||||||
|   for istep in range(xargs.RL_steps): |   logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget)) | ||||||
|  |   total_steps, total_costs = 0, 0 | ||||||
|  |   #for istep in range(xargs.RL_steps): | ||||||
|  |   while total_costs < xargs.time_budget: | ||||||
|  |     start_time = time.time() | ||||||
|     log_prob, action = select_action( policy ) |     log_prob, action = select_action( policy ) | ||||||
|     arch   = policy.generate_arch( action ) |     arch   = policy.generate_arch( action ) | ||||||
|     reward = train_and_eval(arch, nas_bench, extra_info) |     reward, cost_time = train_and_eval(arch, nas_bench, extra_info) | ||||||
|  |     # accumulate time | ||||||
|  |     if total_costs + cost_time < xargs.time_budget: | ||||||
|  |       total_costs += cost_time | ||||||
|  |     else: break | ||||||
|  |  | ||||||
|     baseline.update(reward) |     baseline.update(reward) | ||||||
|     # calculate loss |     # calculate loss | ||||||
| @@ -143,13 +151,15 @@ def main(xargs, nas_bench): | |||||||
|     optimizer.zero_grad() |     optimizer.zero_grad() | ||||||
|     policy_loss.backward() |     policy_loss.backward() | ||||||
|     optimizer.step() |     optimizer.step() | ||||||
|  |     # accumulate time | ||||||
|     logger.log('step [{:3d}/{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(istep, xargs.RL_steps, baseline.value(), policy_loss.item(), policy.genotype())) |     total_costs += time.time() - start_time | ||||||
|  |     total_steps += 1 | ||||||
|  |     logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype())) | ||||||
|     #logger.log('----> {:}'.format(policy.arch_parameters)) |     #logger.log('----> {:}'.format(policy.arch_parameters)) | ||||||
|     logger.log('') |     #logger.log('') | ||||||
|  |  | ||||||
|   best_arch = policy.genotype() |   best_arch = policy.genotype() | ||||||
|  |   logger.log('REINFORCE finish with {:} steps and {:.1f} s.'.format(total_steps, total_costs)) | ||||||
|   info = nas_bench.query_by_arch( best_arch ) |   info = nas_bench.query_by_arch( best_arch ) | ||||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) |   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||||
|   else           : logger.log('{:}'.format(info)) |   else           : logger.log('{:}'.format(info)) | ||||||
| @@ -169,8 +179,9 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||||
|   parser.add_argument('--learning_rate',      type=float, help='The learning rate for REINFORCE.') |   parser.add_argument('--learning_rate',      type=float, help='The learning rate for REINFORCE.') | ||||||
|   parser.add_argument('--RL_steps',           type=int,   help='The steps for REINFORCE.') |   #parser.add_argument('--RL_steps',           type=int,   help='The steps for REINFORCE.') | ||||||
|   parser.add_argument('--EMA_momentum',       type=float, help='The momentum value for EMA.') |   parser.add_argument('--EMA_momentum',       type=float, help='The momentum value for EMA.') | ||||||
|  |   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') | ||||||
|   # log |   # log | ||||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') |   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') | ||||||
| @@ -183,7 +194,7 @@ if __name__ == '__main__': | |||||||
|     nas_bench = None |     nas_bench = None | ||||||
|   else: |   else: | ||||||
|     print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) |     print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) | ||||||
|     nas_bench = AANASBenchAPI(args.arch_nas_dataset) |     nas_bench = API(args.arch_nas_dataset) | ||||||
|   if args.rand_seed < 0: |   if args.rand_seed < 0: | ||||||
|     save_dir, all_indexes, num = None, [], 500 |     save_dir, all_indexes, num = None, [], 500 | ||||||
|     for i in range(num): |     for i in range(num): | ||||||
|   | |||||||
| @@ -34,5 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--n_iters 6 --num_samples 3 \ | 	--time_budget 12000 \ | ||||||
|  | 	--n_iters 100 --num_samples 4 --random_fraction 0 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -34,5 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \ | |||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--learning_rate 0.001 --RL_steps 100 --EMA_momentum 0.9 \ | 	--time_budget 12000 \ | ||||||
|  | 	--learning_rate 0.001 --EMA_momentum 0.9 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user