update ENAS
This commit is contained in:
		
							
								
								
									
										17
									
								
								configs/nas-benchmark/algos/ENAS.config
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								configs/nas-benchmark/algos/ENAS.config
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| { | ||||
|   "scheduler": ["str",   "cos"], | ||||
|   "LR"       : ["float", "0.05"], | ||||
|   "eta_min"  : ["float", "0.0005"], | ||||
|   "epochs"   : ["int",   "310"], | ||||
|   "T_max"    : ["int",   "10"], | ||||
|   "warmup"   : ["int",   "0"], | ||||
|   "optim"    : ["str",   "SGD"], | ||||
|   "decay"    : ["float", "0.00025"], | ||||
|   "momentum" : ["float", "0.9"], | ||||
|   "nesterov" : ["bool",  "1"], | ||||
|   "controller_lr"   : ["float", "0.001"], | ||||
|   "controller_betas": ["float", [0, 0.999]], | ||||
|   "controller_eps"  : ["float", 0.001], | ||||
|   "criterion": ["str",   "Softmax"], | ||||
|   "batch_size": ["int",  "128"] | ||||
| } | ||||
							
								
								
									
										347
									
								
								exps/algos/ENAS.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										347
									
								
								exps/algos/ENAS.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,347 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, dict2config, configure2str | ||||
| from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
|  | ||||
|  | ||||
| def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(), time.time() | ||||
|    | ||||
|   shared_cnn.train() | ||||
|   controller.eval() | ||||
|  | ||||
|   for step, (inputs, targets) in enumerate(xloader): | ||||
|     scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - xend) | ||||
|      | ||||
|     with torch.no_grad(): | ||||
|       _, _, sampled_arch = controller() | ||||
|  | ||||
|     optimizer.zero_grad() | ||||
|     shared_cnn.module.update_arch(sampled_arch) | ||||
|     _, logits = shared_cnn(inputs) | ||||
|     loss      = criterion(logits, targets) | ||||
|     loss.backward() | ||||
|     torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) | ||||
|     optimizer.step() | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),  inputs.size(0)) | ||||
|     top1s.update (prec1.item(), inputs.size(0)) | ||||
|     top5s.update (prec5.item(), inputs.size(0)) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - xend) | ||||
|     xend = time.time() | ||||
|  | ||||
|     if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|       Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=losses, top1=top1s, top5=top5s) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) | ||||
|   return losses.avg, top1s.avg, top5s.avg | ||||
|  | ||||
|  | ||||
| def train_controller(xloader, shared_cnn, controller, criterion, optimizer, config, epoch_str, print_freq, logger): | ||||
|   # config. (containing some necessary arg) | ||||
|   #   baseline: The baseline score (i.e. average val_acc) from the previous epoch | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   GradnormMeter, LossMeter, ValAccMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time() | ||||
|    | ||||
|   shared_cnn.eval() | ||||
|   controller.train() | ||||
|   controller.zero_grad() | ||||
|   #for step, (inputs, targets) in enumerate(xloader): | ||||
|   loader_iter = iter(xloader) | ||||
|   for step in range(config.ctl_train_steps * config.ctl_num_aggre): | ||||
|     try: | ||||
|       inputs, targets = next(loader_iter) | ||||
|     except: | ||||
|       loader_iter = iter(xloader) | ||||
|       inputs, targets = next(loader_iter) | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - xend) | ||||
|      | ||||
|     log_prob, entropy, sampled_arch = controller() | ||||
|     with torch.no_grad(): | ||||
|       shared_cnn.module.update_arch(sampled_arch) | ||||
|       _, logits = shared_cnn(inputs) | ||||
|       val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       val_top1  = val_top1.view(-1) / 100 | ||||
|     reward = val_top1 + config.ctl_entropy_w * entropy | ||||
|     if config.baseline is None: | ||||
|       baseline = val_top1 | ||||
|     else: | ||||
|       baseline = config.baseline - (1 - config.ctl_bl_dec) * (config.baseline - reward) | ||||
|     | ||||
|     loss = -1 * log_prob * (reward - baseline) | ||||
|      | ||||
|     # account | ||||
|     RewardMeter.update(reward.item()) | ||||
|     BaselineMeter.update(baseline.item()) | ||||
|     ValAccMeter.update(val_top1.item()) | ||||
|     LossMeter.update(loss.item()) | ||||
|    | ||||
|     # Average gradient over controller_num_aggregate samples | ||||
|     loss = loss / config.ctl_num_aggre | ||||
|     loss.backward(retain_graph=True) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - xend) | ||||
|     xend = time.time() | ||||
|     if (step+1) % config.ctl_num_aggre == 0: | ||||
|       grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0) | ||||
|       GradnormMeter.update(grad_norm) | ||||
|       optimizer.step() | ||||
|       controller.zero_grad() | ||||
|  | ||||
|     if step % print_freq == 0: | ||||
|       Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) | ||||
|  | ||||
|   return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item() | ||||
|  | ||||
|  | ||||
| def get_best_arch(controller, shared_cnn, xloader, n_samples=10): | ||||
|   with torch.no_grad(): | ||||
|     controller.eval() | ||||
|     shared_cnn.eval() | ||||
|     archs, valid_accs = [], [] | ||||
|     loader_iter = iter(xloader) | ||||
|     for i in range(n_samples): | ||||
|       try: | ||||
|         inputs, targets = next(loader_iter) | ||||
|       except: | ||||
|         loader_iter = iter(xloader) | ||||
|         inputs, targets = next(loader_iter) | ||||
|  | ||||
|       _, _, sampled_arch = controller() | ||||
|       arch = shared_cnn.module.update_arch(sampled_arch) | ||||
|       _, logits = shared_cnn(inputs) | ||||
|       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) | ||||
|  | ||||
|       archs.append( arch ) | ||||
|       valid_accs.append( val_top1.item() ) | ||||
|  | ||||
|     best_idx = np.argmax(valid_accs) | ||||
|     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
|     return best_arch, best_valid_acc | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   network.eval() | ||||
|   end = time.time() | ||||
|   with torch.no_grad(): | ||||
|     for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|       arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|       # measure data loading time | ||||
|       data_time.update(time.time() - end) | ||||
|       # prediction | ||||
|       _, logits = network(arch_inputs) | ||||
|       arch_loss = criterion(logits, arch_targets) | ||||
|       # record | ||||
|       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||
|       arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||
|       arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) | ||||
|       arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) | ||||
|       # measure elapsed time | ||||
|       batch_time.update(time.time() - end) | ||||
|       end = time.time() | ||||
|   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   torch.backends.cudnn.benchmark = False | ||||
|   torch.backends.cudnn.deterministic = True | ||||
|   torch.set_num_threads( xargs.workers ) | ||||
|   prepare_seed(xargs.rand_seed) | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   elif xargs.dataset.startswith('ImageNet16'): | ||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) | ||||
|     imagenet16_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||
|   logger.log('use config from : {:}'.format(xargs.config_path)) | ||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   logger.log('config: {:}'.format(config)) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = test_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   # data loader | ||||
|   train_loader  = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|  | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|   shared_cnn = get_cell_based_tiny_net(model_config) | ||||
|   controller = shared_cnn.create_controller() | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config) | ||||
|   a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps) | ||||
|   logger.log('w-optimizer : {:}'.format(w_optimizer)) | ||||
|   logger.log('a-optimizer : {:}'.format(a_optimizer)) | ||||
|   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||
|   logger.log('criterion   : {:}'.format(criterion)) | ||||
|   #flop, param  = get_model_infos(shared_cnn, xshape) | ||||
|   #logger.log('{:}'.format(shared_cnn)) | ||||
|   #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() | ||||
|  | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|  | ||||
|   if last_info.exists(): # automatically resume from previous checkpoint | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) | ||||
|     last_info   = torch.load(last_info) | ||||
|     start_epoch = last_info['epoch'] | ||||
|     checkpoint  = torch.load(last_info['last_checkpoint']) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     baseline    = checkpoint['baseline'] | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     shared_cnn.load_state_dict( checkpoint['shared_cnn'] ) | ||||
|     controller.load_state_dict( checkpoint['controller'] ) | ||||
|     w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) | ||||
|     w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) | ||||
|     a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||
|   else: | ||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|     start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None | ||||
|  | ||||
|   # start training | ||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup | ||||
|   for epoch in range(start_epoch, total_epoch): | ||||
|     w_scheduler.update(epoch, 0.0) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) | ||||
|     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) | ||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) | ||||
|  | ||||
|     cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) | ||||
|     logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) | ||||
|     ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \ | ||||
|                                  = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \ | ||||
|                                                         dict2config({'baseline': baseline, | ||||
|                                                                      'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate, | ||||
|                                                                      'ctl_entropy_w': xargs.controller_entropy_weight,  | ||||
|                                                                      'ctl_bl_dec'   : xargs.controller_bl_dec}, None), \ | ||||
|                                                         epoch_str, xargs.print_freq, logger) | ||||
|     logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline)) | ||||
|     best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) | ||||
|     shared_cnn.module.update_arch(best_arch) | ||||
|     best_valid_acc = valid_func(valid_loader, shared_cnn, criterion) | ||||
|  | ||||
|     genotypes[epoch] = best_arch | ||||
|     # check the best accuracy | ||||
|     valid_accuracies[epoch] = best_valid_acc | ||||
|     if best_valid_acc > valid_accuracies['best']: | ||||
|       valid_accuracies['best'] = best_valid_acc | ||||
|       genotypes['best']        = best_arch | ||||
|       find_best = True | ||||
|     else: find_best = False | ||||
|  | ||||
|     logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) | ||||
|     # save checkpoint | ||||
|     save_path = save_checkpoint({'epoch' : epoch + 1, | ||||
|                 'args'  : deepcopy(xargs), | ||||
|                 'baseline'    : baseline, | ||||
|                 'shared_cnn'  : shared_cnn.state_dict(), | ||||
|                 'controller'  : controller.state_dict(), | ||||
|                 'w_optimizer' : w_optimizer.state_dict(), | ||||
|                 'a_optimizer' : a_optimizer.state_dict(), | ||||
|                 'w_scheduler' : w_scheduler.state_dict(), | ||||
|                 'genotypes'   : genotypes, | ||||
|                 'valid_accuracies' : valid_accuracies}, | ||||
|                 model_base_path, logger) | ||||
|     last_info = save_checkpoint({ | ||||
|           'epoch': epoch + 1, | ||||
|           'args' : deepcopy(args), | ||||
|           'last_checkpoint': save_path, | ||||
|           }, logger.path('info'), logger) | ||||
|     if find_best: | ||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc)) | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('\n' + '-'*100) | ||||
|   # check the performance from the architecture dataset | ||||
|   #if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): | ||||
|   #  logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) | ||||
|   #else: | ||||
|   #  nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) | ||||
|   #  geno = genotypes[total_epoch-1] | ||||
|   #  logger.log('The last model is {:}'.format(geno)) | ||||
|   #  info = nas_bench.query_by_arch( geno ) | ||||
|   #  if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) | ||||
|   #  else           : logger.log('{:}'.format(info)) | ||||
|   #  logger.log('-'*100) | ||||
|   #  geno = genotypes['best'] | ||||
|   #  logger.log('The best model is {:}'.format(geno)) | ||||
|   #  info = nas_bench.query_by_arch( geno ) | ||||
|   #  if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) | ||||
|   #  else           : logger.log('{:}'.format(info)) | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("ENAS") | ||||
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||
|   # channels and number-of-cells | ||||
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | ||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--config_path',        type=str,   help='The config file to train ENAS.') | ||||
|   parser.add_argument('--controller_train_steps',    type=int,     help='.') | ||||
|   parser.add_argument('--controller_num_aggregate',  type=int,     help='.') | ||||
|   parser.add_argument('--controller_entropy_weight', type=float,   help='The weight for the entropy of the controller.') | ||||
|   parser.add_argument('--controller_bl_dec'        , type=float,   help='.') | ||||
|   parser.add_argument('--controller_num_samples'   , type=int,     help='.') | ||||
|   # log | ||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') | ||||
|   parser.add_argument('--arch_nas_dataset',   type=str,   help='The path to load the architecture dataset (nas-benchmark).') | ||||
|   parser.add_argument('--print_freq',         type=int,   help='print frequency (default: 200)') | ||||
|   parser.add_argument('--rand_seed',          type=int,   help='manual seed') | ||||
|   args = parser.parse_args() | ||||
|   if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||
|   main(args) | ||||
| @@ -16,18 +16,10 @@ from .cell_searchs import CellStructure, CellArchitectures | ||||
|  | ||||
| # Cell-based NAS Models | ||||
| def get_cell_based_tiny_net(config): | ||||
|   if config.name == 'DARTS-V1': | ||||
|     from .cell_searchs import TinyNetworkDartsV1 | ||||
|     return TinyNetworkDartsV1(config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   elif config.name == 'DARTS-V2': | ||||
|     from .cell_searchs import TinyNetworkDartsV2 | ||||
|     return TinyNetworkDartsV2(config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   elif config.name == 'GDAS': | ||||
|     from .cell_searchs import TinyNetworkGDAS | ||||
|     return TinyNetworkGDAS(config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   elif config.name == 'SETN': | ||||
|     from .cell_searchs import TinyNetworkSETN | ||||
|     return TinyNetworkSETN(config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS'] | ||||
|   from .cell_searchs import nas_super_nets | ||||
|   if config.name in group_names: | ||||
|     return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   elif config.name == 'infer.tiny': | ||||
|     from .cell_infers import TinyNetwork | ||||
|     return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) | ||||
|   | ||||
| @@ -2,4 +2,11 @@ from .search_model_darts_v1 import TinyNetworkDartsV1 | ||||
| from .search_model_darts_v2 import TinyNetworkDartsV2 | ||||
| from .search_model_gdas     import TinyNetworkGDAS | ||||
| from .search_model_setn     import TinyNetworkSETN | ||||
| from .search_model_enas     import TinyNetworkENAS | ||||
| from .genotypes             import Structure as CellStructure, architectures as CellArchitectures | ||||
|  | ||||
| nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1, | ||||
|                   'DARTS-V2': TinyNetworkDartsV2, | ||||
|                   'GDAS'    : TinyNetworkGDAS, | ||||
|                   'SETN'    : TinyNetworkSETN, | ||||
|                   'ENAS'    : TinyNetworkENAS} | ||||
|   | ||||
							
								
								
									
										9
									
								
								lib/models/cell_searchs/_test_module.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								lib/models/cell_searchs/_test_module.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| import torch | ||||
| from search_model_enas_utils import Controller | ||||
|  | ||||
| def main(): | ||||
|   controller = Controller(6, 4) | ||||
|   predictions = controller() | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main() | ||||
							
								
								
									
										94
									
								
								lib/models/cell_searchs/search_model_enas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								lib/models/cell_searchs/search_model_enas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ########################################################################## | ||||
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | ||||
| ########################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells     import SearchCell | ||||
| from .genotypes        import Structure | ||||
| from .search_model_enas_utils import Controller | ||||
|  | ||||
|  | ||||
| class TinyNetworkENAS(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, max_nodes, num_classes, search_space): | ||||
|     super(TinyNetworkENAS, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
|     self.max_nodes = max_nodes | ||||
|     self.stem = nn.Sequential( | ||||
|                     nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), | ||||
|                     nn.BatchNorm2d(C)) | ||||
|    | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|     C_prev, num_edge, edge2index = C, None, None | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       if reduction: | ||||
|         cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|       else: | ||||
|         cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) | ||||
|         if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|         else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) | ||||
|       self.cells.append( cell ) | ||||
|       C_prev = cell.out_dim | ||||
|     self.op_names   = deepcopy( search_space ) | ||||
|     self._Layer     = len(self.cells) | ||||
|     self.edge2index = edge2index | ||||
|     self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     # to maintain the sampled architecture | ||||
|     self.sampled_arch = None | ||||
|  | ||||
|   def update_arch(self, _arch): | ||||
|     if _arch is None: | ||||
|       self.sampled_arch = None | ||||
|     elif isinstance(_arch, Structure): | ||||
|       self.sampled_arch = _arch | ||||
|     elif isinstance(_arch, (list, tuple)): | ||||
|       genotypes = [] | ||||
|       for i in range(1, self.max_nodes): | ||||
|         xlist = [] | ||||
|         for j in range(i): | ||||
|           node_str = '{:}<-{:}'.format(i, j) | ||||
|           op_index = _arch[ self.edge2index[node_str] ] | ||||
|           op_name  = self.op_names[ op_index ] | ||||
|           xlist.append((op_name, j)) | ||||
|         genotypes.append( tuple(xlist) ) | ||||
|       self.sampled_arch = Structure(genotypes) | ||||
|     else: | ||||
|       raise ValueError('invalid type of input architecture : {:}'.format(_arch)) | ||||
|     return self.sampled_arch | ||||
|      | ||||
|   def create_controller(self): | ||||
|     return Controller(len(self.edge2index), len(self.op_names)) | ||||
|  | ||||
|   def get_message(self): | ||||
|     string = self.extra_repr() | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||
|     return string | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|  | ||||
|     feature = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       if isinstance(cell, SearchCell): | ||||
|         feature = cell.forward_dynamic(feature, self.sampled_arch) | ||||
|       else: feature = cell(feature) | ||||
|  | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling( out ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     return out, logits | ||||
							
								
								
									
										55
									
								
								lib/models/cell_searchs/search_model_enas_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								lib/models/cell_searchs/search_model_enas_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ########################################################################## | ||||
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | ||||
| ########################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.distributions.categorical import Categorical | ||||
|  | ||||
| class Controller(nn.Module): | ||||
|   # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py | ||||
|   def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): | ||||
|     super(Controller, self).__init__() | ||||
|     # assign the attributes | ||||
|     self.num_edge  = num_edge | ||||
|     self.num_ops   = num_ops | ||||
|     self.lstm_size = lstm_size | ||||
|     self.lstm_N    = lstm_num_layers | ||||
|     self.tanh_constant = tanh_constant | ||||
|     self.temperature   = temperature | ||||
|     # create parameters | ||||
|     self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) | ||||
|     self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) | ||||
|     self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) | ||||
|     self.w_pred = nn.Linear(self.lstm_size, self.num_ops) | ||||
|  | ||||
|     nn.init.uniform_(self.input_vars         , -0.1, 0.1) | ||||
|     nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) | ||||
|     nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) | ||||
|     nn.init.uniform_(self.w_embd.weight      , -0.1, 0.1) | ||||
|     nn.init.uniform_(self.w_pred.weight      , -0.1, 0.1) | ||||
|  | ||||
|   def forward(self): | ||||
|  | ||||
|     inputs, h0 = self.input_vars, None | ||||
|     log_probs, entropys, sampled_arch = [], [], [] | ||||
|     for iedge in range(self.num_edge): | ||||
|       outputs, h0 = self.w_lstm(inputs, h0) | ||||
|        | ||||
|       logits = self.w_pred(outputs) | ||||
|       logits = logits / self.temperature | ||||
|       logits = self.tanh_constant * torch.tanh(logits) | ||||
|       # distribution | ||||
|       op_distribution = Categorical(logits=logits) | ||||
|       op_index    = op_distribution.sample() | ||||
|       sampled_arch.append( op_index.item() ) | ||||
|  | ||||
|       op_log_prob = op_distribution.log_prob(op_index) | ||||
|       log_probs.append( op_log_prob.view(-1) ) | ||||
|       op_entropy  = op_distribution.entropy() | ||||
|       entropys.append( op_entropy.view(-1) ) | ||||
|        | ||||
|       # obtain the input embedding for the next step | ||||
|       inputs = self.w_embd(op_index) | ||||
|     return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch | ||||
		Reference in New Issue
	
	Block a user