Prototype generic nas model (cont.).
This commit is contained in:
		| @@ -12,6 +12,10 @@ | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 1 | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas | ||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas | ||||
| #### | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 1 | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn | ||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn | ||||
| ###################################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| @@ -252,6 +256,7 @@ def main(xargs): | ||||
|   logger.log('model config : {:}'.format(model_config)) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   search_model.set_algo(xargs.algo) | ||||
|   logger.log('{:}'.format(search_model)) | ||||
|  | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) | ||||
|   a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||
| @@ -396,6 +401,6 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--rand_seed',          type=int,   help='manual seed') | ||||
|   args = parser.parse_args() | ||||
|   if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, args.algo) | ||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, '{:}-{:}'.format(args.algo, args.drop_path_rate)) | ||||
|  | ||||
|   main(args) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user