Prototype generic nas model (cont.).
This commit is contained in:
		| @@ -2,7 +2,7 @@ | |||||||
|   "scheduler": ["str",   "cos"], |   "scheduler": ["str",   "cos"], | ||||||
|   "LR"       : ["float", "0.025"], |   "LR"       : ["float", "0.025"], | ||||||
|   "eta_min"  : ["float", "0.001"], |   "eta_min"  : ["float", "0.001"], | ||||||
|   "epochs"   : ["int",   "150"], |   "epochs"   : ["int",   "100"], | ||||||
|   "warmup"   : ["int",   "0"], |   "warmup"   : ["int",   "0"], | ||||||
|   "optim"    : ["str",   "SGD"], |   "optim"    : ["str",   "SGD"], | ||||||
|   "decay"    : ["float", "0.0005"], |   "decay"    : ["float", "0.0005"], | ||||||
|   | |||||||
| @@ -12,6 +12,10 @@ | |||||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 1 | # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 1 | ||||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas | # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas | ||||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas | # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas | ||||||
|  | #### | ||||||
|  | # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 1 | ||||||
|  | # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn | ||||||
|  | # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn | ||||||
| ###################################################################################### | ###################################################################################### | ||||||
| import os, sys, time, random, argparse | import os, sys, time, random, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -252,6 +256,7 @@ def main(xargs): | |||||||
|   logger.log('model config : {:}'.format(model_config)) |   logger.log('model config : {:}'.format(model_config)) | ||||||
|   search_model = get_cell_based_tiny_net(model_config) |   search_model = get_cell_based_tiny_net(model_config) | ||||||
|   search_model.set_algo(xargs.algo) |   search_model.set_algo(xargs.algo) | ||||||
|  |   logger.log('{:}'.format(search_model)) | ||||||
|  |  | ||||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) |   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) | ||||||
|   a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) |   a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||||
| @@ -396,6 +401,6 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--rand_seed',          type=int,   help='manual seed') |   parser.add_argument('--rand_seed',          type=int,   help='manual seed') | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) |   if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, args.algo) |   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, '{:}-{:}'.format(args.algo, args.drop_path_rate)) | ||||||
|  |  | ||||||
|   main(args) |   main(args) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user