Revise names for compared #channel-search algorithms
This commit is contained in:
		@@ -1,28 +1,30 @@
 | 
				
			|||||||
##################################################
 | 
					##################################################
 | 
				
			||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
 | 
					# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
 | 
				
			||||||
###########################################################################################################################################
 | 
					###########################################################################################################################################
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
# In this file, we aims to evaluate three kinds of channel searching strategies:
 | 
					# In this file, we aims to evaluate three kinds of channel searching strategies:
 | 
				
			||||||
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
 | 
					# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
 | 
				
			||||||
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
 | 
					# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
 | 
				
			||||||
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
 | 
					# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
 | 
				
			||||||
# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links:
 | 
					#
 | 
				
			||||||
 | 
					# For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links:
 | 
				
			||||||
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
 | 
					# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
 | 
				
			||||||
# - FBNetV2: https://github.com/facebookresearch/mobile-vision
 | 
					# - FBNetV2: https://github.com/facebookresearch/mobile-vision
 | 
				
			||||||
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
 | 
					# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
 | 
				
			||||||
####
 | 
					####
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
 | 
					# python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25
 | 
				
			||||||
####
 | 
					####
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
 | 
					# python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
 | 
					# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
 | 
					# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
 | 
				
			||||||
####
 | 
					####
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
 | 
					# python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
 | 
					# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
 | 
					# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
 | 
				
			||||||
####
 | 
					####
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
 | 
					# python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
 | 
					# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
 | 
				
			||||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
 | 
					# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
 | 
				
			||||||
###########################################################################################################################################
 | 
					###########################################################################################################################################
 | 
				
			||||||
import os, sys, time, random, argparse
 | 
					import os, sys, time, random, argparse
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
@@ -94,13 +96,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
 | 
				
			|||||||
    a_optimizer.zero_grad()
 | 
					    a_optimizer.zero_grad()
 | 
				
			||||||
    _, logits, log_probs = network(arch_inputs)
 | 
					    _, logits, log_probs = network(arch_inputs)
 | 
				
			||||||
    arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
 | 
					    arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
 | 
				
			||||||
    if algo == 'tunas':
 | 
					    if algo == 'mask_rl':
 | 
				
			||||||
      with torch.no_grad():
 | 
					      with torch.no_grad():
 | 
				
			||||||
        RL_BASELINE_EMA.update(arch_prec1.item())
 | 
					        RL_BASELINE_EMA.update(arch_prec1.item())
 | 
				
			||||||
        rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
 | 
					        rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
 | 
				
			||||||
      rl_log_prob = sum(log_probs)
 | 
					      rl_log_prob = sum(log_probs)
 | 
				
			||||||
      arch_loss = - rl_advantage * rl_log_prob
 | 
					      arch_loss = - rl_advantage * rl_log_prob
 | 
				
			||||||
    elif algo == 'tas' or algo == 'fbv2':
 | 
					    elif algo == 'tas' or algo == 'mask_gumbel':
 | 
				
			||||||
      arch_loss = criterion(logits, arch_targets)
 | 
					      arch_loss = criterion(logits, arch_targets)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
      raise ValueError('invalid algorightm name: {:}'.format(algo))
 | 
					      raise ValueError('invalid algorightm name: {:}'.format(algo))
 | 
				
			||||||
@@ -231,7 +233,7 @@ def main(xargs):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
 | 
					    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if xargs.algo == 'fbv2' or xargs.algo == 'tas':
 | 
					    if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas':
 | 
				
			||||||
      network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
 | 
					      network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
 | 
				
			||||||
      logger.log('[RESET tau as : {:}]'.format(network.tau))
 | 
					      logger.log('[RESET tau as : {:}]'.format(network.tau))
 | 
				
			||||||
    search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
 | 
					    search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
 | 
				
			||||||
@@ -291,7 +293,7 @@ if __name__ == '__main__':
 | 
				
			|||||||
  parser.add_argument('--data_path'   ,       type=str,   help='Path to dataset')
 | 
					  parser.add_argument('--data_path'   ,       type=str,   help='Path to dataset')
 | 
				
			||||||
  parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
 | 
					  parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
 | 
				
			||||||
  parser.add_argument('--search_space',       type=str,   default='sss', choices=['sss'], help='The search space name.')
 | 
					  parser.add_argument('--search_space',       type=str,   default='sss', choices=['sss'], help='The search space name.')
 | 
				
			||||||
  parser.add_argument('--algo'        ,       type=str,   choices=['tas', 'fbv2', 'tunas'], help='The search space name.')
 | 
					  parser.add_argument('--algo'        ,       type=str,   choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.')
 | 
				
			||||||
  parser.add_argument('--genotype'    ,       type=str,   default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
 | 
					  parser.add_argument('--genotype'    ,       type=str,   default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
 | 
				
			||||||
  parser.add_argument('--use_api'     ,       type=int,   default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
 | 
					  parser.add_argument('--use_api'     ,       type=int,   default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
 | 
				
			||||||
  # FOR GDAS
 | 
					  # FOR GDAS
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user