Prototype generic nas model (cont.) for ENAS.
This commit is contained in:
		| @@ -20,6 +20,10 @@ | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777 | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random | ||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random | ||||
| #### | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas | ||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas | ||||
| ###################################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| @@ -130,6 +134,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | ||||
|       network.set_cal_mode('joint', None) | ||||
|     elif algo == 'random': | ||||
|       network.set_cal_mode('urs', None) | ||||
|     elif algo == 'enas': | ||||
|       with torch.no_grad(): | ||||
|         network.controller.eval() | ||||
|         _, _, sampled_arch = network.controller() | ||||
|       network.set_cal_mode('dynamic', sampled_arch) | ||||
|     else: | ||||
|       raise ValueError('Invalid algo name : {:}'.format(algo)) | ||||
|        | ||||
| @@ -153,16 +162,21 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | ||||
|       network.set_cal_mode('joint', None) | ||||
|     elif algo == 'random': | ||||
|       network.set_cal_mode('urs', None) | ||||
|     else: | ||||
|     elif algo != 'enas': | ||||
|       raise ValueError('Invalid algo name : {:}'.format(algo)) | ||||
|     network.zero_grad() | ||||
|     if algo == 'darts-v2': | ||||
|       arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) | ||||
|       a_optimizer.step() | ||||
|     elif algo == 'random' or algo == 'enas': | ||||
|       with torch.no_grad(): | ||||
|         _, logits = network(arch_inputs) | ||||
|         arch_loss = criterion(logits, arch_targets) | ||||
|     else: | ||||
|       _, logits = network(arch_inputs) | ||||
|       arch_loss = criterion(logits, arch_targets) | ||||
|       arch_loss.backward() | ||||
|     a_optimizer.step() | ||||
|       a_optimizer.step() | ||||
|     # record | ||||
|     arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||
|     arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||
| @@ -182,6 +196,76 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | ||||
|   return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger): | ||||
|   # config. (containing some necessary arg) | ||||
|   #   baseline: The baseline score (i.e. average val_acc) from the previous epoch | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time() | ||||
|    | ||||
|   controller_num_aggregate = 20 | ||||
|   controller_train_steps = 50 | ||||
|   controller_bl_dec = 0.99 | ||||
|   controller_entropy_weight = 0.0001 | ||||
|  | ||||
|   network.eval() | ||||
|   network.controller.train() | ||||
|   network.controller.zero_grad() | ||||
|   loader_iter = iter(xloader) | ||||
|   for step in range(controller_train_steps * controller_num_aggregate): | ||||
|     try: | ||||
|       inputs, targets = next(loader_iter) | ||||
|     except: | ||||
|       loader_iter = iter(xloader) | ||||
|       inputs, targets = next(loader_iter) | ||||
|     inputs  = inputs.cuda(non_blocking=True) | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - xend) | ||||
|      | ||||
|     log_prob, entropy, sampled_arch = network.controller() | ||||
|     with torch.no_grad(): | ||||
|       network.set_cal_mode('dynamic', sampled_arch) | ||||
|       _, logits = network(inputs) | ||||
|       val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       val_top1  = val_top1.view(-1) / 100 | ||||
|     reward = val_top1 + controller_entropy_weight * entropy | ||||
|     if prev_baseline is None: | ||||
|       baseline = val_top1 | ||||
|     else: | ||||
|       baseline = prev_baseline - (1 - controller_bl_dec) * (prev_baseline - reward) | ||||
|     | ||||
|     loss = -1 * log_prob * (reward - baseline) | ||||
|      | ||||
|     # account | ||||
|     RewardMeter.update(reward.item()) | ||||
|     BaselineMeter.update(baseline.item()) | ||||
|     ValAccMeter.update(val_top1.item()*100) | ||||
|     LossMeter.update(loss.item()) | ||||
|     EntropyMeter.update(entropy.item()) | ||||
|    | ||||
|     # Average gradient over controller_num_aggregate samples | ||||
|     loss = loss / controller_num_aggregate | ||||
|     loss.backward(retain_graph=True) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - xend) | ||||
|     xend = time.time() | ||||
|     if (step+1) % controller_num_aggregate == 0: | ||||
|       grad_norm = torch.nn.utils.clip_grad_norm_(network.controller.parameters(), 5.0) | ||||
|       GradnormMeter.update(grad_norm) | ||||
|       optimizer.step() | ||||
|       network.controller.zero_grad() | ||||
|  | ||||
|     if step % print_freq == 0: | ||||
|       Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, controller_train_steps * controller_num_aggregate) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter) | ||||
|       Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr) | ||||
|  | ||||
|   return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg | ||||
|  | ||||
|  | ||||
| def get_best_arch(xloader, network, n_samples, algo): | ||||
|   with torch.no_grad(): | ||||
|     network.eval() | ||||
| @@ -192,6 +276,11 @@ def get_best_arch(xloader, network, n_samples, algo): | ||||
|     elif algo.startswith('darts') or algo == 'gdas': | ||||
|       arch = network.genotype | ||||
|       archs, valid_accs = [arch], [] | ||||
|     elif algo == 'enas': | ||||
|       archs, valid_accs = [], [] | ||||
|       for _ in range(n_samples): | ||||
|         _, _, sampled_arch = network.controller() | ||||
|         archs.append(sampled_arch) | ||||
|     else: | ||||
|       raise ValueError('Invalid algorithm name : {:}'.format(algo)) | ||||
|     loader_iter = iter(xloader) | ||||
| @@ -245,7 +334,7 @@ def main(xargs): | ||||
|  | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ | ||||
|   search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ | ||||
|                                         (config.batch_size, config.test_batch_size), xargs.workers) | ||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
| @@ -263,7 +352,7 @@ def main(xargs): | ||||
|   logger.log('{:}'.format(search_model)) | ||||
|  | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) | ||||
|   a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||
|   a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps) | ||||
|   logger.log('w-optimizer : {:}'.format(w_optimizer)) | ||||
|   logger.log('a-optimizer : {:}'.format(a_optimizer)) | ||||
|   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||
| @@ -288,6 +377,8 @@ def main(xargs): | ||||
|     start_epoch = last_info['epoch'] | ||||
|     checkpoint  = torch.load(last_info['last_checkpoint']) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     if xargs.algo == 'enas': | ||||
|       baseline  = checkpoint['baseline'] | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     search_model.load_state_dict( checkpoint['search_model'] ) | ||||
|     w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) | ||||
| @@ -297,6 +388,7 @@ def main(xargs): | ||||
|   else: | ||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.return_topK(1, True)[0]} | ||||
|     baseline = None | ||||
|  | ||||
|   # start training | ||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||
| @@ -312,9 +404,13 @@ def main(xargs): | ||||
|     search_time.update(time.time() - start_time) | ||||
|     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||
|     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) | ||||
|     if xargs.algo == 'enas': | ||||
|       ctl_loss, ctl_acc, baseline, ctl_reward \ | ||||
|                                  = train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger) | ||||
|       logger.log('[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}'.format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward)) | ||||
|  | ||||
|     genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) | ||||
|     if xargs.algo == 'setn': | ||||
|     if xargs.algo == 'setn' or xargs.algo == 'enas': | ||||
|       network.set_cal_mode('dynamic', genotype) | ||||
|     elif xargs.algo == 'gdas': | ||||
|       network.set_cal_mode('gdas', None) | ||||
| @@ -333,6 +429,7 @@ def main(xargs): | ||||
|     # save checkpoint | ||||
|     save_path = save_checkpoint({'epoch' : epoch + 1, | ||||
|                 'args'  : deepcopy(xargs), | ||||
|                 'baseline'    : baseline, | ||||
|                 'search_model': search_model.state_dict(), | ||||
|                 'w_optimizer' : w_optimizer.state_dict(), | ||||
|                 'a_optimizer' : a_optimizer.state_dict(), | ||||
| @@ -377,7 +474,6 @@ def main(xargs): | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") | ||||
|   parser.add_argument('--data_path'   ,       type=str,   help='Path to dataset') | ||||
| @@ -396,7 +492,8 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--config_path' ,       type=str,   default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') | ||||
|   # architecture leraning rate | ||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
|   parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding') | ||||
|   parser.add_argument('--arch_eps'          , type=float, default=1e-8, help='weight decay for arch encoding') | ||||
|   parser.add_argument('--drop_path_rate'  ,  type=float, help='The drop path rate.') | ||||
|   # log | ||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user