Revise names for compared #channel-search algorithms

This commit is contained in:
D-X-Y 2020-10-15 20:40:27 +11:00
parent bc0ac65882
commit 10e5f05935

View File

@ -1,28 +1,30 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
###########################################################################################################################################
#
# In this file, we aims to evaluate three kinds of channel searching strategies:
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links:
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
#
# For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links:
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
# - FBNetV2: https://github.com/facebookresearch/mobile-vision
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
###########################################################################################################################################
import os, sys, time, random, argparse
import numpy as np
@ -94,13 +96,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
a_optimizer.zero_grad()
_, logits, log_probs = network(arch_inputs)
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
if algo == 'tunas':
if algo == 'mask_rl':
with torch.no_grad():
RL_BASELINE_EMA.update(arch_prec1.item())
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
rl_log_prob = sum(log_probs)
arch_loss = - rl_advantage * rl_log_prob
elif algo == 'tas' or algo == 'fbv2':
elif algo == 'tas' or algo == 'mask_gumbel':
arch_loss = criterion(logits, arch_targets)
else:
raise ValueError('invalid algorightm name: {:}'.format(algo))
@ -231,7 +233,7 @@ def main(xargs):
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
if xargs.algo == 'fbv2' or xargs.algo == 'tas':
if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas':
network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
logger.log('[RESET tau as : {:}]'.format(network.tau))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
@ -291,7 +293,7 @@ if __name__ == '__main__':
parser.add_argument('--data_path' , type=str, help='Path to dataset')
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
parser.add_argument('--algo' , type=str, choices=['tas', 'fbv2', 'tunas'], help='The search space name.')
parser.add_argument('--algo' , type=str, choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.')
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
# FOR GDAS