diff --git a/exps/NATS-algos/search-size.py b/exps/NATS-algos/search-size.py index 78727ee..f40a95f 100644 --- a/exps/NATS-algos/search-size.py +++ b/exps/NATS-algos/search-size.py @@ -1,28 +1,30 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # ########################################################################################################################################### +# # In this file, we aims to evaluate three kinds of channel searching strategies: # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" -# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" -# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" -# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links: +# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" +# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" +# +# For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links: # - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md # - FBNetV2: https://github.com/facebookresearch/mobile-vision # - TuNAS: https://github.com/google-research/google-research/tree/master/tunas #### -# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25 +# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25 #### # python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 #### -# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 -# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 -# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 +# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777 +# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777 +# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777 #### -# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 -# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 -# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777 +# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0 +# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 +# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777 ########################################################################################################################################### import os, sys, time, random, argparse import numpy as np @@ -94,13 +96,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer a_optimizer.zero_grad() _, logits, log_probs = network(arch_inputs) arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - if algo == 'tunas': + if algo == 'mask_rl': with torch.no_grad(): RL_BASELINE_EMA.update(arch_prec1.item()) rl_advantage = arch_prec1 - RL_BASELINE_EMA.value rl_log_prob = sum(log_probs) arch_loss = - rl_advantage * rl_log_prob - elif algo == 'tas' or algo == 'fbv2': + elif algo == 'tas' or algo == 'mask_gumbel': arch_loss = criterion(logits, arch_targets) else: raise ValueError('invalid algorightm name: {:}'.format(algo)) @@ -231,7 +233,7 @@ def main(xargs): logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller)) - if xargs.algo == 'fbv2' or xargs.algo == 'tas': + if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas': network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1)) logger.log('[RESET tau as : {:}]'.format(network.tau)) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ @@ -291,7 +293,7 @@ if __name__ == '__main__': parser.add_argument('--data_path' , type=str, help='Path to dataset') parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.') - parser.add_argument('--algo' , type=str, choices=['tas', 'fbv2', 'tunas'], help='The search space name.') + parser.add_argument('--algo' , type=str, choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.') parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.') parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).') # FOR GDAS