Prototype generic nas model (cont.).

This commit is contained in:
D-X-Y 2020-07-19 08:29:08 +00:00
parent c34620ab1b
commit 31a896346a
2 changed files with 7 additions and 2 deletions

View File

@ -2,7 +2,7 @@
"scheduler": ["str", "cos"],
"LR" : ["float", "0.025"],
"eta_min" : ["float", "0.001"],
"epochs" : ["int", "150"],
"epochs" : ["int", "100"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"decay" : ["float", "0.0005"],

View File

@ -12,6 +12,10 @@
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 1
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas
####
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 1
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
######################################################################################
import os, sys, time, random, argparse
import numpy as np
@ -252,6 +256,7 @@ def main(xargs):
logger.log('model config : {:}'.format(model_config))
search_model = get_cell_based_tiny_net(model_config)
search_model.set_algo(xargs.algo)
logger.log('{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
@ -396,6 +401,6 @@ if __name__ == '__main__':
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, args.algo)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, '{:}-{:}'.format(args.algo, args.drop_path_rate))
main(args)