rm PD ; update NAS-Bench-102 baselines
This commit is contained in:
		| @@ -213,7 +213,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se | ||||
|  | ||||
|  | ||||
| def generate_meta_info(save_dir, max_node, divide=40): | ||||
|   aa_nas_bench_ss = get_search_spaces('cell', 'aa-nas') | ||||
|   aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-102') | ||||
|   archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) | ||||
|   print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) | ||||
|  | ||||
|   | ||||
| @@ -17,6 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
|  | ||||
|  | ||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | ||||
| @@ -144,6 +145,11 @@ def main(xargs): | ||||
|   flop, param  = get_model_infos(search_model, xshape) | ||||
|   #logger.log('{:}'.format(search_model)) | ||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   if xargs.arch_nas_dataset is None: | ||||
|     api = None | ||||
|   else: | ||||
|     api = API(xargs.arch_nas_dataset) | ||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||
|  | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
| @@ -165,7 +171,7 @@ def main(xargs): | ||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} | ||||
|  | ||||
|   # start training | ||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup | ||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||
|   for epoch in range(start_epoch, total_epoch): | ||||
|     w_scheduler.update(epoch, 0.0) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) | ||||
| @@ -173,7 +179,8 @@ def main(xargs): | ||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) | ||||
|  | ||||
|     search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) | ||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) | ||||
|     search_time.update(time.time() - start_time) | ||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     # check the best accuracy | ||||
| @@ -204,6 +211,8 @@ def main(xargs): | ||||
|     if find_best: | ||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     if api is not None: | ||||
|       logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|     # measure elapsed time | ||||
| @@ -211,22 +220,8 @@ def main(xargs): | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('\n' + '-'*100) | ||||
|   # check the performance from the architecture dataset | ||||
|   #if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): | ||||
|   #  logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) | ||||
|   #else: | ||||
|   #  nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) | ||||
|   #  geno = genotypes[total_epoch-1] | ||||
|   #  logger.log('The last model is {:}'.format(geno)) | ||||
|   #  info = nas_bench.query_by_arch( geno ) | ||||
|   #  if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) | ||||
|   #  else           : logger.log('{:}'.format(info)) | ||||
|   #  logger.log('-'*100) | ||||
|   #  geno = genotypes['best'] | ||||
|   #  logger.log('The best model is {:}'.format(geno)) | ||||
|   #  info = nas_bench.query_by_arch( geno ) | ||||
|   #  if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) | ||||
|   #  else           : logger.log('{:}'.format(info)) | ||||
|   logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|   | ||||
| @@ -59,9 +59,9 @@ def train_and_eval(arch, nas_bench, extra_info): | ||||
|   if nas_bench is not None: | ||||
|     arch_index = nas_bench.query_index_by_arch( arch ) | ||||
|     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||
|     info = nas_bench.arch2infos[ arch_index ] | ||||
|     _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs | ||||
|     #import pdb; pdb.set_trace() | ||||
|     info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True) | ||||
|     import pdb; pdb.set_trace() | ||||
|     #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs | ||||
|   else: | ||||
|     # train a model from scratch. | ||||
|     raise ValueError('NOT IMPLEMENT YET') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user