Update new version of BOHB
This commit is contained in:
		| @@ -3,7 +3,8 @@ | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/experimental/vis-bench-algos.py          # | ||||
| # Usage: python exps/experimental/vis-bench-algos.py --search_space tss | ||||
| # Usage: python exps/experimental/vis-bench-algos.py --search_space sss | ||||
| ############################################################### | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| @@ -115,15 +116,17 @@ def visualize_curve(api, vis_save_dir, search_space, max_time): | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--save_dir',    type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') | ||||
|   parser.add_argument('--max_time',    type=float, default=20000, help='The maximum time budget.') | ||||
|   parser.add_argument('--save_dir',     type=str,   default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') | ||||
|   parser.add_argument('--search_space', type=str,   choices=['tss', 'sss'], help='Choose the search space.') | ||||
|   parser.add_argument('--max_time',     type=float, default=20000, help='The maximum time budget.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   save_dir = Path(args.save_dir) | ||||
|  | ||||
|   api201 = NASBench201API(verbose=False) | ||||
|   visualize_curve(api201, save_dir, 'tss', args.max_time) | ||||
|   del api201 | ||||
|   gc.collect() | ||||
|   api301 = NASBench301API(verbose=False) | ||||
|   visualize_curve(api301, save_dir, 'sss', args.max_time) | ||||
|   if args.search_space == 'tss': | ||||
|     api = NASBench201API(verbose=False) | ||||
|   elif args.search_space == 'sss': | ||||
|     api = NASBench301API(verbose=False) | ||||
|   else: | ||||
|     raise ValueError('Invalid search space : {:}'.format(args.search_space)) | ||||
|   visualize_curve(api, save_dir, args.search_space, args.max_time) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user