Update NATS (sss) algorithms -- warmup

This commit is contained in:
D-X-Y 2020-10-06 20:44:15 +11:00
parent a306fd4562
commit ad5d6e28b9
4 changed files with 78 additions and 12 deletions

View File

@ -1,6 +1,11 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
###################################################################################### ######################################################################################
# In this file, we aims to evaluate three kinds of channel searching strategies:
# -
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
@ -51,7 +56,7 @@ class ExponentialMovingAverage(object):
RL_BASELINE_EMA = ExponentialMovingAverage(0.95) RL_BASELINE_EMA = ExponentialMovingAverage(0.95)
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger): def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, enable_controller, algo, epoch_str, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter() data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
@ -80,6 +85,7 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
# update the architecture-weight # update the architecture-weight
network.zero_grad() network.zero_grad()
a_optimizer.zero_grad()
_, logits, log_probs = network(arch_inputs) _, logits, log_probs = network(arch_inputs)
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
if algo == 'tunas': if algo == 'tunas':
@ -92,8 +98,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
arch_loss = criterion(logits, arch_targets) arch_loss = criterion(logits, arch_targets)
else: else:
raise ValueError('invalid algorightm name: {:}'.format(algo)) raise ValueError('invalid algorightm name: {:}'.format(algo))
arch_loss.backward() if enable_controller:
a_optimizer.step() arch_loss.backward()
a_optimizer.step()
# record # record
arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
@ -208,13 +215,22 @@ def main(xargs):
w_scheduler.update(epoch, 0.0) w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True)) need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True))
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
if xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch:
enable_controller = True
network.set_warmup_ratio(None)
else:
enable_controller = False
network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
if xargs.algo == 'fbv2' or xargs.algo == 'tas': if xargs.algo == 'fbv2' or xargs.algo == 'tas':
network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) ) network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
logger.log('[RESET tau as : {:}]'.format(network.tau)) logger.log('[RESET tau as : {:}]'.format(network.tau))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, xargs.algo, epoch_str, xargs.print_freq, logger) = search_func(search_loader, network, criterion, w_scheduler,
w_optimizer, a_optimizer, enable_controller, xargs.algo, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time) search_time.update(time.time() - start_time)
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
@ -275,6 +291,8 @@ if __name__ == '__main__':
# FOR GDAS # FOR GDAS
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.') parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.') parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
# FOR ALL
parser.add_argument('--warmup_ratio', type=float, help='The warmup ratio, if None, not use warmup.')
# #
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
@ -291,7 +309,7 @@ if __name__ == '__main__':
parser.add_argument('--rand_seed', type=int, help='manual seed') parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args() args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
dirname = '{:}-affine{:}_BN{:}-AWD{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay) dirname = '{:}-affine{:}_BN{:}-AWD{:}-WARM{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay, args.warmup_ratio)
if args.overwite_epochs is not None: if args.overwite_epochs is not None:
dirname = dirname + '-E{:}'.format(args.overwite_epochs) dirname = dirname + '-E{:}'.format(args.overwite_epochs)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname) args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname)

View File

@ -26,7 +26,7 @@ from nats_bench import create
from log_utils import time_string from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'):
ss_dir = '{:}-{:}'.format(root_dir, search_space) ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict() alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999] seeds = [777, 888, 999]
@ -39,9 +39,9 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
alg2name['ENAS'] = 'enas-affine0_BN0-None' alg2name['ENAS'] = 'enas-affine0_BN0-None'
alg2name['SETN'] = 'setn-affine0_BN0-None' alg2name['SETN'] = 'setn-affine0_BN0-None'
else: else:
alg2name['TAS'] = 'tas-affine0_BN0' alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
alg2name['FBNetV2'] = 'fbv2-affine0_BN0' alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
alg2name['TuNAS'] = 'tunas-affine0_BN0' alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
for alg, name in alg2name.items(): for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict() alg2data = OrderedDict()

View File

@ -1,6 +1,10 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
##################################################### #####################################################
# Here, we utilized three techniques to search for the number of channels:
# - feature interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + GumbelSoftmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
from typing import List, Text, Any from typing import List, Text, Any
import random, torch import random, torch
import torch.nn as nn import torch.nn as nn
@ -43,6 +47,7 @@ class GenericNAS301Model(nn.Module):
# algorithm related # algorithm related
self.register_buffer('_tau', torch.zeros(1)) self.register_buffer('_tau', torch.zeros(1))
self._algo = None self._algo = None
self._warmup_ratio = None
def set_algo(self, algo: Text): def set_algo(self, algo: Text):
# used for searching # used for searching
@ -62,6 +67,13 @@ class GenericNAS301Model(nn.Module):
def set_tau(self, tau): def set_tau(self, tau):
self._tau.data[:] = tau self._tau.data[:] = tau
@property
def warmup_ratio(self):
return self._warmup_ratio
def set_warmup_ratio(self, ratio: float):
self._warmup_ratio = ratio
@property @property
def weights(self): def weights(self):
xlist = list(self._cells.parameters()) xlist = list(self._cells.parameters())
@ -112,7 +124,13 @@ class GenericNAS301Model(nn.Module):
feature = cell(feature) feature = cell(feature)
# apply different searching algorithms # apply different searching algorithms
idx = max(0, i-1) idx = max(0, i-1)
if self._algo == 'fbv2': if self._warmup_ratio is not None:
if random.random() < self._warmup_ratio:
mask = self._masks[-1]
else:
mask = self._masks[random.randint(0, len(self._masks)-1)]
feature = feature * mask.view(1, -1, 1, 1)
elif self._algo == 'fbv2':
weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1)
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
feature = feature * mask feature = feature * mask

View File

@ -0,0 +1,30 @@
#!/bin/bash
# bash ./NATS/search-size.sh 0 777
echo script name: $0
echo $# arguments
if [ "$#" -ne 2 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 2 parameters for GPU-device and seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
device=$1
seed=$2
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed ${seed}