fix bugs in RANDOM-NAS and BOHB
This commit is contained in:
		| @@ -53,43 +53,50 @@ def config2structure_func(max_nodes): | ||||
|  | ||||
| class MyWorker(Worker): | ||||
|  | ||||
|   def __init__(self, *args, convert_func=None, nas_bench=None, time_scale=None, **kwargs): | ||||
|   def __init__(self, *args, convert_func=None, nas_bench=None, time_budget=None, **kwargs): | ||||
|     super().__init__(*args, **kwargs) | ||||
|     self.convert_func   = convert_func | ||||
|     self.nas_bench      = nas_bench | ||||
|     self.time_scale     = time_scale | ||||
|     self.seen_arch      = 0 | ||||
|     self.time_budget    = time_budget | ||||
|     self.seen_archs     = [] | ||||
|     self.sim_cost_time  = 0 | ||||
|     self.real_cost_time = 0 | ||||
|     self.is_end         = False | ||||
|  | ||||
|   def get_the_best(self): | ||||
|     assert len(self.seen_archs) > 0 | ||||
|     best_index, best_acc = -1, None | ||||
|     for arch_index in self.seen_archs: | ||||
|       info = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) | ||||
|       vacc = info['valid-accuracy'] | ||||
|       if best_acc is None or best_acc < vacc: | ||||
|         best_acc = vacc | ||||
|         best_index = arch_index | ||||
|     assert best_index != -1 | ||||
|     return best_index | ||||
|  | ||||
|   def compute(self, config, budget, **kwargs): | ||||
|     start_time = time.time() | ||||
|     structure  = self.convert_func( config ) | ||||
|     arch_index = self.nas_bench.query_index_by_arch( structure ) | ||||
|     iepoch     = 0 | ||||
|     while iepoch < 12: | ||||
|       info     = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch, True) | ||||
|       cur_time = info['train-all-time'] + info['valid-per-time'] | ||||
|       cur_vacc = info['valid-accuracy'] | ||||
|       if time.time() - start_time + cur_time / self.time_scale > budget: | ||||
|         break | ||||
|       else: | ||||
|         iepoch += 1 | ||||
|     self.sim_cost_time += cur_time | ||||
|     self.seen_arch += 1 | ||||
|     remaining_time = cur_time / self.time_scale - (time.time() - start_time) | ||||
|     if remaining_time > 0: | ||||
|       time.sleep(remaining_time) | ||||
|     else: | ||||
|       import pdb; pdb.set_trace() | ||||
|     info       = self.nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) | ||||
|     cur_time   = info['train-all-time'] + info['valid-per-time'] | ||||
|     cur_vacc   = info['valid-accuracy'] | ||||
|     self.real_cost_time += (time.time() - start_time) | ||||
|     return ({ | ||||
|             'loss': 100 - float(cur_vacc), | ||||
|             'info': {'seen-arch'     : self.seen_arch, | ||||
|                      'sim-test-time' : self.sim_cost_time, | ||||
|                      'real-test-time': self.real_cost_time, | ||||
|                      'current-arch'  : arch_index, | ||||
|                      'current-budget': budget} | ||||
|     if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: | ||||
|       self.sim_cost_time += cur_time | ||||
|       self.seen_archs.append( arch_index ) | ||||
|       return ({'loss': 100 - float(cur_vacc), | ||||
|                'info': {'seen-arch'     : len(self.seen_archs), | ||||
|                         'sim-test-time' : self.sim_cost_time, | ||||
|                         'current-arch'  : arch_index} | ||||
|             }) | ||||
|     else: | ||||
|       self.is_end = True | ||||
|       return ({'loss': 100, | ||||
|                'info': {'seen-arch'     : len(self.seen_archs), | ||||
|                         'sim-test-time' : self.sim_cost_time, | ||||
|                         'current-arch'  : None} | ||||
|             }) | ||||
|  | ||||
|  | ||||
| @@ -139,16 +146,14 @@ def main(xargs, nas_bench): | ||||
|   #logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) | ||||
|   workers = [] | ||||
|   for i in range(num_workers): | ||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_scale=xargs.time_scale, run_id=hb_run_id, id=i) | ||||
|     w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) | ||||
|     w.run(background=True) | ||||
|     workers.append(w) | ||||
|  | ||||
|   simulate_time_budge = xargs.time_budget // xargs.time_scale | ||||
|   start_time = time.time() | ||||
|   logger.log('simulate_time_budge : {:} (in seconds).'.format(simulate_time_budge)) | ||||
|   bohb = BOHB(configspace=cs, | ||||
|             run_id=hb_run_id, | ||||
|             eta=3, min_budget=simulate_time_budge//3, max_budget=simulate_time_budge, | ||||
|             eta=3, min_budget=12, max_budget=200, | ||||
|             nameserver=ns_host, | ||||
|             nameserver_port=ns_port, | ||||
|             num_samples=xargs.num_samples, | ||||
| @@ -161,11 +166,9 @@ def main(xargs, nas_bench): | ||||
|   NS.shutdown() | ||||
|  | ||||
|   real_cost_time = time.time() - start_time | ||||
|   import pdb; pdb.set_trace() | ||||
|  | ||||
|   id2config = results.get_id2config_mapping() | ||||
|   incumbent = results.get_incumbent_id() | ||||
|  | ||||
|   logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config'])) | ||||
|   best_arch = config2structure( id2config[incumbent]['config'] ) | ||||
|  | ||||
| @@ -174,7 +177,7 @@ def main(xargs, nas_bench): | ||||
|   else           : logger.log('{:}'.format(info)) | ||||
|   logger.log('-'*100) | ||||
|  | ||||
|   logger.log('workers : {:}'.format(workers[0].test_time)) | ||||
|   logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs))) | ||||
|   logger.close() | ||||
|   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) | ||||
|    | ||||
| @@ -190,14 +193,13 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).') | ||||
|   parser.add_argument('--time_scale' ,        type=int,   help='The time scale to accelerate the time budget.') | ||||
|   # BOHB | ||||
|   parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function') | ||||
|   parser.add_argument('--min_bandwidth',    default=.3, type=float, nargs='?', help='minimum bandwidth for KDE') | ||||
|   parser.add_argument('--num_samples',      default=64, type=int, nargs='?', help='number of samples for the acquisition function') | ||||
|   parser.add_argument('--strategy', default="sampling",  type=str, nargs='?', help='optimization strategy for the acquisition function') | ||||
|   parser.add_argument('--min_bandwidth',    default=.3,  type=float, nargs='?', help='minimum bandwidth for KDE') | ||||
|   parser.add_argument('--num_samples',      default=64,  type=int, nargs='?', help='number of samples for the acquisition function') | ||||
|   parser.add_argument('--random_fraction',  default=.33, type=float, nargs='?', help='fraction of random configurations') | ||||
|   parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') | ||||
|   parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method') | ||||
|   parser.add_argument('--bandwidth_factor', default=3,   type=int, nargs='?', help='factor multiplied to the bandwidth') | ||||
|   parser.add_argument('--n_iters',          default=100, type=int, nargs='?', help='number of iterations for optimization method') | ||||
|   # log | ||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') | ||||
|   | ||||
| @@ -82,14 +82,29 @@ def valid_func(xloader, network, criterion): | ||||
|   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def search_find_best(valid_loader, network, criterion, select_num): | ||||
|   best_arch, best_acc = None, -1 | ||||
|   for iarch in range(select_num): | ||||
|     arch = network.module.random_genotype( True ) | ||||
|     valid_a_loss, valid_a_top1, valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     if best_arch is None or best_acc < valid_a_top1: | ||||
|       best_arch, best_acc = arch, valid_a_top1 | ||||
|   return best_arch | ||||
| def search_find_best(xloader, network, n_samples): | ||||
|   with torch.no_grad(): | ||||
|     network.eval() | ||||
|     archs, valid_accs = [], [] | ||||
|     #print ('obtain the top-{:} architectures'.format(n_samples)) | ||||
|     loader_iter = iter(xloader) | ||||
|     for i in range(n_samples): | ||||
|       arch = network.module.random_genotype( True ) | ||||
|       try: | ||||
|         inputs, targets = next(loader_iter) | ||||
|       except: | ||||
|         loader_iter = iter(xloader) | ||||
|         inputs, targets = next(loader_iter) | ||||
|  | ||||
|       _, logits = network(inputs) | ||||
|       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) | ||||
|  | ||||
|       archs.append( arch ) | ||||
|       valid_accs.append( val_top1.item() ) | ||||
|  | ||||
|     best_idx = np.argmax(valid_accs) | ||||
|     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
|     return best_arch, best_valid_acc | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
| @@ -127,7 +142,7 @@ def main(xargs): | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|  | ||||
| @@ -177,7 +192,8 @@ def main(xargs): | ||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) | ||||
|     cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) | ||||
|     logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc)) | ||||
|     genotypes[epoch] = cur_arch | ||||
|     # check the best accuracy | ||||
|     valid_accuracies[epoch] = valid_a_top1 | ||||
| @@ -211,13 +227,7 @@ def main(xargs): | ||||
|   logger.log('\n' + '-'*200) | ||||
|   logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) | ||||
|   start_time = time.time() | ||||
|   best_arch, best_acc = None, -1 | ||||
|   for iarch in range(xargs.select_num): | ||||
|     arch = search_model.random_genotype( True ) | ||||
|     valid_a_loss, valid_a_top1, valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss)) | ||||
|     if best_arch is None or best_acc < valid_a_top1: | ||||
|       best_arch, best_acc = arch, valid_a_top1 | ||||
|   best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) | ||||
|   search_time.update(time.time() - start_time) | ||||
|   logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user