Prototype generic nas model (cont.).
This commit is contained in:
parent
c34620ab1b
commit
31a896346a
@ -2,7 +2,7 @@
|
||||
"scheduler": ["str", "cos"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"eta_min" : ["float", "0.001"],
|
||||
"epochs" : ["int", "150"],
|
||||
"epochs" : ["int", "100"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
|
@ -12,6 +12,10 @@
|
||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 1
|
||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
|
||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas
|
||||
####
|
||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 1
|
||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
|
||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
|
||||
######################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
@ -252,6 +256,7 @@ def main(xargs):
|
||||
logger.log('model config : {:}'.format(model_config))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.set_algo(xargs.algo)
|
||||
logger.log('{:}'.format(search_model))
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
|
||||
a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||
@ -396,6 +401,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, args.algo)
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, '{:}-{:}'.format(args.algo, args.drop_path_rate))
|
||||
|
||||
main(args)
|
||||
|
Loading…
Reference in New Issue
Block a user