Update VIS-CODES and SCRIPTS
This commit is contained in:
parent
8d27050f6f
commit
a2a1abcb7d
@ -1,22 +1,46 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# bash ./exps/algos-v2/run-all.sh
|
# bash ./exps/algos-v2/run-all.sh mul
|
||||||
|
# bash ./exps/algos-v2/run-all.sh ws
|
||||||
set -e
|
set -e
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 1 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 1 parameters for type of algorithms."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
datasets="cifar10 cifar100 ImageNet16-120"
|
datasets="cifar10 cifar100 ImageNet16-120"
|
||||||
search_spaces="tss sss"
|
alg_type=$1
|
||||||
|
|
||||||
for dataset in ${datasets}
|
if [ "$alg_type" == "mul" ]; then
|
||||||
do
|
search_spaces="tss sss"
|
||||||
for search_space in ${search_spaces}
|
|
||||||
|
for dataset in ${datasets}
|
||||||
do
|
do
|
||||||
python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
|
for search_space in ${search_spaces}
|
||||||
python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
do
|
||||||
python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
|
python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
|
||||||
python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
||||||
|
python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
|
||||||
|
python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||||
|
done
|
||||||
done
|
done
|
||||||
done
|
|
||||||
|
|
||||||
python exps/experimental/vis-bench-algos.py --search_space tss
|
python exps/experimental/vis-bench-algos.py --search_space tss
|
||||||
python exps/experimental/vis-bench-algos.py --search_space sss
|
python exps/experimental/vis-bench-algos.py --search_space sss
|
||||||
|
else
|
||||||
|
seeds="777 888 999"
|
||||||
|
epoch=200
|
||||||
|
for seed in ${seeds}
|
||||||
|
do
|
||||||
|
for alg in "darts-v1 darts-v2 gdas setn random enas"
|
||||||
|
do
|
||||||
|
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||||
|
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||||
|
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
@ -22,8 +22,8 @@
|
|||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas
|
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas
|
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||||
######################################################################################
|
######################################################################################
|
||||||
import os, sys, time, random, argparse
|
import os, sys, time, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -333,7 +333,11 @@ def main(xargs):
|
|||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
if xargs.overwite_epochs is None:
|
||||||
|
extra_info = {'class_num': class_num, 'xshape': xshape}
|
||||||
|
else:
|
||||||
|
extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs}
|
||||||
|
config = load_config(xargs.config_path, extra_info, logger)
|
||||||
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
|
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
|
||||||
(config.batch_size, config.test_batch_size), xargs.workers)
|
(config.batch_size, config.test_batch_size), xargs.workers)
|
||||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||||
@ -496,6 +500,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||||
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
|
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
|
||||||
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
|
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
|
||||||
|
parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.')
|
||||||
# architecture leraning rate
|
# architecture leraning rate
|
||||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||||
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
|
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
|
||||||
@ -508,8 +513,13 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
if args.overwite_epochs is None:
|
||||||
args.dataset,
|
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
||||||
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
|
args.dataset,
|
||||||
|
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
|
||||||
|
else:
|
||||||
|
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
||||||
|
args.dataset,
|
||||||
|
'{:}-affine{:}_BN{:}-E{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate))
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -30,12 +30,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
|
|||||||
ss_dir = '{:}-{:}'.format(root_dir, search_space)
|
ss_dir = '{:}-{:}'.format(root_dir, search_space)
|
||||||
alg2name, alg2path = OrderedDict(), OrderedDict()
|
alg2name, alg2path = OrderedDict(), OrderedDict()
|
||||||
seeds = [777]
|
seeds = [777]
|
||||||
alg2name['GDAS'] = 'gdas-affine1_BN0-None'
|
alg2name['GDAS'] = 'gdas-affine0_BN0-None'
|
||||||
|
alg2name['RSPS'] = 'random-affine0_BN0-None'
|
||||||
"""
|
"""
|
||||||
alg2name['DARTS (1st)'] = 'darts-v1-affine1_BN0-None'
|
alg2name['DARTS (1st)'] = 'darts-v1-affine1_BN0-None'
|
||||||
alg2name['DARTS (2nd)'] = 'darts-v2-affine1_BN0-None'
|
alg2name['DARTS (2nd)'] = 'darts-v2-affine1_BN0-None'
|
||||||
alg2name['SETN'] = 'setn-affine1_BN0-None'
|
alg2name['SETN'] = 'setn-affine1_BN0-None'
|
||||||
alg2name['RSPS'] = 'random-affine1_BN0-None'
|
|
||||||
"""
|
"""
|
||||||
for alg, name in alg2name.items():
|
for alg, name in alg2name.items():
|
||||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
||||||
@ -76,7 +76,7 @@ def visualize_curve(api, vis_save_dir, search_space):
|
|||||||
def sub_plot_fn(ax, dataset):
|
def sub_plot_fn(ax, dataset):
|
||||||
alg2data = fetch_data(search_space=search_space, dataset=dataset)
|
alg2data = fetch_data(search_space=search_space, dataset=dataset)
|
||||||
alg2accuracies = OrderedDict()
|
alg2accuracies = OrderedDict()
|
||||||
epochs = 20
|
epochs = 100
|
||||||
colors = ['b', 'g', 'c', 'm', 'y']
|
colors = ['b', 'g', 'c', 'm', 'y']
|
||||||
ax.set_xlim(0, epochs)
|
ax.set_xlim(0, epochs)
|
||||||
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
|
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
|
||||||
|
Loading…
Reference in New Issue
Block a user