102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # python exps/NAS-Bench-102/check.py --base_save_dir  | ||||
| # python exps/NAS-Bench-201/check.py --base_save_dir  | ||||
| ################################################## | ||||
| import os, sys, time, argparse, collections | ||||
| from shutil import copyfile | ||||
| @@ -67,8 +67,8 @@ def check_files(save_dir, meta_file, basestr): | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| 
 | ||||
|   parser = argparse.ArgumentParser(description='NAS Benchmark 102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--base_save_dir',  type=str, default='./output/NAS-BENCH-102-4',     help='The base-name of folder to save checkpoints and log.') | ||||
|   parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--base_save_dir',  type=str, default='./output/NAS-BENCH-201-4',     help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--max_node',       type=int, default=4,                                 help='The maximum node in a cell.') | ||||
|   parser.add_argument('--channel',        type=int, default=16,                                help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',      type=int, default=5,                                 help='The number of cells in one stage.') | ||||
| @@ -78,7 +78,7 @@ if __name__ == '__main__': | ||||
|   meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) | ||||
|   assert save_dir.exists(),  'invalid save dir path : {:}'.format(save_dir) | ||||
|   assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) | ||||
|   print ('check NAS-Bench-102 in {:}'.format(save_dir)) | ||||
|   print ('check NAS-Bench-201 in {:}'.format(save_dir)) | ||||
| 
 | ||||
|   basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells) | ||||
|   check_files(save_dir, meta_path, basestr) | ||||
| @@ -8,15 +8,15 @@ def read(fname='README.md'): | ||||
| 
 | ||||
| 
 | ||||
| setup( | ||||
|     name = "nas_bench_102", | ||||
|     name = "nas_bench_201", | ||||
|     version = "1.0", | ||||
|     author = "Xuanyi Dong", | ||||
|     author_email = "dongxuanyi888@gmail.com", | ||||
|     description = "API for NAS-Bench-102 (a benchmark for neural architecture search).", | ||||
|     description = "API for NAS-Bench-201 (a benchmark for neural architecture search).", | ||||
|     license = "MIT", | ||||
|     keywords = "NAS Dataset API DeepLearning", | ||||
|     url = "https://github.com/D-X-Y/NAS-Projects", | ||||
|     packages=['nas_102_api'], | ||||
|     url = "https://github.com/D-X-Y/NAS-Bench-201", | ||||
|     packages=['nas_201_api'], | ||||
|     long_description=read('README.md'), | ||||
|     long_description_content_type='text/markdown', | ||||
|     classifiers=[ | ||||
| @@ -1,5 +1,5 @@ | ||||
| ############################################################### | ||||
| # NAS-Bench-102, ICLR 2020 (https://arxiv.org/abs/2001.00326) # | ||||
| # NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019-2020         # | ||||
| ############################################################### | ||||
| @@ -213,7 +213,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se | ||||
| 
 | ||||
| 
 | ||||
| def generate_meta_info(save_dir, max_node, divide=40): | ||||
|   aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-102') | ||||
|   aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-201') | ||||
|   archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) | ||||
|   print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) | ||||
| 
 | ||||
| @@ -249,15 +249,15 @@ def generate_meta_info(save_dir, max_node, divide=40): | ||||
|   torch.save(info, save_name) | ||||
|   print ('save the meta file into {:}'.format(save_name)) | ||||
| 
 | ||||
|   script_name_full = save_dir / 'BENCH-102-N{:}.opt-full.script'.format(max_node) | ||||
|   script_name_less = save_dir / 'BENCH-102-N{:}.opt-less.script'.format(max_node) | ||||
|   script_name_full = save_dir / 'BENCH-201-N{:}.opt-full.script'.format(max_node) | ||||
|   script_name_less = save_dir / 'BENCH-201-N{:}.opt-less.script'.format(max_node) | ||||
|   full_file = open(str(script_name_full), 'w') | ||||
|   less_file = open(str(script_name_less), 'w') | ||||
|   gaps = total_arch // divide | ||||
|   for start in range(0, total_arch, gaps): | ||||
|     xend = min(start+gaps, total_arch) | ||||
|     full_file.write('bash ./scripts-search/NAS-Bench-102/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|     less_file.write('bash ./scripts-search/NAS-Bench-102/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|     full_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|     less_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|   print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less)) | ||||
|   full_file.close() | ||||
|   less_file.close() | ||||
| @@ -267,14 +267,14 @@ def generate_meta_info(save_dir, max_node, divide=40): | ||||
|   with open(str(script_name), 'w') as cfile: | ||||
|     for start in range(0, total_arch, gaps): | ||||
|       xend = min(start+gaps, total_arch) | ||||
|       cfile.write('{:} python exps/NAS-Bench-102/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1)) | ||||
|       cfile.write('{:} python exps/NAS-Bench-201/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1)) | ||||
|   print ('save the post-processing script into {:}'.format(script_name)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   #mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] | ||||
|   #parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--mode'   ,     type=str,   required=True,  help='The script mode.') | ||||
|   parser.add_argument('--save_dir',    type=str,                   help='Folder to save checkpoints and log.') | ||||
|   parser.add_argument('--max_node',    type=int,                   help='The maximum node in a cell.') | ||||
| @@ -12,9 +12,9 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from config_utils import load_config, dict2config | ||||
| from datasets     import get_datasets | ||||
| # NAS-Bench-102 related module or function | ||||
| # NAS-Bench-201 related module or function | ||||
| from models       import CellStructure, get_cell_based_tiny_net | ||||
| from nas_102_api  import ArchResults, ResultsCount | ||||
| from nas_201_api  import ArchResults, ResultsCount | ||||
| from functions    import pure_evaluate | ||||
| 
 | ||||
| 
 | ||||
| @@ -271,9 +271,9 @@ def merge_all(save_dir, meta_file, basestr): | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| 
 | ||||
|   parser = argparse.ArgumentParser(description='NAS-BENCH-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--mode'         ,  type=str, choices=['cal', 'merge'],            help='The running mode for this script.') | ||||
|   parser.add_argument('--base_save_dir',  type=str, default='./output/NAS-BENCH-102-4',  help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--base_save_dir',  type=str, default='./output/NAS-BENCH-201-4',  help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--target_dir'   ,  type=str,                                      help='The target directory.') | ||||
|   parser.add_argument('--max_node'     ,  type=int, default=4,                           help='The maximum node in a cell.') | ||||
|   parser.add_argument('--channel'      ,  type=int, default=16,                          help='The number of channels.') | ||||
| @@ -1,7 +1,7 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ######################################################## | ||||
| # python exps/NAS-Bench-102/test-correlation.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth | ||||
| # python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||
| ######################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| @@ -18,7 +18,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces, CellStructure | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
| 
 | ||||
|    | ||||
| def valid_func(xloader, network, criterion): | ||||
| @@ -197,9 +197,9 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-102") | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-102 benchmark file.') | ||||
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-201 benchmark file.') | ||||
|   args = parser.parse_args() | ||||
| 
 | ||||
|   vis_save_dir = Path(args.save_dir) | ||||
| @@ -1,7 +1,7 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # python exps/NAS-Bench-102/visualize.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth | ||||
| # python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||
| ################################################## | ||||
| import os, sys, time, argparse, collections | ||||
| from tqdm import tqdm | ||||
| @@ -19,7 +19,7 @@ import matplotlib.pyplot as plt | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from log_utils    import time_string | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @@ -367,13 +367,66 @@ def write_video(save_dir): | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def plot_results_nas_v2(api, dataset_xset_a, dataset_xset_b, root, file_name, y_lims): | ||||
|   #print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) | ||||
|   print ('root-path : {:} and {:}'.format(dataset_xset_a, dataset_xset_b)) | ||||
|   checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-201/RAND-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-201/BOHB-cifar10/results.pth' | ||||
|                 ] | ||||
|   legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None | ||||
|   All_Accs_A, All_Accs_B = OrderedDict(), OrderedDict() | ||||
|   for legend, checkpoint in zip(legends, checkpoints): | ||||
|     all_indexes = torch.load(checkpoint, map_location='cpu') | ||||
|     accuracies_A, accuracies_B = [], [] | ||||
|     accuracies = [] | ||||
|     for x in all_indexes: | ||||
|       info = api.arch2infos_full[ x ] | ||||
|       metrics = info.get_metrics(dataset_xset_a[0], dataset_xset_a[1], None, False) | ||||
|       accuracies_A.append( metrics['accuracy'] ) | ||||
|       metrics = info.get_metrics(dataset_xset_b[0], dataset_xset_b[1], None, False) | ||||
|       accuracies_B.append( metrics['accuracy'] ) | ||||
|       accuracies.append( (accuracies_A[-1], accuracies_B[-1]) ) | ||||
|     if indexes is None: indexes = list(range(len(all_indexes))) | ||||
|     accuracies = sorted(accuracies) | ||||
|     All_Accs_A[legend] = [x[0] for x in accuracies] | ||||
|     All_Accs_B[legend] = [x[1] for x in accuracies] | ||||
| 
 | ||||
|   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||
|   dpi, width, height = 300, 3400, 2600 | ||||
|   LabelSize, LegendFontsize = 28, 28 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   fig = plt.figure(figsize=figsize) | ||||
|   x_axis = np.arange(0, 600) | ||||
|   plt.xlim(0, max(indexes)) | ||||
|   plt.ylim(y_lims[0], y_lims[1]) | ||||
|   interval_x, interval_y = 100, y_lims[2] | ||||
|   plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) | ||||
|   plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) | ||||
|   plt.grid() | ||||
|   plt.xlabel('The index of runs', fontsize=LabelSize) | ||||
|   plt.ylabel('The accuracy (%)', fontsize=LabelSize) | ||||
| 
 | ||||
|   for idx, legend in enumerate(legends): | ||||
|     plt.plot(indexes, All_Accs_B[legend], color=color_set[idx], linestyle='--', label='{:}'.format(legend), lw=1, alpha=0.5) | ||||
|     plt.plot(indexes, All_Accs_A[legend], color=color_set[idx], linestyle='-', lw=1) | ||||
|     for All_Accs in [All_Accs_A, All_Accs_B]: | ||||
|       print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) | ||||
|   plt.legend(loc=4, fontsize=LegendFontsize) | ||||
|   save_path = root / '{:}'.format(file_name) | ||||
|   print('save figure into {:}\n'.format(save_path)) | ||||
|   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def plot_results_nas(api, dataset, xset, root, file_name, y_lims): | ||||
|   print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) | ||||
|   checkpoints = ['./output/search-cell-nas-bench-102/R-EA-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-102/REINFORCE-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-102/RAND-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-102/BOHB-cifar10/results.pth' | ||||
|   checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-201/RAND-cifar10/results.pth', | ||||
|                  './output/search-cell-nas-bench-201/BOHB-cifar10/results.pth' | ||||
|                 ] | ||||
|   legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None | ||||
|   All_Accs = OrderedDict() | ||||
| @@ -422,19 +475,19 @@ def just_show(api): | ||||
|     xlist = np.array(xlist) | ||||
|     print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean())) | ||||
| 
 | ||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', | ||||
|             'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', | ||||
|             'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/', | ||||
|             'GDAS'    : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/', | ||||
|             'SETN'    : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/', | ||||
|             'ENAS'    : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/', | ||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/', | ||||
|             'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/', | ||||
|             'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/', | ||||
|             'GDAS'    : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/', | ||||
|             'SETN'    : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/', | ||||
|             'ENAS'    : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/', | ||||
|            } | ||||
|   xseeds = {'RSPS'    : [5349, 59613, 5983], | ||||
|             'DARTS-V1': [11416, 72873, 81184], | ||||
|             'DARTS-V2': [43330, 79405, 79423], | ||||
|             'GDAS'    : [19677, 884, 95950], | ||||
|             'SETN'    : [20518, 61817, 89144], | ||||
|             'ENAS'    : [30801, 75610, 97745], | ||||
|             'ENAS'    : [3231, 34238, 96929], | ||||
|            } | ||||
| 
 | ||||
|   def get_accs(xdata, index=-1): | ||||
| @@ -480,24 +533,27 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_ | ||||
|   plt.xlabel('The searching epoch', fontsize=LabelSize) | ||||
|   plt.ylabel('The accuracy (%)', fontsize=LabelSize) | ||||
| 
 | ||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', | ||||
|             'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', | ||||
|             'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/', | ||||
|             'GDAS'    : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/', | ||||
|             'SETN'    : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/', | ||||
|             'ENAS'    : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/', | ||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/', | ||||
|             'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/', | ||||
|             'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/', | ||||
|             'GDAS'    : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/', | ||||
|             'SETN'    : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/', | ||||
|             'ENAS'    : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/', | ||||
|            } | ||||
|   xseeds = {'RSPS'    : [5349, 59613, 5983], | ||||
|             'DARTS-V1': [11416, 72873, 81184], | ||||
|             'DARTS-V1': [11416, 72873, 81184, 28640], | ||||
|             'DARTS-V2': [43330, 79405, 79423], | ||||
|             'GDAS'    : [19677, 884, 95950], | ||||
|             'SETN'    : [20518, 61817, 89144], | ||||
|             'ENAS'    : [30801, 75610, 97745], | ||||
|             'ENAS'    : [3231, 34238, 96929], | ||||
|            } | ||||
| 
 | ||||
|   def get_accs(xdata): | ||||
|     epochs, xresults = xdata['epoch'], [] | ||||
|     metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False) | ||||
|     if -1 in xdata['genotypes']: | ||||
|       metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False) | ||||
|     else: | ||||
|       metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False) | ||||
|     xresults.append( metrics['accuracy'] ) | ||||
|     for iepoch in range(epochs): | ||||
|       genotype = xdata['genotypes'][iepoch] | ||||
| @@ -528,12 +584,120 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_ | ||||
|   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||
| 
 | ||||
| 
 | ||||
| def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, y_lims, x_maxs): | ||||
|   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||
|   dpi, width, height = 300, 3400, 2600 | ||||
|   LabelSize, LegendFontsize = 28, 28 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   fig = plt.figure(figsize=figsize) | ||||
|   #x_maxs = 250 | ||||
|   plt.xlim(0, x_maxs+1) | ||||
|   plt.ylim(y_lims[0], y_lims[1]) | ||||
|   interval_x, interval_y = x_maxs // 5, y_lims[2] | ||||
|   plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize) | ||||
|   plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) | ||||
|   plt.grid() | ||||
|   plt.xlabel('The searching epoch', fontsize=LabelSize) | ||||
|   plt.ylabel('The accuracy (%)', fontsize=LabelSize) | ||||
| 
 | ||||
|   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/', | ||||
|             'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/', | ||||
|             'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/', | ||||
|             'GDAS'    : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/', | ||||
|             'SETN'    : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/', | ||||
|             'ENAS'    : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/', | ||||
|            } | ||||
|   xseeds = {'RSPS'    : [5349, 59613, 5983], | ||||
|             'DARTS-V1': [11416, 72873, 81184, 28640], | ||||
|             'DARTS-V2': [43330, 79405, 79423], | ||||
|             'GDAS'    : [19677, 884, 95950], | ||||
|             'SETN'    : [20518, 61817, 89144], | ||||
|             'ENAS'    : [3231, 34238, 96929], | ||||
|            } | ||||
| 
 | ||||
|   def get_accs(xdata, dataset, subset): | ||||
|     epochs, xresults = xdata['epoch'], [] | ||||
|     if -1 in xdata['genotypes']: | ||||
|       metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False) | ||||
|     else: | ||||
|       metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False) | ||||
|     xresults.append( metrics['accuracy'] ) | ||||
|     for iepoch in range(epochs): | ||||
|       genotype = xdata['genotypes'][iepoch] | ||||
|       index = api.query_index_by_arch(genotype) | ||||
|       metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False) | ||||
|       xresults.append( metrics['accuracy'] ) | ||||
|     return xresults | ||||
| 
 | ||||
|   if x_maxs == 50: | ||||
|     xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2'] | ||||
|   elif x_maxs == 250: | ||||
|     xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS'] | ||||
|   else: raise ValueError('invalid x_maxs={:}'.format(x_maxs)) | ||||
| 
 | ||||
|   for idx, method in enumerate(xxxstrs): | ||||
|     xkey = method | ||||
|     all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ] | ||||
|     all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths] | ||||
|     accyss_A = np.array( [get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas] ) | ||||
|     accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] ) | ||||
|     epochs = list(range(accyss_A.shape[1])) | ||||
|     for j, accyss in enumerate([accyss_A, accyss_B]): | ||||
|       plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx*2+j], linestyle='-' if j==0 else '--', label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9) | ||||
|       plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx*2+j]) | ||||
|   #plt.legend(loc=4, fontsize=LegendFontsize) | ||||
|   plt.legend(loc=0, fontsize=LegendFontsize) | ||||
|   save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name) | ||||
|   print('save figure into {:}\n'.format(save_path)) | ||||
|   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||
| 
 | ||||
| 
 | ||||
| def show_reinforce(api, root, dataset, xset, file_name, y_lims): | ||||
|   print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) | ||||
|   LRs = ['0.01', '0.02', '0.1', '0.2', '0.5', '1.0', '1.5', '2.0', '2.5', '3.0'] | ||||
|   checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs] | ||||
|   acc_lr_dict, indexes = {}, None | ||||
|   for lr, checkpoint in zip(LRs, checkpoints): | ||||
|     all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), [] | ||||
|     for x in all_indexes: | ||||
|       info = api.arch2infos_full[ x ] | ||||
|       metrics = info.get_metrics(dataset, xset, None, False) | ||||
|       accuracies.append( metrics['accuracy'] ) | ||||
|     if indexes is None: indexes = list(range(len(accuracies))) | ||||
|     acc_lr_dict[lr] = np.array( sorted(accuracies) ) | ||||
|     print ('LR={:.3f}, mean={:}, std={:}'.format(float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std())) | ||||
|    | ||||
|   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||
|   dpi, width, height = 300, 3400, 2600 | ||||
|   LabelSize, LegendFontsize = 28, 22 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   fig = plt.figure(figsize=figsize) | ||||
|   x_axis = np.arange(0, 600) | ||||
|   plt.xlim(0, max(indexes)) | ||||
|   plt.ylim(y_lims[0], y_lims[1]) | ||||
|   interval_x, interval_y = 100, y_lims[2] | ||||
|   plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) | ||||
|   plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) | ||||
|   plt.grid() | ||||
|   plt.xlabel('The index of runs', fontsize=LabelSize) | ||||
|   plt.ylabel('The accuracy (%)', fontsize=LabelSize) | ||||
| 
 | ||||
|   for idx, LR in enumerate(LRs): | ||||
|     legend = 'LR={:.2f}'.format(float(LR)) | ||||
|     color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.' | ||||
|     plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8) | ||||
|     print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]))) | ||||
|   plt.legend(loc=4, fontsize=LegendFontsize) | ||||
|   save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name) | ||||
|   print('save figure into {:}\n'.format(save_path)) | ||||
|   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| 
 | ||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-102 benchmark file.') | ||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-201 benchmark file.') | ||||
|   args = parser.parse_args() | ||||
|    | ||||
|   vis_save_dir = Path(args.save_dir) | ||||
| @@ -548,6 +712,9 @@ if __name__ == '__main__': | ||||
|   #visualize_relative_ranking(vis_save_dir) | ||||
| 
 | ||||
|   api = API(args.api_path) | ||||
|   show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (75, 95, 5)) | ||||
|   import pdb; pdb.set_trace() | ||||
| 
 | ||||
|   for x_maxs in [50, 250]: | ||||
|     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
| @@ -555,12 +722,19 @@ if __name__ == '__main__': | ||||
|     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|    | ||||
|   show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10'       , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50) | ||||
|   show_nas_sharing_w_v2(api, ('cifar100'      , 'x-valid'), ('cifar100'      , 'x-test'  ) , vis_save_dir, 'DARTS-CIFAR100.pdf', (0, 100,10), 50) | ||||
|   show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test'  ) , vis_save_dir, 'DARTS-ImageNet.pdf', (0, 100,10), 50) | ||||
|   #just_show(api) | ||||
|   """ | ||||
|   just_show(api) | ||||
|   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||
|   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||
|   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||
|   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||
|   plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10'       , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1)) | ||||
|   plot_results_nas_v2(api, ('cifar100'      , 'x-valid'), ('cifar100'      , 'x-test'  ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3)) | ||||
|   plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test'  ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2)) | ||||
|   """ | ||||
| @@ -1,9 +1,10 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # required to install hpbandster ################# | ||||
| # bash ./scripts-search/algos/BOHB.sh -1         # | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################################### | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale # | ||||
| # required to install hpbandster ################################## | ||||
| # bash ./scripts-search/algos/BOHB.sh -1         ################## | ||||
| ################################################################### | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| from copy import deepcopy | ||||
| @@ -17,7 +18,7 @@ from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
| from models       import CellStructure, get_search_spaces | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | ||||
| import ConfigSpace | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ######################################################## | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 # | ||||
| ######################################################## | ||||
| @@ -17,7 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ######################################################## | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 # | ||||
| ######################################################## | ||||
| @@ -17,7 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def _concat(xs): | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ########################################################################## | ||||
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | ||||
| ########################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| @@ -15,7 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| @@ -17,7 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ############################################################################## | ||||
| # Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # | ||||
| ############################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| @@ -15,7 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger): | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ############################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| from copy import deepcopy | ||||
| @@ -15,7 +15,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
| from R_EA         import train_and_eval, random_architecture_func | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################################## | ||||
| # Regularized Evolution for Image Classifier Architecture Search # | ||||
| ################################################################## | ||||
| @@ -16,7 +16,7 @@ from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
| from models       import CellStructure, get_search_spaces | ||||
|  | ||||
|  | ||||
| @@ -31,30 +31,8 @@ class Model(object): | ||||
|     return '{:}'.format(self.arch) | ||||
|    | ||||
|  | ||||
| def valid_func(xloader, network, criterion): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   network.train() | ||||
|   end = time.time() | ||||
|   with torch.no_grad(): | ||||
|     for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|       arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|       # measure data loading time | ||||
|       data_time.update(time.time() - end) | ||||
|       # prediction | ||||
|       _, logits = network(arch_inputs) | ||||
|       arch_loss = criterion(logits, arch_targets) | ||||
|       # record | ||||
|       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||
|       arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||
|       arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) | ||||
|       arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) | ||||
|       # measure elapsed time | ||||
|       batch_time.update(time.time() - end) | ||||
|       end = time.time() | ||||
|   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. | ||||
| # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. | ||||
| def train_and_eval(arch, nas_bench, extra_info): | ||||
|   if nas_bench is not None: | ||||
|     arch_index = nas_bench.query_index_by_arch( arch ) | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ###################################################################################### | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # | ||||
| ###################################################################################### | ||||
| @@ -17,7 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | ||||
|   | ||||
| @@ -17,7 +17,7 @@ from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from nas_102_api  import NASBench102API as API | ||||
| from nas_201_api  import NASBench201API as API | ||||
| from models       import CellStructure, get_search_spaces | ||||
| from R_EA import train_and_eval | ||||
|  | ||||
| @@ -128,6 +128,7 @@ def main(xargs, nas_bench): | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   policy    = Policy(xargs.max_nodes, search_space) | ||||
|   optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) | ||||
|   #optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate) | ||||
|   eps       = np.finfo(np.float32).eps.item() | ||||
|   baseline  = ExponentialMovingAverage(xargs.EMA_momentum) | ||||
|   logger.log('policy    : {:}'.format(policy)) | ||||
| @@ -141,13 +142,14 @@ def main(xargs, nas_bench): | ||||
|   # attempts = 0 | ||||
|   x_start_time = time.time() | ||||
|   logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget)) | ||||
|   total_steps, total_costs = 0, 0 | ||||
|   total_steps, total_costs, trace = 0, 0, [] | ||||
|   #for istep in range(xargs.RL_steps): | ||||
|   while total_costs < xargs.time_budget: | ||||
|     start_time = time.time() | ||||
|     log_prob, action = select_action( policy ) | ||||
|     arch   = policy.generate_arch( action ) | ||||
|     reward, cost_time = train_and_eval(arch, nas_bench, extra_info) | ||||
|     trace.append( (reward, arch) ) | ||||
|     # accumulate time | ||||
|     if total_costs + cost_time < xargs.time_budget: | ||||
|       total_costs += cost_time | ||||
| @@ -166,7 +168,8 @@ def main(xargs, nas_bench): | ||||
|     #logger.log('----> {:}'.format(policy.arch_parameters)) | ||||
|     #logger.log('') | ||||
|  | ||||
|   best_arch = policy.genotype() | ||||
|   # best_arch = policy.genotype() # first version | ||||
|   best_arch = max(trace, key=lambda x: x[0])[1] | ||||
|   logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time)) | ||||
|   info = nas_bench.query_by_arch( best_arch ) | ||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||
|   | ||||
| @@ -8,11 +8,11 @@ from collections import OrderedDict | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from nas_102_api import NASBench102API as API | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
| def test_nas_api(): | ||||
|   from nas_102_api import ArchResults | ||||
|   xdata   = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-102-4/simplifies/architectures/000157-FULL.pth') | ||||
|   from nas_201_api import ArchResults | ||||
|   xdata   = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth') | ||||
|   for key in ['full', 'less']: | ||||
|     print ('\n------------------------- {:} -------------------------'.format(key)) | ||||
|     archRes = ArchResults.create_from_state_dict(xdata[key]) | ||||
| @@ -81,8 +81,8 @@ def test_one_shot_model(ckpath, use_train): | ||||
|   from config_utils import load_config, dict2config | ||||
|   from utils.nas_utils import evaluate_one_shot | ||||
|   use_train = int(use_train) > 0 | ||||
|   #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' | ||||
|   #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' | ||||
|   #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' | ||||
|   #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' | ||||
|   print ('ckpath : {:}'.format(ckpath)) | ||||
|   ckp = torch.load(ckpath) | ||||
|   xargs = ckp['args'] | ||||
| @@ -103,7 +103,7 @@ def test_one_shot_model(ckpath, use_train): | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   search_model.load_state_dict( ckp['search_model'] ) | ||||
|   search_model = search_model.cuda() | ||||
|   api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth') | ||||
|   api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth') | ||||
|   archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train) | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user