Revise names for compared #channel-search algorithms
This commit is contained in:
parent
bc0ac65882
commit
10e5f05935
@ -1,28 +1,30 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
###########################################################################################################################################
|
###########################################################################################################################################
|
||||||
|
#
|
||||||
# In this file, we aims to evaluate three kinds of channel searching strategies:
|
# In this file, we aims to evaluate three kinds of channel searching strategies:
|
||||||
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
|
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
|
||||||
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||||
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||||
# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links:
|
#
|
||||||
|
# For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links:
|
||||||
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
|
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
|
||||||
# - FBNetV2: https://github.com/facebookresearch/mobile-vision
|
# - FBNetV2: https://github.com/facebookresearch/mobile-vision
|
||||||
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
|
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
|
||||||
####
|
####
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
|
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25
|
||||||
####
|
####
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
||||||
####
|
####
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
|
||||||
####
|
####
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
|
||||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
|
||||||
###########################################################################################################################################
|
###########################################################################################################################################
|
||||||
import os, sys, time, random, argparse
|
import os, sys, time, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -94,13 +96,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
|||||||
a_optimizer.zero_grad()
|
a_optimizer.zero_grad()
|
||||||
_, logits, log_probs = network(arch_inputs)
|
_, logits, log_probs = network(arch_inputs)
|
||||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
if algo == 'tunas':
|
if algo == 'mask_rl':
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
RL_BASELINE_EMA.update(arch_prec1.item())
|
RL_BASELINE_EMA.update(arch_prec1.item())
|
||||||
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
|
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
|
||||||
rl_log_prob = sum(log_probs)
|
rl_log_prob = sum(log_probs)
|
||||||
arch_loss = - rl_advantage * rl_log_prob
|
arch_loss = - rl_advantage * rl_log_prob
|
||||||
elif algo == 'tas' or algo == 'fbv2':
|
elif algo == 'tas' or algo == 'mask_gumbel':
|
||||||
arch_loss = criterion(logits, arch_targets)
|
arch_loss = criterion(logits, arch_targets)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid algorightm name: {:}'.format(algo))
|
raise ValueError('invalid algorightm name: {:}'.format(algo))
|
||||||
@ -231,7 +233,7 @@ def main(xargs):
|
|||||||
|
|
||||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
|
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
|
||||||
|
|
||||||
if xargs.algo == 'fbv2' or xargs.algo == 'tas':
|
if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas':
|
||||||
network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
|
network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
|
||||||
logger.log('[RESET tau as : {:}]'.format(network.tau))
|
logger.log('[RESET tau as : {:}]'.format(network.tau))
|
||||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||||
@ -291,7 +293,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--data_path' , type=str, help='Path to dataset')
|
parser.add_argument('--data_path' , type=str, help='Path to dataset')
|
||||||
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||||
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
|
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
|
||||||
parser.add_argument('--algo' , type=str, choices=['tas', 'fbv2', 'tunas'], help='The search space name.')
|
parser.add_argument('--algo' , type=str, choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.')
|
||||||
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
|
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
|
||||||
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
|
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
|
||||||
# FOR GDAS
|
# FOR GDAS
|
||||||
|
Loading…
Reference in New Issue
Block a user