Revise names for compared #channel-search algorithms
This commit is contained in:
		| @@ -1,28 +1,30 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||||
| ########################################################################################################################################### | ########################################################################################################################################### | ||||||
|  | # | ||||||
| # In this file, we aims to evaluate three kinds of channel searching strategies: | # In this file, we aims to evaluate three kinds of channel searching strategies: | ||||||
| # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | ||||||
| # - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | # - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | ||||||
| # - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | # - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | ||||||
| # For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links: | # | ||||||
|  | # For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links: | ||||||
| # - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md | # - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md | ||||||
| # - FBNetV2: https://github.com/facebookresearch/mobile-vision | # - FBNetV2: https://github.com/facebookresearch/mobile-vision | ||||||
| # - TuNAS: https://github.com/google-research/google-research/tree/master/tunas | # - TuNAS: https://github.com/google-research/google-research/tree/master/tunas | ||||||
| #### | #### | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25 | # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25 | ||||||
| #### | #### | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 | ||||||
| #### | #### | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777 | ||||||
| #### | #### | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 | # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0 | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777 | # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777 | ||||||
| ########################################################################################################################################### | ########################################################################################################################################### | ||||||
| import os, sys, time, random, argparse | import os, sys, time, random, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -94,13 +96,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | |||||||
|     a_optimizer.zero_grad() |     a_optimizer.zero_grad() | ||||||
|     _, logits, log_probs = network(arch_inputs) |     _, logits, log_probs = network(arch_inputs) | ||||||
|     arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) |     arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||||
|     if algo == 'tunas': |     if algo == 'mask_rl': | ||||||
|       with torch.no_grad(): |       with torch.no_grad(): | ||||||
|         RL_BASELINE_EMA.update(arch_prec1.item()) |         RL_BASELINE_EMA.update(arch_prec1.item()) | ||||||
|         rl_advantage = arch_prec1 - RL_BASELINE_EMA.value |         rl_advantage = arch_prec1 - RL_BASELINE_EMA.value | ||||||
|       rl_log_prob = sum(log_probs) |       rl_log_prob = sum(log_probs) | ||||||
|       arch_loss = - rl_advantage * rl_log_prob |       arch_loss = - rl_advantage * rl_log_prob | ||||||
|     elif algo == 'tas' or algo == 'fbv2': |     elif algo == 'tas' or algo == 'mask_gumbel': | ||||||
|       arch_loss = criterion(logits, arch_targets) |       arch_loss = criterion(logits, arch_targets) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('invalid algorightm name: {:}'.format(algo)) |       raise ValueError('invalid algorightm name: {:}'.format(algo)) | ||||||
| @@ -231,7 +233,7 @@ def main(xargs): | |||||||
|  |  | ||||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller)) |     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller)) | ||||||
|  |  | ||||||
|     if xargs.algo == 'fbv2' or xargs.algo == 'tas': |     if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas': | ||||||
|       network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1)) |       network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1)) | ||||||
|       logger.log('[RESET tau as : {:}]'.format(network.tau)) |       logger.log('[RESET tau as : {:}]'.format(network.tau)) | ||||||
|     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ |     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ | ||||||
| @@ -291,7 +293,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--data_path'   ,       type=str,   help='Path to dataset') |   parser.add_argument('--data_path'   ,       type=str,   help='Path to dataset') | ||||||
|   parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') |   parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||||
|   parser.add_argument('--search_space',       type=str,   default='sss', choices=['sss'], help='The search space name.') |   parser.add_argument('--search_space',       type=str,   default='sss', choices=['sss'], help='The search space name.') | ||||||
|   parser.add_argument('--algo'        ,       type=str,   choices=['tas', 'fbv2', 'tunas'], help='The search space name.') |   parser.add_argument('--algo'        ,       type=str,   choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.') | ||||||
|   parser.add_argument('--genotype'    ,       type=str,   default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.') |   parser.add_argument('--genotype'    ,       type=str,   default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.') | ||||||
|   parser.add_argument('--use_api'     ,       type=int,   default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).') |   parser.add_argument('--use_api'     ,       type=int,   default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).') | ||||||
|   # FOR GDAS |   # FOR GDAS | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user