Update VIS-CODES and SCRIPTS
This commit is contained in:
		| @@ -1,22 +1,46 @@ | ||||
| #!/bin/bash | ||||
| # bash ./exps/algos-v2/run-all.sh | ||||
| # bash ./exps/algos-v2/run-all.sh mul | ||||
| # bash ./exps/algos-v2/run-all.sh ws | ||||
| set -e | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
| if [ "$#" -ne 1 ] ;then | ||||
|   echo "Input illegal number of parameters " $# | ||||
|   echo "Need 1 parameters for type of algorithms." | ||||
|   exit 1 | ||||
| fi | ||||
|  | ||||
|  | ||||
| datasets="cifar10 cifar100 ImageNet16-120" | ||||
| search_spaces="tss sss" | ||||
| alg_type=$1 | ||||
|  | ||||
| for dataset in ${datasets} | ||||
| do | ||||
|   for search_space in ${search_spaces} | ||||
| if [ "$alg_type" == "mul" ]; then | ||||
|   search_spaces="tss sss" | ||||
|  | ||||
|   for dataset in ${datasets} | ||||
|   do | ||||
|     python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01 | ||||
|     python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||
|     python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} | ||||
|     python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||
|     for search_space in ${search_spaces} | ||||
|     do | ||||
|       python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01 | ||||
|       python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||
|       python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} | ||||
|       python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||
|     done | ||||
|   done | ||||
| done | ||||
|  | ||||
| python exps/experimental/vis-bench-algos.py --search_space tss | ||||
| python exps/experimental/vis-bench-algos.py --search_space sss | ||||
|   python exps/experimental/vis-bench-algos.py --search_space tss | ||||
|   python exps/experimental/vis-bench-algos.py --search_space sss | ||||
| else | ||||
|   seeds="777 888 999" | ||||
|   epoch=200 | ||||
|   for seed in ${seeds} | ||||
|   do | ||||
|     for alg in "darts-v1 darts-v2 gdas setn random enas" | ||||
|     do | ||||
|     python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch} | ||||
|     python ./exps/algos-v2/search-cell.py --dataset cifar100  --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch} | ||||
|     python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120  --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch} | ||||
|     done | ||||
|   done | ||||
| fi | ||||
|  | ||||
|   | ||||
| @@ -22,8 +22,8 @@ | ||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random | ||||
| #### | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas | ||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| ###################################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| @@ -333,7 +333,11 @@ def main(xargs): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   if xargs.overwite_epochs is None: | ||||
|     extra_info = {'class_num': class_num, 'xshape': xshape} | ||||
|   else: | ||||
|     extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs} | ||||
|   config = load_config(xargs.config_path, extra_info, logger) | ||||
|   search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ | ||||
|                                         (config.batch_size, config.test_batch_size), xargs.workers) | ||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||
| @@ -496,6 +500,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--track_running_stats',type=int,   default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||
|   parser.add_argument('--affine'      ,       type=int,   default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') | ||||
|   parser.add_argument('--config_path' ,       type=str,   default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') | ||||
|   parser.add_argument('--overwite_epochs',    type=int,   help='The number of epochs to overwrite that value in config files.') | ||||
|   # architecture leraning rate | ||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
|   parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding') | ||||
| @@ -508,8 +513,13 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--rand_seed',          type=int,   help='manual seed') | ||||
|   args = parser.parse_args() | ||||
|   if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), | ||||
|                                args.dataset, | ||||
|                                '{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate)) | ||||
|   if args.overwite_epochs is None: | ||||
|     args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), | ||||
|         args.dataset, | ||||
|         '{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate)) | ||||
|   else: | ||||
|     args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), | ||||
|         args.dataset, | ||||
|         '{:}-affine{:}_BN{:}-E{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate)) | ||||
|  | ||||
|   main(args) | ||||
|   | ||||
| @@ -30,12 +30,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | ||||
|   ss_dir = '{:}-{:}'.format(root_dir, search_space) | ||||
|   alg2name, alg2path = OrderedDict(), OrderedDict() | ||||
|   seeds = [777] | ||||
|   alg2name['GDAS'] = 'gdas-affine1_BN0-None' | ||||
|   alg2name['GDAS'] = 'gdas-affine0_BN0-None' | ||||
|   alg2name['RSPS'] = 'random-affine0_BN0-None' | ||||
|   """ | ||||
|   alg2name['DARTS (1st)'] = 'darts-v1-affine1_BN0-None' | ||||
|   alg2name['DARTS (2nd)'] = 'darts-v2-affine1_BN0-None' | ||||
|   alg2name['SETN'] = 'setn-affine1_BN0-None' | ||||
|   alg2name['RSPS'] = 'random-affine1_BN0-None' | ||||
|   """ | ||||
|   for alg, name in alg2name.items(): | ||||
|     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') | ||||
| @@ -76,7 +76,7 @@ def visualize_curve(api, vis_save_dir, search_space): | ||||
|   def sub_plot_fn(ax, dataset): | ||||
|     alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||
|     alg2accuracies = OrderedDict() | ||||
|     epochs = 20 | ||||
|     epochs = 100 | ||||
|     colors = ['b', 'g', 'c', 'm', 'y'] | ||||
|     ax.set_xlim(0, epochs) | ||||
|     # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user