update ENAS
This commit is contained in:
		
							
								
								
									
										17
									
								
								configs/nas-benchmark/algos/ENAS.config
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								configs/nas-benchmark/algos/ENAS.config
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "scheduler": ["str",   "cos"],
 | 
				
			||||||
 | 
					  "LR"       : ["float", "0.05"],
 | 
				
			||||||
 | 
					  "eta_min"  : ["float", "0.0005"],
 | 
				
			||||||
 | 
					  "epochs"   : ["int",   "310"],
 | 
				
			||||||
 | 
					  "T_max"    : ["int",   "10"],
 | 
				
			||||||
 | 
					  "warmup"   : ["int",   "0"],
 | 
				
			||||||
 | 
					  "optim"    : ["str",   "SGD"],
 | 
				
			||||||
 | 
					  "decay"    : ["float", "0.00025"],
 | 
				
			||||||
 | 
					  "momentum" : ["float", "0.9"],
 | 
				
			||||||
 | 
					  "nesterov" : ["bool",  "1"],
 | 
				
			||||||
 | 
					  "controller_lr"   : ["float", "0.001"],
 | 
				
			||||||
 | 
					  "controller_betas": ["float", [0, 0.999]],
 | 
				
			||||||
 | 
					  "controller_eps"  : ["float", 0.001],
 | 
				
			||||||
 | 
					  "criterion": ["str",   "Softmax"],
 | 
				
			||||||
 | 
					  "batch_size": ["int",  "128"]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										347
									
								
								exps/algos/ENAS.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										347
									
								
								exps/algos/ENAS.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,347 @@
 | 
				
			|||||||
 | 
					##################################################
 | 
				
			||||||
 | 
					# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
				
			||||||
 | 
					##################################################
 | 
				
			||||||
 | 
					import os, sys, time, glob, random, argparse
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from copy import deepcopy
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 | 
				
			||||||
 | 
					if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 | 
				
			||||||
 | 
					from config_utils import load_config, dict2config, configure2str
 | 
				
			||||||
 | 
					from datasets     import get_datasets, SearchDataset
 | 
				
			||||||
 | 
					from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
 | 
				
			||||||
 | 
					from utils        import get_model_infos, obtain_accuracy
 | 
				
			||||||
 | 
					from log_utils    import AverageMeter, time_string, convert_secs2time
 | 
				
			||||||
 | 
					from models       import get_cell_based_tiny_net, get_search_spaces
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger):
 | 
				
			||||||
 | 
					  data_time, batch_time = AverageMeter(), AverageMeter()
 | 
				
			||||||
 | 
					  losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(), time.time()
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  shared_cnn.train()
 | 
				
			||||||
 | 
					  controller.eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for step, (inputs, targets) in enumerate(xloader):
 | 
				
			||||||
 | 
					    scheduler.update(None, 1.0 * step / len(xloader))
 | 
				
			||||||
 | 
					    targets = targets.cuda(non_blocking=True)
 | 
				
			||||||
 | 
					    # measure data loading time
 | 
				
			||||||
 | 
					    data_time.update(time.time() - xend)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    with torch.no_grad():
 | 
				
			||||||
 | 
					      _, _, sampled_arch = controller()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    optimizer.zero_grad()
 | 
				
			||||||
 | 
					    shared_cnn.module.update_arch(sampled_arch)
 | 
				
			||||||
 | 
					    _, logits = shared_cnn(inputs)
 | 
				
			||||||
 | 
					    loss      = criterion(logits, targets)
 | 
				
			||||||
 | 
					    loss.backward()
 | 
				
			||||||
 | 
					    torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
 | 
				
			||||||
 | 
					    optimizer.step()
 | 
				
			||||||
 | 
					    # record
 | 
				
			||||||
 | 
					    prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
 | 
				
			||||||
 | 
					    losses.update(loss.item(),  inputs.size(0))
 | 
				
			||||||
 | 
					    top1s.update (prec1.item(), inputs.size(0))
 | 
				
			||||||
 | 
					    top5s.update (prec5.item(), inputs.size(0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # measure elapsed time
 | 
				
			||||||
 | 
					    batch_time.update(time.time() - xend)
 | 
				
			||||||
 | 
					    xend = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if step % print_freq == 0 or step + 1 == len(xloader):
 | 
				
			||||||
 | 
					      Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
 | 
				
			||||||
 | 
					      Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
 | 
				
			||||||
 | 
					      Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=losses, top1=top1s, top5=top5s)
 | 
				
			||||||
 | 
					      logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)
 | 
				
			||||||
 | 
					  return losses.avg, top1s.avg, top5s.avg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def train_controller(xloader, shared_cnn, controller, criterion, optimizer, config, epoch_str, print_freq, logger):
 | 
				
			||||||
 | 
					  # config. (containing some necessary arg)
 | 
				
			||||||
 | 
					  #   baseline: The baseline score (i.e. average val_acc) from the previous epoch
 | 
				
			||||||
 | 
					  data_time, batch_time = AverageMeter(), AverageMeter()
 | 
				
			||||||
 | 
					  GradnormMeter, LossMeter, ValAccMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  shared_cnn.eval()
 | 
				
			||||||
 | 
					  controller.train()
 | 
				
			||||||
 | 
					  controller.zero_grad()
 | 
				
			||||||
 | 
					  #for step, (inputs, targets) in enumerate(xloader):
 | 
				
			||||||
 | 
					  loader_iter = iter(xloader)
 | 
				
			||||||
 | 
					  for step in range(config.ctl_train_steps * config.ctl_num_aggre):
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					      inputs, targets = next(loader_iter)
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					      loader_iter = iter(xloader)
 | 
				
			||||||
 | 
					      inputs, targets = next(loader_iter)
 | 
				
			||||||
 | 
					    targets = targets.cuda(non_blocking=True)
 | 
				
			||||||
 | 
					    # measure data loading time
 | 
				
			||||||
 | 
					    data_time.update(time.time() - xend)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    log_prob, entropy, sampled_arch = controller()
 | 
				
			||||||
 | 
					    with torch.no_grad():
 | 
				
			||||||
 | 
					      shared_cnn.module.update_arch(sampled_arch)
 | 
				
			||||||
 | 
					      _, logits = shared_cnn(inputs)
 | 
				
			||||||
 | 
					      val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
 | 
				
			||||||
 | 
					      val_top1  = val_top1.view(-1) / 100
 | 
				
			||||||
 | 
					    reward = val_top1 + config.ctl_entropy_w * entropy
 | 
				
			||||||
 | 
					    if config.baseline is None:
 | 
				
			||||||
 | 
					      baseline = val_top1
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      baseline = config.baseline - (1 - config.ctl_bl_dec) * (config.baseline - reward)
 | 
				
			||||||
 | 
					   
 | 
				
			||||||
 | 
					    loss = -1 * log_prob * (reward - baseline)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # account
 | 
				
			||||||
 | 
					    RewardMeter.update(reward.item())
 | 
				
			||||||
 | 
					    BaselineMeter.update(baseline.item())
 | 
				
			||||||
 | 
					    ValAccMeter.update(val_top1.item())
 | 
				
			||||||
 | 
					    LossMeter.update(loss.item())
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    # Average gradient over controller_num_aggregate samples
 | 
				
			||||||
 | 
					    loss = loss / config.ctl_num_aggre
 | 
				
			||||||
 | 
					    loss.backward(retain_graph=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # measure elapsed time
 | 
				
			||||||
 | 
					    batch_time.update(time.time() - xend)
 | 
				
			||||||
 | 
					    xend = time.time()
 | 
				
			||||||
 | 
					    if (step+1) % config.ctl_num_aggre == 0:
 | 
				
			||||||
 | 
					      grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0)
 | 
				
			||||||
 | 
					      GradnormMeter.update(grad_norm)
 | 
				
			||||||
 | 
					      optimizer.step()
 | 
				
			||||||
 | 
					      controller.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if step % print_freq == 0:
 | 
				
			||||||
 | 
					      Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre)
 | 
				
			||||||
 | 
					      Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
 | 
				
			||||||
 | 
					      Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter)
 | 
				
			||||||
 | 
					      logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_best_arch(controller, shared_cnn, xloader, n_samples=10):
 | 
				
			||||||
 | 
					  with torch.no_grad():
 | 
				
			||||||
 | 
					    controller.eval()
 | 
				
			||||||
 | 
					    shared_cnn.eval()
 | 
				
			||||||
 | 
					    archs, valid_accs = [], []
 | 
				
			||||||
 | 
					    loader_iter = iter(xloader)
 | 
				
			||||||
 | 
					    for i in range(n_samples):
 | 
				
			||||||
 | 
					      try:
 | 
				
			||||||
 | 
					        inputs, targets = next(loader_iter)
 | 
				
			||||||
 | 
					      except:
 | 
				
			||||||
 | 
					        loader_iter = iter(xloader)
 | 
				
			||||||
 | 
					        inputs, targets = next(loader_iter)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      _, _, sampled_arch = controller()
 | 
				
			||||||
 | 
					      arch = shared_cnn.module.update_arch(sampled_arch)
 | 
				
			||||||
 | 
					      _, logits = shared_cnn(inputs)
 | 
				
			||||||
 | 
					      val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      archs.append( arch )
 | 
				
			||||||
 | 
					      valid_accs.append( val_top1.item() )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    best_idx = np.argmax(valid_accs)
 | 
				
			||||||
 | 
					    best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
 | 
				
			||||||
 | 
					    return best_arch, best_valid_acc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def valid_func(xloader, network, criterion):
 | 
				
			||||||
 | 
					  data_time, batch_time = AverageMeter(), AverageMeter()
 | 
				
			||||||
 | 
					  arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
 | 
				
			||||||
 | 
					  network.eval()
 | 
				
			||||||
 | 
					  end = time.time()
 | 
				
			||||||
 | 
					  with torch.no_grad():
 | 
				
			||||||
 | 
					    for step, (arch_inputs, arch_targets) in enumerate(xloader):
 | 
				
			||||||
 | 
					      arch_targets = arch_targets.cuda(non_blocking=True)
 | 
				
			||||||
 | 
					      # measure data loading time
 | 
				
			||||||
 | 
					      data_time.update(time.time() - end)
 | 
				
			||||||
 | 
					      # prediction
 | 
				
			||||||
 | 
					      _, logits = network(arch_inputs)
 | 
				
			||||||
 | 
					      arch_loss = criterion(logits, arch_targets)
 | 
				
			||||||
 | 
					      # record
 | 
				
			||||||
 | 
					      arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
 | 
				
			||||||
 | 
					      arch_losses.update(arch_loss.item(),  arch_inputs.size(0))
 | 
				
			||||||
 | 
					      arch_top1.update  (arch_prec1.item(), arch_inputs.size(0))
 | 
				
			||||||
 | 
					      arch_top5.update  (arch_prec5.item(), arch_inputs.size(0))
 | 
				
			||||||
 | 
					      # measure elapsed time
 | 
				
			||||||
 | 
					      batch_time.update(time.time() - end)
 | 
				
			||||||
 | 
					      end = time.time()
 | 
				
			||||||
 | 
					  return arch_losses.avg, arch_top1.avg, arch_top5.avg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main(xargs):
 | 
				
			||||||
 | 
					  assert torch.cuda.is_available(), 'CUDA is not available.'
 | 
				
			||||||
 | 
					  torch.backends.cudnn.enabled   = True
 | 
				
			||||||
 | 
					  torch.backends.cudnn.benchmark = False
 | 
				
			||||||
 | 
					  torch.backends.cudnn.deterministic = True
 | 
				
			||||||
 | 
					  torch.set_num_threads( xargs.workers )
 | 
				
			||||||
 | 
					  prepare_seed(xargs.rand_seed)
 | 
				
			||||||
 | 
					  logger = prepare_logger(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
 | 
				
			||||||
 | 
					  if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
 | 
				
			||||||
 | 
					    split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
 | 
				
			||||||
 | 
					    cifar_split = load_config(split_Fpath, None, None)
 | 
				
			||||||
 | 
					    train_split, valid_split = cifar_split.train, cifar_split.valid
 | 
				
			||||||
 | 
					    logger.log('Load split file from {:}'.format(split_Fpath))
 | 
				
			||||||
 | 
					  elif xargs.dataset.startswith('ImageNet16'):
 | 
				
			||||||
 | 
					    split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
 | 
				
			||||||
 | 
					    imagenet16_split = load_config(split_Fpath, None, None)
 | 
				
			||||||
 | 
					    train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
 | 
				
			||||||
 | 
					    logger.log('Load split file from {:}'.format(split_Fpath))
 | 
				
			||||||
 | 
					  else:
 | 
				
			||||||
 | 
					    raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
 | 
				
			||||||
 | 
					  logger.log('use config from : {:}'.format(xargs.config_path))
 | 
				
			||||||
 | 
					  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
 | 
				
			||||||
 | 
					  logger.log('config: {:}'.format(config))
 | 
				
			||||||
 | 
					  # To split data
 | 
				
			||||||
 | 
					  train_data_v2 = deepcopy(train_data)
 | 
				
			||||||
 | 
					  train_data_v2.transform = test_data.transform
 | 
				
			||||||
 | 
					  valid_data    = train_data_v2
 | 
				
			||||||
 | 
					  # data loader
 | 
				
			||||||
 | 
					  train_loader  = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=xargs.workers, pin_memory=True)
 | 
				
			||||||
 | 
					  valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
 | 
				
			||||||
 | 
					  logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
 | 
				
			||||||
 | 
					  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  search_space = get_search_spaces('cell', xargs.search_space_name)
 | 
				
			||||||
 | 
					  model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells,
 | 
				
			||||||
 | 
					                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
 | 
				
			||||||
 | 
					                              'space'    : search_space}, None)
 | 
				
			||||||
 | 
					  shared_cnn = get_cell_based_tiny_net(model_config)
 | 
				
			||||||
 | 
					  controller = shared_cnn.create_controller()
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config)
 | 
				
			||||||
 | 
					  a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps)
 | 
				
			||||||
 | 
					  logger.log('w-optimizer : {:}'.format(w_optimizer))
 | 
				
			||||||
 | 
					  logger.log('a-optimizer : {:}'.format(a_optimizer))
 | 
				
			||||||
 | 
					  logger.log('w-scheduler : {:}'.format(w_scheduler))
 | 
				
			||||||
 | 
					  logger.log('criterion   : {:}'.format(criterion))
 | 
				
			||||||
 | 
					  #flop, param  = get_model_infos(shared_cnn, xshape)
 | 
				
			||||||
 | 
					  #logger.log('{:}'.format(shared_cnn))
 | 
				
			||||||
 | 
					  #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
 | 
				
			||||||
 | 
					  shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if last_info.exists(): # automatically resume from previous checkpoint
 | 
				
			||||||
 | 
					    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
 | 
				
			||||||
 | 
					    last_info   = torch.load(last_info)
 | 
				
			||||||
 | 
					    start_epoch = last_info['epoch']
 | 
				
			||||||
 | 
					    checkpoint  = torch.load(last_info['last_checkpoint'])
 | 
				
			||||||
 | 
					    genotypes   = checkpoint['genotypes']
 | 
				
			||||||
 | 
					    baseline    = checkpoint['baseline']
 | 
				
			||||||
 | 
					    valid_accuracies = checkpoint['valid_accuracies']
 | 
				
			||||||
 | 
					    shared_cnn.load_state_dict( checkpoint['shared_cnn'] )
 | 
				
			||||||
 | 
					    controller.load_state_dict( checkpoint['controller'] )
 | 
				
			||||||
 | 
					    w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
 | 
				
			||||||
 | 
					    w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
 | 
				
			||||||
 | 
					    a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
 | 
				
			||||||
 | 
					    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
 | 
				
			||||||
 | 
					  else:
 | 
				
			||||||
 | 
					    logger.log("=> do not find the last-info file : {:}".format(last_info))
 | 
				
			||||||
 | 
					    start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # start training
 | 
				
			||||||
 | 
					  start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
 | 
				
			||||||
 | 
					  for epoch in range(start_epoch, total_epoch):
 | 
				
			||||||
 | 
					    w_scheduler.update(epoch, 0.0)
 | 
				
			||||||
 | 
					    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
 | 
				
			||||||
 | 
					    epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
 | 
				
			||||||
 | 
					    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
 | 
				
			||||||
 | 
					    logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
 | 
				
			||||||
 | 
					    ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \
 | 
				
			||||||
 | 
					                                 = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \
 | 
				
			||||||
 | 
					                                                        dict2config({'baseline': baseline,
 | 
				
			||||||
 | 
					                                                                     'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate,
 | 
				
			||||||
 | 
					                                                                     'ctl_entropy_w': xargs.controller_entropy_weight, 
 | 
				
			||||||
 | 
					                                                                     'ctl_bl_dec'   : xargs.controller_bl_dec}, None), \
 | 
				
			||||||
 | 
					                                                        epoch_str, xargs.print_freq, logger)
 | 
				
			||||||
 | 
					    logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline))
 | 
				
			||||||
 | 
					    best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
 | 
				
			||||||
 | 
					    shared_cnn.module.update_arch(best_arch)
 | 
				
			||||||
 | 
					    best_valid_acc = valid_func(valid_loader, shared_cnn, criterion)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    genotypes[epoch] = best_arch
 | 
				
			||||||
 | 
					    # check the best accuracy
 | 
				
			||||||
 | 
					    valid_accuracies[epoch] = best_valid_acc
 | 
				
			||||||
 | 
					    if best_valid_acc > valid_accuracies['best']:
 | 
				
			||||||
 | 
					      valid_accuracies['best'] = best_valid_acc
 | 
				
			||||||
 | 
					      genotypes['best']        = best_arch
 | 
				
			||||||
 | 
					      find_best = True
 | 
				
			||||||
 | 
					    else: find_best = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
 | 
				
			||||||
 | 
					    # save checkpoint
 | 
				
			||||||
 | 
					    save_path = save_checkpoint({'epoch' : epoch + 1,
 | 
				
			||||||
 | 
					                'args'  : deepcopy(xargs),
 | 
				
			||||||
 | 
					                'baseline'    : baseline,
 | 
				
			||||||
 | 
					                'shared_cnn'  : shared_cnn.state_dict(),
 | 
				
			||||||
 | 
					                'controller'  : controller.state_dict(),
 | 
				
			||||||
 | 
					                'w_optimizer' : w_optimizer.state_dict(),
 | 
				
			||||||
 | 
					                'a_optimizer' : a_optimizer.state_dict(),
 | 
				
			||||||
 | 
					                'w_scheduler' : w_scheduler.state_dict(),
 | 
				
			||||||
 | 
					                'genotypes'   : genotypes,
 | 
				
			||||||
 | 
					                'valid_accuracies' : valid_accuracies},
 | 
				
			||||||
 | 
					                model_base_path, logger)
 | 
				
			||||||
 | 
					    last_info = save_checkpoint({
 | 
				
			||||||
 | 
					          'epoch': epoch + 1,
 | 
				
			||||||
 | 
					          'args' : deepcopy(args),
 | 
				
			||||||
 | 
					          'last_checkpoint': save_path,
 | 
				
			||||||
 | 
					          }, logger.path('info'), logger)
 | 
				
			||||||
 | 
					    if find_best:
 | 
				
			||||||
 | 
					      logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc))
 | 
				
			||||||
 | 
					      copy_checkpoint(model_base_path, model_best_path, logger)
 | 
				
			||||||
 | 
					    # measure elapsed time
 | 
				
			||||||
 | 
					    epoch_time.update(time.time() - start_time)
 | 
				
			||||||
 | 
					    start_time = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  logger.log('\n' + '-'*100)
 | 
				
			||||||
 | 
					  # check the performance from the architecture dataset
 | 
				
			||||||
 | 
					  #if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
 | 
				
			||||||
 | 
					  #  logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
 | 
				
			||||||
 | 
					  #else:
 | 
				
			||||||
 | 
					  #  nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
 | 
				
			||||||
 | 
					  #  geno = genotypes[total_epoch-1]
 | 
				
			||||||
 | 
					  #  logger.log('The last model is {:}'.format(geno))
 | 
				
			||||||
 | 
					  #  info = nas_bench.query_by_arch( geno )
 | 
				
			||||||
 | 
					  #  if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
 | 
				
			||||||
 | 
					  #  else           : logger.log('{:}'.format(info))
 | 
				
			||||||
 | 
					  #  logger.log('-'*100)
 | 
				
			||||||
 | 
					  #  geno = genotypes['best']
 | 
				
			||||||
 | 
					  #  logger.log('The best model is {:}'.format(geno))
 | 
				
			||||||
 | 
					  #  info = nas_bench.query_by_arch( geno )
 | 
				
			||||||
 | 
					  #  if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
 | 
				
			||||||
 | 
					  #  else           : logger.log('{:}'.format(info))
 | 
				
			||||||
 | 
					  logger.close()
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  parser = argparse.ArgumentParser("ENAS")
 | 
				
			||||||
 | 
					  parser.add_argument('--data_path',          type=str,   help='Path to dataset')
 | 
				
			||||||
 | 
					  parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
 | 
				
			||||||
 | 
					  # channels and number-of-cells
 | 
				
			||||||
 | 
					  parser.add_argument('--search_space_name',  type=str,   help='The search space name.')
 | 
				
			||||||
 | 
					  parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.')
 | 
				
			||||||
 | 
					  parser.add_argument('--channel',            type=int,   help='The number of channels.')
 | 
				
			||||||
 | 
					  parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.')
 | 
				
			||||||
 | 
					  parser.add_argument('--config_path',        type=str,   help='The config file to train ENAS.')
 | 
				
			||||||
 | 
					  parser.add_argument('--controller_train_steps',    type=int,     help='.')
 | 
				
			||||||
 | 
					  parser.add_argument('--controller_num_aggregate',  type=int,     help='.')
 | 
				
			||||||
 | 
					  parser.add_argument('--controller_entropy_weight', type=float,   help='The weight for the entropy of the controller.')
 | 
				
			||||||
 | 
					  parser.add_argument('--controller_bl_dec'        , type=float,   help='.')
 | 
				
			||||||
 | 
					  parser.add_argument('--controller_num_samples'   , type=int,     help='.')
 | 
				
			||||||
 | 
					  # log
 | 
				
			||||||
 | 
					  parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)')
 | 
				
			||||||
 | 
					  parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.')
 | 
				
			||||||
 | 
					  parser.add_argument('--arch_nas_dataset',   type=str,   help='The path to load the architecture dataset (nas-benchmark).')
 | 
				
			||||||
 | 
					  parser.add_argument('--print_freq',         type=int,   help='print frequency (default: 200)')
 | 
				
			||||||
 | 
					  parser.add_argument('--rand_seed',          type=int,   help='manual seed')
 | 
				
			||||||
 | 
					  args = parser.parse_args()
 | 
				
			||||||
 | 
					  if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
 | 
				
			||||||
 | 
					  main(args)
 | 
				
			||||||
@@ -16,18 +16,10 @@ from .cell_searchs import CellStructure, CellArchitectures
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Cell-based NAS Models
 | 
					# Cell-based NAS Models
 | 
				
			||||||
def get_cell_based_tiny_net(config):
 | 
					def get_cell_based_tiny_net(config):
 | 
				
			||||||
  if config.name == 'DARTS-V1':
 | 
					  group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS']
 | 
				
			||||||
    from .cell_searchs import TinyNetworkDartsV1
 | 
					  from .cell_searchs import nas_super_nets
 | 
				
			||||||
    return TinyNetworkDartsV1(config.C, config.N, config.max_nodes, config.num_classes, config.space)
 | 
					  if config.name in group_names:
 | 
				
			||||||
  elif config.name == 'DARTS-V2':
 | 
					    return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
 | 
				
			||||||
    from .cell_searchs import TinyNetworkDartsV2
 | 
					 | 
				
			||||||
    return TinyNetworkDartsV2(config.C, config.N, config.max_nodes, config.num_classes, config.space)
 | 
					 | 
				
			||||||
  elif config.name == 'GDAS':
 | 
					 | 
				
			||||||
    from .cell_searchs import TinyNetworkGDAS
 | 
					 | 
				
			||||||
    return TinyNetworkGDAS(config.C, config.N, config.max_nodes, config.num_classes, config.space)
 | 
					 | 
				
			||||||
  elif config.name == 'SETN':
 | 
					 | 
				
			||||||
    from .cell_searchs import TinyNetworkSETN
 | 
					 | 
				
			||||||
    return TinyNetworkSETN(config.C, config.N, config.max_nodes, config.num_classes, config.space)
 | 
					 | 
				
			||||||
  elif config.name == 'infer.tiny':
 | 
					  elif config.name == 'infer.tiny':
 | 
				
			||||||
    from .cell_infers import TinyNetwork
 | 
					    from .cell_infers import TinyNetwork
 | 
				
			||||||
    return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)
 | 
					    return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,4 +2,11 @@ from .search_model_darts_v1 import TinyNetworkDartsV1
 | 
				
			|||||||
from .search_model_darts_v2 import TinyNetworkDartsV2
 | 
					from .search_model_darts_v2 import TinyNetworkDartsV2
 | 
				
			||||||
from .search_model_gdas     import TinyNetworkGDAS
 | 
					from .search_model_gdas     import TinyNetworkGDAS
 | 
				
			||||||
from .search_model_setn     import TinyNetworkSETN
 | 
					from .search_model_setn     import TinyNetworkSETN
 | 
				
			||||||
 | 
					from .search_model_enas     import TinyNetworkENAS
 | 
				
			||||||
from .genotypes             import Structure as CellStructure, architectures as CellArchitectures
 | 
					from .genotypes             import Structure as CellStructure, architectures as CellArchitectures
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1,
 | 
				
			||||||
 | 
					                  'DARTS-V2': TinyNetworkDartsV2,
 | 
				
			||||||
 | 
					                  'GDAS'    : TinyNetworkGDAS,
 | 
				
			||||||
 | 
					                  'SETN'    : TinyNetworkSETN,
 | 
				
			||||||
 | 
					                  'ENAS'    : TinyNetworkENAS}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										9
									
								
								lib/models/cell_searchs/_test_module.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								lib/models/cell_searchs/_test_module.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from search_model_enas_utils import Controller
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					  controller = Controller(6, 4)
 | 
				
			||||||
 | 
					  predictions = controller()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  main()
 | 
				
			||||||
							
								
								
									
										94
									
								
								lib/models/cell_searchs/search_model_enas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								lib/models/cell_searchs/search_model_enas.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,94 @@
 | 
				
			|||||||
 | 
					##################################################
 | 
				
			||||||
 | 
					# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
				
			||||||
 | 
					##########################################################################
 | 
				
			||||||
 | 
					# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
 | 
				
			||||||
 | 
					##########################################################################
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from copy import deepcopy
 | 
				
			||||||
 | 
					from ..cell_operations import ResNetBasicblock
 | 
				
			||||||
 | 
					from .search_cells     import SearchCell
 | 
				
			||||||
 | 
					from .genotypes        import Structure
 | 
				
			||||||
 | 
					from .search_model_enas_utils import Controller
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TinyNetworkENAS(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(self, C, N, max_nodes, num_classes, search_space):
 | 
				
			||||||
 | 
					    super(TinyNetworkENAS, self).__init__()
 | 
				
			||||||
 | 
					    self._C        = C
 | 
				
			||||||
 | 
					    self._layerN   = N
 | 
				
			||||||
 | 
					    self.max_nodes = max_nodes
 | 
				
			||||||
 | 
					    self.stem = nn.Sequential(
 | 
				
			||||||
 | 
					                    nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
 | 
				
			||||||
 | 
					                    nn.BatchNorm2d(C))
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N    
 | 
				
			||||||
 | 
					    layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    C_prev, num_edge, edge2index = C, None, None
 | 
				
			||||||
 | 
					    self.cells = nn.ModuleList()
 | 
				
			||||||
 | 
					    for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
 | 
				
			||||||
 | 
					      if reduction:
 | 
				
			||||||
 | 
					        cell = ResNetBasicblock(C_prev, C_curr, 2)
 | 
				
			||||||
 | 
					      else:
 | 
				
			||||||
 | 
					        cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
 | 
				
			||||||
 | 
					        if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
 | 
				
			||||||
 | 
					        else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
 | 
				
			||||||
 | 
					      self.cells.append( cell )
 | 
				
			||||||
 | 
					      C_prev = cell.out_dim
 | 
				
			||||||
 | 
					    self.op_names   = deepcopy( search_space )
 | 
				
			||||||
 | 
					    self._Layer     = len(self.cells)
 | 
				
			||||||
 | 
					    self.edge2index = edge2index
 | 
				
			||||||
 | 
					    self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
 | 
				
			||||||
 | 
					    self.global_pooling = nn.AdaptiveAvgPool2d(1)
 | 
				
			||||||
 | 
					    self.classifier = nn.Linear(C_prev, num_classes)
 | 
				
			||||||
 | 
					    # to maintain the sampled architecture
 | 
				
			||||||
 | 
					    self.sampled_arch = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def update_arch(self, _arch):
 | 
				
			||||||
 | 
					    if _arch is None:
 | 
				
			||||||
 | 
					      self.sampled_arch = None
 | 
				
			||||||
 | 
					    elif isinstance(_arch, Structure):
 | 
				
			||||||
 | 
					      self.sampled_arch = _arch
 | 
				
			||||||
 | 
					    elif isinstance(_arch, (list, tuple)):
 | 
				
			||||||
 | 
					      genotypes = []
 | 
				
			||||||
 | 
					      for i in range(1, self.max_nodes):
 | 
				
			||||||
 | 
					        xlist = []
 | 
				
			||||||
 | 
					        for j in range(i):
 | 
				
			||||||
 | 
					          node_str = '{:}<-{:}'.format(i, j)
 | 
				
			||||||
 | 
					          op_index = _arch[ self.edge2index[node_str] ]
 | 
				
			||||||
 | 
					          op_name  = self.op_names[ op_index ]
 | 
				
			||||||
 | 
					          xlist.append((op_name, j))
 | 
				
			||||||
 | 
					        genotypes.append( tuple(xlist) )
 | 
				
			||||||
 | 
					      self.sampled_arch = Structure(genotypes)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      raise ValueError('invalid type of input architecture : {:}'.format(_arch))
 | 
				
			||||||
 | 
					    return self.sampled_arch
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					  def create_controller(self):
 | 
				
			||||||
 | 
					    return Controller(len(self.edge2index), len(self.op_names))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_message(self):
 | 
				
			||||||
 | 
					    string = self.extra_repr()
 | 
				
			||||||
 | 
					    for i, cell in enumerate(self.cells):
 | 
				
			||||||
 | 
					      string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
 | 
				
			||||||
 | 
					    return string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def extra_repr(self):
 | 
				
			||||||
 | 
					    return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def forward(self, inputs):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    feature = self.stem(inputs)
 | 
				
			||||||
 | 
					    for i, cell in enumerate(self.cells):
 | 
				
			||||||
 | 
					      if isinstance(cell, SearchCell):
 | 
				
			||||||
 | 
					        feature = cell.forward_dynamic(feature, self.sampled_arch)
 | 
				
			||||||
 | 
					      else: feature = cell(feature)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    out = self.lastact(feature)
 | 
				
			||||||
 | 
					    out = self.global_pooling( out )
 | 
				
			||||||
 | 
					    out = out.view(out.size(0), -1)
 | 
				
			||||||
 | 
					    logits = self.classifier(out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return out, logits
 | 
				
			||||||
							
								
								
									
										55
									
								
								lib/models/cell_searchs/search_model_enas_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								lib/models/cell_searchs/search_model_enas_utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
				
			|||||||
 | 
					##################################################
 | 
				
			||||||
 | 
					# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
				
			||||||
 | 
					##########################################################################
 | 
				
			||||||
 | 
					# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
 | 
				
			||||||
 | 
					##########################################################################
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from torch.distributions.categorical import Categorical
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Controller(nn.Module):
 | 
				
			||||||
 | 
					  # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
 | 
				
			||||||
 | 
					  def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
 | 
				
			||||||
 | 
					    super(Controller, self).__init__()
 | 
				
			||||||
 | 
					    # assign the attributes
 | 
				
			||||||
 | 
					    self.num_edge  = num_edge
 | 
				
			||||||
 | 
					    self.num_ops   = num_ops
 | 
				
			||||||
 | 
					    self.lstm_size = lstm_size
 | 
				
			||||||
 | 
					    self.lstm_N    = lstm_num_layers
 | 
				
			||||||
 | 
					    self.tanh_constant = tanh_constant
 | 
				
			||||||
 | 
					    self.temperature   = temperature
 | 
				
			||||||
 | 
					    # create parameters
 | 
				
			||||||
 | 
					    self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
 | 
				
			||||||
 | 
					    self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
 | 
				
			||||||
 | 
					    self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
 | 
				
			||||||
 | 
					    self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    nn.init.uniform_(self.input_vars         , -0.1, 0.1)
 | 
				
			||||||
 | 
					    nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
 | 
				
			||||||
 | 
					    nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
 | 
				
			||||||
 | 
					    nn.init.uniform_(self.w_embd.weight      , -0.1, 0.1)
 | 
				
			||||||
 | 
					    nn.init.uniform_(self.w_pred.weight      , -0.1, 0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def forward(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    inputs, h0 = self.input_vars, None
 | 
				
			||||||
 | 
					    log_probs, entropys, sampled_arch = [], [], []
 | 
				
			||||||
 | 
					    for iedge in range(self.num_edge):
 | 
				
			||||||
 | 
					      outputs, h0 = self.w_lstm(inputs, h0)
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      logits = self.w_pred(outputs)
 | 
				
			||||||
 | 
					      logits = logits / self.temperature
 | 
				
			||||||
 | 
					      logits = self.tanh_constant * torch.tanh(logits)
 | 
				
			||||||
 | 
					      # distribution
 | 
				
			||||||
 | 
					      op_distribution = Categorical(logits=logits)
 | 
				
			||||||
 | 
					      op_index    = op_distribution.sample()
 | 
				
			||||||
 | 
					      sampled_arch.append( op_index.item() )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      op_log_prob = op_distribution.log_prob(op_index)
 | 
				
			||||||
 | 
					      log_probs.append( op_log_prob.view(-1) )
 | 
				
			||||||
 | 
					      op_entropy  = op_distribution.entropy()
 | 
				
			||||||
 | 
					      entropys.append( op_entropy.view(-1) )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      # obtain the input embedding for the next step
 | 
				
			||||||
 | 
					      inputs = self.w_embd(op_index)
 | 
				
			||||||
 | 
					    return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch
 | 
				
			||||||
		Reference in New Issue
	
	Block a user