Prototype generic nas model.
This commit is contained in:
		
							
								
								
									
										14
									
								
								configs/nas-benchmark/algos/weight-sharing.config
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								configs/nas-benchmark/algos/weight-sharing.config
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | { | ||||||
|  |   "scheduler": ["str",   "cos"], | ||||||
|  |   "LR"       : ["float", "0.025"], | ||||||
|  |   "eta_min"  : ["float", "0.001"], | ||||||
|  |   "epochs"   : ["int",   "250"], | ||||||
|  |   "warmup"   : ["int",   "0"], | ||||||
|  |   "optim"    : ["str",   "SGD"], | ||||||
|  |   "decay"    : ["float", "0.0005"], | ||||||
|  |   "momentum" : ["float", "0.9"], | ||||||
|  |   "nesterov" : ["bool",  "1"], | ||||||
|  |   "criterion": ["str",   "Softmax"], | ||||||
|  |   "batch_size": ["int",  "64"], | ||||||
|  |   "test_batch_size": ["int",  "512"] | ||||||
|  | } | ||||||
| @@ -14,7 +14,7 @@ do | |||||||
|     python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 |     python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 | ||||||
|     python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 |     python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||||
|     python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} |     python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} | ||||||
|     python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 |     python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|   done |   done | ||||||
| done | done | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										290
									
								
								exps/algos-v2/search-cell.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										290
									
								
								exps/algos-v2/search-cell.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,290 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||||
|  | ###################################################################################### | ||||||
|  | # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 1 | ||||||
|  | # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 | ||||||
|  | # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1 | ||||||
|  | ###################################################################################### | ||||||
|  | import os, sys, time, random, argparse | ||||||
|  | import numpy as np | ||||||
|  | from copy import deepcopy | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from pathlib import Path | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  | from config_utils import load_config, dict2config, configure2str | ||||||
|  | from datasets     import get_datasets, get_nas_search_loaders | ||||||
|  | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
|  | from utils        import count_parameters_in_MB, obtain_accuracy | ||||||
|  | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
|  | from models       import get_cell_based_tiny_net, get_search_spaces | ||||||
|  | from nas_201_api  import NASBench201API as API | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | ||||||
|  |   data_time, batch_time = AverageMeter(), AverageMeter() | ||||||
|  |   base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|  |   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|  |   end = time.time() | ||||||
|  |   network.train() | ||||||
|  |   for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): | ||||||
|  |     scheduler.update(None, 1.0 * step / len(xloader)) | ||||||
|  |     base_targets = base_targets.cuda(non_blocking=True) | ||||||
|  |     arch_targets = arch_targets.cuda(non_blocking=True) | ||||||
|  |     # measure data loading time | ||||||
|  |     data_time.update(time.time() - end) | ||||||
|  |      | ||||||
|  |     # update the weights | ||||||
|  |     sampled_arch = network.module.dync_genotype(True) | ||||||
|  |     network.module.set_cal_mode('dynamic', sampled_arch) | ||||||
|  |     #network.module.set_cal_mode( 'urs' ) | ||||||
|  |     network.zero_grad() | ||||||
|  |     _, logits = network(base_inputs) | ||||||
|  |     base_loss = criterion(logits, base_targets) | ||||||
|  |     base_loss.backward() | ||||||
|  |     w_optimizer.step() | ||||||
|  |     # record | ||||||
|  |     base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||||
|  |     base_losses.update(base_loss.item(),  base_inputs.size(0)) | ||||||
|  |     base_top1.update  (base_prec1.item(), base_inputs.size(0)) | ||||||
|  |     base_top5.update  (base_prec5.item(), base_inputs.size(0)) | ||||||
|  |  | ||||||
|  |     # update the architecture-weight | ||||||
|  |     network.module.set_cal_mode( 'joint' ) | ||||||
|  |     network.zero_grad() | ||||||
|  |     _, logits = network(arch_inputs) | ||||||
|  |     arch_loss = criterion(logits, arch_targets) | ||||||
|  |     arch_loss.backward() | ||||||
|  |     a_optimizer.step() | ||||||
|  |     # record | ||||||
|  |     arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||||
|  |     arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||||
|  |     arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) | ||||||
|  |     arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) | ||||||
|  |  | ||||||
|  |     # measure elapsed time | ||||||
|  |     batch_time.update(time.time() - end) | ||||||
|  |     end = time.time() | ||||||
|  |  | ||||||
|  |     if step % print_freq == 0 or step + 1 == len(xloader): | ||||||
|  |       Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) | ||||||
|  |       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||||
|  |       Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) | ||||||
|  |       Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) | ||||||
|  |       logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) | ||||||
|  |       #print (nn.functional.softmax(network.module.arch_parameters, dim=-1)) | ||||||
|  |       #print (network.module.arch_parameters) | ||||||
|  |   return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_best_arch(xloader, network, n_samples): | ||||||
|  |   with torch.no_grad(): | ||||||
|  |     network.eval() | ||||||
|  |     archs, valid_accs = network.module.return_topK(n_samples), [] | ||||||
|  |     #print ('obtain the top-{:} architectures'.format(n_samples)) | ||||||
|  |     loader_iter = iter(xloader) | ||||||
|  |     for i, sampled_arch in enumerate(archs): | ||||||
|  |       network.module.set_cal_mode('dynamic', sampled_arch) | ||||||
|  |       try: | ||||||
|  |         inputs, targets = next(loader_iter) | ||||||
|  |       except: | ||||||
|  |         loader_iter = iter(xloader) | ||||||
|  |         inputs, targets = next(loader_iter) | ||||||
|  |  | ||||||
|  |       _, logits = network(inputs) | ||||||
|  |       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) | ||||||
|  |  | ||||||
|  |       valid_accs.append(val_top1.item()) | ||||||
|  |  | ||||||
|  |     best_idx = np.argmax(valid_accs) | ||||||
|  |     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||||
|  |     return best_arch, best_valid_acc | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def valid_func(xloader, network, criterion): | ||||||
|  |   data_time, batch_time = AverageMeter(), AverageMeter() | ||||||
|  |   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|  |   end = time.time() | ||||||
|  |   with torch.no_grad(): | ||||||
|  |     network.eval() | ||||||
|  |     for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||||
|  |       arch_targets = arch_targets.cuda(non_blocking=True) | ||||||
|  |       # measure data loading time | ||||||
|  |       data_time.update(time.time() - end) | ||||||
|  |       # prediction | ||||||
|  |       _, logits = network(arch_inputs) | ||||||
|  |       arch_loss = criterion(logits, arch_targets) | ||||||
|  |       # record | ||||||
|  |       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||||
|  |       arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||||
|  |       arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) | ||||||
|  |       arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) | ||||||
|  |       # measure elapsed time | ||||||
|  |       batch_time.update(time.time() - end) | ||||||
|  |       end = time.time() | ||||||
|  |   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(xargs): | ||||||
|  |   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||||
|  |   torch.backends.cudnn.enabled   = True | ||||||
|  |   torch.backends.cudnn.benchmark = False | ||||||
|  |   torch.backends.cudnn.deterministic = True | ||||||
|  |   torch.set_num_threads( xargs.workers ) | ||||||
|  |   prepare_seed(xargs.rand_seed) | ||||||
|  |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|  |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|  |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|  |   search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ | ||||||
|  |                                         (config.batch_size, config.test_batch_size), xargs.workers) | ||||||
|  |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
|  |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
|  |   search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') | ||||||
|  |  | ||||||
|  |    | ||||||
|  |   model_config = dict2config( | ||||||
|  |       dict(name='generic', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, | ||||||
|  |            space=search_space, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None) | ||||||
|  |   logger.log('search space : {:}'.format(search_space)) | ||||||
|  |   logger.log('model config : {:}'.format(model_config)) | ||||||
|  |   search_model = get_cell_based_tiny_net(model_config) | ||||||
|  |   search_model.set_algo(xargs.algo) | ||||||
|  |  | ||||||
|  |   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) | ||||||
|  |   a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||||
|  |   logger.log('w-optimizer : {:}'.format(w_optimizer)) | ||||||
|  |   logger.log('a-optimizer : {:}'.format(a_optimizer)) | ||||||
|  |   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||||
|  |   logger.log('criterion   : {:}'.format(criterion)) | ||||||
|  |   params = count_parameters_in_MB(search_model) | ||||||
|  |   logger.log('The parameters of the search model = {:.2f} MB'.format(params)) | ||||||
|  |   logger.log('search-space : {:}'.format(search_space)) | ||||||
|  |   api = API() | ||||||
|  |   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||||
|  |  | ||||||
|  |   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||||
|  |   # network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||||
|  |   network, criterion = search_model.cuda(), criterion.cuda()  # use a single GPU | ||||||
|  |  | ||||||
|  |   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||||
|  |  | ||||||
|  |   if last_info.exists(): # automatically resume from previous checkpoint | ||||||
|  |     logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) | ||||||
|  |     last_info   = torch.load(last_info) | ||||||
|  |     start_epoch = last_info['epoch'] | ||||||
|  |     checkpoint  = torch.load(last_info['last_checkpoint']) | ||||||
|  |     genotypes   = checkpoint['genotypes'] | ||||||
|  |     valid_accuracies = checkpoint['valid_accuracies'] | ||||||
|  |     search_model.load_state_dict( checkpoint['search_model'] ) | ||||||
|  |     w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) | ||||||
|  |     w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) | ||||||
|  |     a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) | ||||||
|  |     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||||
|  |   else: | ||||||
|  |     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||||
|  |     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} | ||||||
|  |  | ||||||
|  |   # start training | ||||||
|  |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
|  |   for epoch in range(start_epoch, total_epoch): | ||||||
|  |     w_scheduler.update(epoch, 0.0) | ||||||
|  |     need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True)) | ||||||
|  |     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) | ||||||
|  |     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) | ||||||
|  |  | ||||||
|  |     import pdb; pdb.set_trace() | ||||||
|  |    | ||||||
|  |     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ | ||||||
|  |                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) | ||||||
|  |     search_time.update(time.time() - start_time) | ||||||
|  |     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||||
|  |     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) | ||||||
|  |  | ||||||
|  |     genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||||
|  |     network.module.set_cal_mode('dynamic', genotype) | ||||||
|  |     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||||
|  |     logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) | ||||||
|  |     #search_model.set_cal_mode('urs') | ||||||
|  |     #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||||
|  |     #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||||
|  |     #search_model.set_cal_mode('joint') | ||||||
|  |     #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||||
|  |     #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||||
|  |     #search_model.set_cal_mode('select') | ||||||
|  |     #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||||
|  |     #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||||
|  |     # check the best accuracy | ||||||
|  |     valid_accuracies[epoch] = valid_a_top1 | ||||||
|  |  | ||||||
|  |     genotypes[epoch] = genotype | ||||||
|  |     logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) | ||||||
|  |     # save checkpoint | ||||||
|  |     save_path = save_checkpoint({'epoch' : epoch + 1, | ||||||
|  |                 'args'  : deepcopy(xargs), | ||||||
|  |                 'search_model': search_model.state_dict(), | ||||||
|  |                 'w_optimizer' : w_optimizer.state_dict(), | ||||||
|  |                 'a_optimizer' : a_optimizer.state_dict(), | ||||||
|  |                 'w_scheduler' : w_scheduler.state_dict(), | ||||||
|  |                 'genotypes'   : genotypes, | ||||||
|  |                 'valid_accuracies' : valid_accuracies}, | ||||||
|  |                 model_base_path, logger) | ||||||
|  |     last_info = save_checkpoint({ | ||||||
|  |           'epoch': epoch + 1, | ||||||
|  |           'args' : deepcopy(args), | ||||||
|  |           'last_checkpoint': save_path, | ||||||
|  |           }, logger.path('info'), logger) | ||||||
|  |     with torch.no_grad(): | ||||||
|  |       logger.log('{:}'.format(search_model.show_alphas())) | ||||||
|  |     if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) | ||||||
|  |     # measure elapsed time | ||||||
|  |     epoch_time.update(time.time() - start_time) | ||||||
|  |     start_time = time.time() | ||||||
|  |  | ||||||
|  |   # the final post procedure : count the time | ||||||
|  |   start_time = time.time() | ||||||
|  |   genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||||
|  |   search_time.update(time.time() - start_time) | ||||||
|  |   network.module.set_cal_mode('dynamic', genotype) | ||||||
|  |   valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) | ||||||
|  |   logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) | ||||||
|  |  | ||||||
|  |   logger.log('\n' + '-'*100) | ||||||
|  |   # check the performance from the architecture dataset | ||||||
|  |   logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype)) | ||||||
|  |   if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') )) | ||||||
|  |   logger.close() | ||||||
|  |    | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |   parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") | ||||||
|  |   parser.add_argument('--data_path'   ,       type=str,   help='Path to dataset') | ||||||
|  |   parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||||
|  |   parser.add_argument('--search_space',       type=str,   default='tss', choices=['tss'], help='The search space name.') | ||||||
|  |   parser.add_argument('--algo'        ,       type=str,   choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.') | ||||||
|  |   # channels and number-of-cells | ||||||
|  |   parser.add_argument('--max_nodes'   ,       type=int,   default=4,  help='The maximum number of nodes.') | ||||||
|  |   parser.add_argument('--channel'     ,       type=int,   default=16, help='The number of channels.') | ||||||
|  |   parser.add_argument('--num_cells'   ,       type=int,   default=5,  help='The number of cells in one stage.') | ||||||
|  |   # | ||||||
|  |   parser.add_argument('--eval_candidate_num', type=int,   help='The number of selected architectures to evaluate.') | ||||||
|  |   # | ||||||
|  |   parser.add_argument('--track_running_stats',type=int,   default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||||
|  |   parser.add_argument('--affine'      ,       type=int,   default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') | ||||||
|  |   parser.add_argument('--config_path' ,       type=str,   default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') | ||||||
|  |   # architecture leraning rate | ||||||
|  |   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||||
|  |   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||||
|  |   # log | ||||||
|  |   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||||
|  |   parser.add_argument('--save_dir',           type=str,   default='./output/search', help='Folder to save checkpoints and log.') | ||||||
|  |   parser.add_argument('--print_freq',         type=int,   help='print frequency (default: 200)') | ||||||
|  |   parser.add_argument('--rand_seed',          type=int,   help='manual seed') | ||||||
|  |   args = parser.parse_args() | ||||||
|  |   if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||||
|  |   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, args.algo) | ||||||
|  |  | ||||||
|  |   main(args) | ||||||
| @@ -20,7 +20,7 @@ from .cell_searchs import CellStructure, CellArchitectures | |||||||
| def get_cell_based_tiny_net(config): | def get_cell_based_tiny_net(config): | ||||||
|   if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict |   if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |   super_type = getattr(config, 'super_type', 'basic') | ||||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] |   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM', 'generic'] | ||||||
|   if super_type == 'basic' and config.name in group_names: |   if super_type == 'basic' and config.name in group_names: | ||||||
|     from .cell_searchs import nas201_super_nets as nas_super_nets |     from .cell_searchs import nas201_super_nets as nas_super_nets | ||||||
|     try: |     try: | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ from .search_model_gdas     import TinyNetworkGDAS | |||||||
| from .search_model_setn     import TinyNetworkSETN | from .search_model_setn     import TinyNetworkSETN | ||||||
| from .search_model_enas     import TinyNetworkENAS | from .search_model_enas     import TinyNetworkENAS | ||||||
| from .search_model_random   import TinyNetworkRANDOM | from .search_model_random   import TinyNetworkRANDOM | ||||||
|  | from .generic_model         import GenericNAS201Model | ||||||
| from .genotypes             import Structure as CellStructure, architectures as CellArchitectures | from .genotypes             import Structure as CellStructure, architectures as CellArchitectures | ||||||
| # NASNet-based macro structure | # NASNet-based macro structure | ||||||
| from .search_model_gdas_nasnet import NASNetworkGDAS | from .search_model_gdas_nasnet import NASNetworkGDAS | ||||||
| @@ -18,7 +19,8 @@ nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, | |||||||
|                      "GDAS": TinyNetworkGDAS, |                      "GDAS": TinyNetworkGDAS, | ||||||
|                      "SETN": TinyNetworkSETN, |                      "SETN": TinyNetworkSETN, | ||||||
|                      "ENAS": TinyNetworkENAS, |                      "ENAS": TinyNetworkENAS, | ||||||
|                      "RANDOM": TinyNetworkRANDOM} |                      "RANDOM": TinyNetworkRANDOM, | ||||||
|  |                      "generic": GenericNAS201Model} | ||||||
|  |  | ||||||
| nasnet_super_nets = {"GDAS": NASNetworkGDAS, | nasnet_super_nets = {"GDAS": NASNetworkGDAS, | ||||||
|                      "DARTS": NASNetworkDARTS} |                      "DARTS": NASNetworkDARTS} | ||||||
|   | |||||||
							
								
								
									
										200
									
								
								lib/models/cell_searchs/generic_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								lib/models/cell_searchs/generic_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,200 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 # | ||||||
|  | ##################################################### | ||||||
|  | import torch, random | ||||||
|  | import torch.nn as nn | ||||||
|  | from copy import deepcopy | ||||||
|  | from typing import Text | ||||||
|  |  | ||||||
|  | from ..cell_operations import ResNetBasicblock | ||||||
|  | from .search_cells     import NAS201SearchCell as SearchCell | ||||||
|  | from .genotypes        import Structure | ||||||
|  | from .search_model_enas_utils import Controller | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class GenericNAS201Model(nn.Module): | ||||||
|  |  | ||||||
|  |   def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): | ||||||
|  |     super(GenericNAS201Model, self).__init__() | ||||||
|  |     self._C          = C | ||||||
|  |     self._layerN     = N | ||||||
|  |     self._max_nodes  = max_nodes | ||||||
|  |     self._stem       = nn.Sequential( | ||||||
|  |                          nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), | ||||||
|  |                          nn.BatchNorm2d(C)) | ||||||
|  |     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||||
|  |     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||||
|  |     C_prev, num_edge, edge2index = C, None, None | ||||||
|  |     self._cells      = nn.ModuleList() | ||||||
|  |     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||||
|  |       if reduction: | ||||||
|  |         cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||||
|  |       else: | ||||||
|  |         cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) | ||||||
|  |         if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index | ||||||
|  |         else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) | ||||||
|  |       self._cells.append(cell) | ||||||
|  |       C_prev = cell.out_dim | ||||||
|  |     self._op_names   = deepcopy(search_space) | ||||||
|  |     self._Layer      = len(self._cells) | ||||||
|  |     self.edge2index  = edge2index | ||||||
|  |     self.lastact     = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||||
|  |     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||||
|  |     self.classifier  = nn.Linear(C_prev, num_classes) | ||||||
|  |     self._num_edge   = num_edge | ||||||
|  |     # algorithm related | ||||||
|  |     self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||||
|  |     self._mode        = None | ||||||
|  |     self.dynamic_cell = None | ||||||
|  |     self._tau         = None | ||||||
|  |     self._algo        = None | ||||||
|  |  | ||||||
|  |   def set_algo(self, algo: Text): | ||||||
|  |     # used for searching | ||||||
|  |     assert self._algo is None, 'This functioin can only be called once.' | ||||||
|  |     self._algo = algo | ||||||
|  |     if algo == 'enas': | ||||||
|  |       self.controller = Controller(len(self.edge2index), len(self._op_names)) | ||||||
|  |     else: | ||||||
|  |       self.arch_parameters = nn.Parameter( 1e-3*torch.randn(self._num_edge, len(self._op_names)) ) | ||||||
|  |       if algo == 'gdas': | ||||||
|  |         self._tau         = 10 | ||||||
|  |      | ||||||
|  |   def set_cal_mode(self, mode, dynamic_cell=None): | ||||||
|  |     assert mode in ['gdas', 'enas', 'urs', 'joint', 'select', 'dynamic'] | ||||||
|  |     self.mode = mode | ||||||
|  |     if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell) | ||||||
|  |     else                : self.dynamic_cell = None | ||||||
|  |  | ||||||
|  |   @property | ||||||
|  |   def mode(self): | ||||||
|  |     return self._mode | ||||||
|  |  | ||||||
|  |   @property | ||||||
|  |   def weights(self): | ||||||
|  |     xlist = list(self._stem.parameters()) | ||||||
|  |     xlist+= list(self._cells.parameters()) | ||||||
|  |     xlist+= list(self.lastact.parameters()) | ||||||
|  |     xlist+= list(self.global_pooling.parameters()) | ||||||
|  |     xlist+= list(self.classifier.parameters()) | ||||||
|  |     return xlist | ||||||
|  |  | ||||||
|  |   def set_tau(self, tau): | ||||||
|  |     self._tau = tau | ||||||
|  |  | ||||||
|  |   @property | ||||||
|  |   def tau(self): | ||||||
|  |     return self._tau | ||||||
|  |  | ||||||
|  |   @property | ||||||
|  |   def alphas(self): | ||||||
|  |     if self._algo == 'enas': | ||||||
|  |       return list(self.controller.parameters()) | ||||||
|  |     else: | ||||||
|  |       return [self.arch_parameters] | ||||||
|  |  | ||||||
|  |   @property | ||||||
|  |   def message(self): | ||||||
|  |     string = self.extra_repr() | ||||||
|  |     for i, cell in enumerate(self._cells): | ||||||
|  |       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self._cells), cell.extra_repr()) | ||||||
|  |     return string | ||||||
|  |  | ||||||
|  |   def extra_repr(self): | ||||||
|  |     return ('{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||||
|  |  | ||||||
|  |   @property | ||||||
|  |   def genotype(self): | ||||||
|  |     genotypes = [] | ||||||
|  |     for i in range(1, self._max_nodes): | ||||||
|  |       xlist = [] | ||||||
|  |       for j in range(i): | ||||||
|  |         node_str = '{:}<-{:}'.format(i, j) | ||||||
|  |         with torch.no_grad(): | ||||||
|  |           weights = self.arch_parameters[ self.edge2index[node_str] ] | ||||||
|  |           op_name = self.op_names[ weights.argmax().item() ] | ||||||
|  |         xlist.append((op_name, j)) | ||||||
|  |       genotypes.append(tuple(xlist)) | ||||||
|  |     return Structure(genotypes) | ||||||
|  |  | ||||||
|  |   def dync_genotype(self, use_random=False): | ||||||
|  |     genotypes = [] | ||||||
|  |     with torch.no_grad(): | ||||||
|  |       alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||||
|  |     for i in range(1, self._max_nodes): | ||||||
|  |       xlist = [] | ||||||
|  |       for j in range(i): | ||||||
|  |         node_str = '{:}<-{:}'.format(i, j) | ||||||
|  |         if use_random: | ||||||
|  |           op_name  = random.choice(self.op_names) | ||||||
|  |         else: | ||||||
|  |           weights  = alphas_cpu[ self.edge2index[node_str] ] | ||||||
|  |           op_index = torch.multinomial(weights, 1).item() | ||||||
|  |           op_name  = self.op_names[ op_index ] | ||||||
|  |         xlist.append((op_name, j)) | ||||||
|  |       genotypes.append(tuple(xlist)) | ||||||
|  |     return Structure(genotypes) | ||||||
|  |  | ||||||
|  |   def get_log_prob(self, arch): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |       logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) | ||||||
|  |     select_logits = [] | ||||||
|  |     for i, node_info in enumerate(arch.nodes): | ||||||
|  |       for op, xin in node_info: | ||||||
|  |         node_str = '{:}<-{:}'.format(i+1, xin) | ||||||
|  |         op_index = self.op_names.index(op) | ||||||
|  |         select_logits.append( logits[self.edge2index[node_str], op_index] ) | ||||||
|  |     return sum(select_logits).item() | ||||||
|  |  | ||||||
|  |   def return_topK(self, K): | ||||||
|  |     archs = Structure.gen_all(self.op_names, self._max_nodes, False) | ||||||
|  |     pairs = [(self.get_log_prob(arch), arch) for arch in archs] | ||||||
|  |     if K < 0 or K >= len(archs): K = len(archs) | ||||||
|  |     sorted_pairs = sorted(pairs, key=lambda x: -x[0]) | ||||||
|  |     return_pairs = [sorted_pairs[_][1] for _ in range(K)] | ||||||
|  |     return return_pairs | ||||||
|  |  | ||||||
|  |   def normalize_archp(self): | ||||||
|  |     if self.mode == 'gdas': | ||||||
|  |       while True: | ||||||
|  |         gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() | ||||||
|  |         logits  = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau | ||||||
|  |         probs   = nn.functional.softmax(logits, dim=1) | ||||||
|  |         index   = probs.max(-1, keepdim=True)[1] | ||||||
|  |         one_h   = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||||
|  |         hardwts = one_h - probs.detach() + probs | ||||||
|  |         if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): | ||||||
|  |           continue | ||||||
|  |         else: break | ||||||
|  |       with torch.no_grad(): | ||||||
|  |         hardwts_cpu = hardwts.detach().cpu() | ||||||
|  |       return hardwts, hardwts_cpu, index | ||||||
|  |     else: | ||||||
|  |       alphas  = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||||
|  |       index   = alphas.max(-1, keepdim=True)[1] | ||||||
|  |       with torch.no_grad(): | ||||||
|  |         alphas_cpu = alphas.detach().cpu() | ||||||
|  |       return alphas, alphas_cpu, index | ||||||
|  |  | ||||||
|  |   def forward(self, inputs): | ||||||
|  |     alphas, alphas_cpu, index = self.normalize_archp() | ||||||
|  |     feature = self._stem(inputs) | ||||||
|  |     for i, cell in enumerate(self._cells): | ||||||
|  |       if isinstance(cell, SearchCell): | ||||||
|  |         if self.mode == 'urs': | ||||||
|  |           feature = cell.forward_urs(feature) | ||||||
|  |         elif self.mode == 'select': | ||||||
|  |           feature = cell.forward_select(feature, alphas_cpu) | ||||||
|  |         elif self.mode == 'joint': | ||||||
|  |           feature = cell.forward_joint(feature, alphas) | ||||||
|  |         elif self.mode == 'dynamic': | ||||||
|  |           feature = cell.forward_dynamic(feature, self.dynamic_cell) | ||||||
|  |         elif self.mode == 'gdas': | ||||||
|  |           feature = cell.forward_gdas(feature, alphas, index) | ||||||
|  |         else: raise ValueError('invalid mode={:}'.format(self.mode)) | ||||||
|  |       else: feature = cell(feature) | ||||||
|  |     out = self.lastact(feature) | ||||||
|  |     out = self.global_pooling(out) | ||||||
|  |     out = out.view(out.size(0), -1) | ||||||
|  |     logits = self.classifier(out) | ||||||
|  |     return out, logits | ||||||
| @@ -1,5 +1,5 @@ | |||||||
| from .evaluation_utils import obtain_accuracy | from .evaluation_utils import obtain_accuracy | ||||||
| from .gpu_manager      import GPUManager | from .gpu_manager      import GPUManager | ||||||
| from .flop_benchmark   import get_model_infos | from .flop_benchmark   import get_model_infos, count_parameters_in_MB | ||||||
| from .affine_utils     import normalize_points, denormalize_points | from .affine_utils     import normalize_points, denormalize_points | ||||||
| from .affine_utils     import identity2affine, solve2theta, affine2image | from .affine_utils     import identity2affine, solve2theta, affine2image | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user