Update NATS (sss) algorithms -- warmup
This commit is contained in:
parent
a306fd4562
commit
ad5d6e28b9
@ -1,6 +1,11 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
######################################################################################
|
||||
# In this file, we aims to evaluate three kinds of channel searching strategies:
|
||||
# -
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
||||
@ -51,7 +56,7 @@ class ExponentialMovingAverage(object):
|
||||
RL_BASELINE_EMA = ExponentialMovingAverage(0.95)
|
||||
|
||||
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger):
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, enable_controller, algo, epoch_str, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
@ -80,6 +85,7 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
|
||||
# update the architecture-weight
|
||||
network.zero_grad()
|
||||
a_optimizer.zero_grad()
|
||||
_, logits, log_probs = network(arch_inputs)
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
if algo == 'tunas':
|
||||
@ -92,8 +98,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
else:
|
||||
raise ValueError('invalid algorightm name: {:}'.format(algo))
|
||||
arch_loss.backward()
|
||||
a_optimizer.step()
|
||||
if enable_controller:
|
||||
arch_loss.backward()
|
||||
a_optimizer.step()
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||
@ -208,13 +215,22 @@ def main(xargs):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True))
|
||||
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
|
||||
if xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch:
|
||||
enable_controller = True
|
||||
network.set_warmup_ratio(None)
|
||||
else:
|
||||
enable_controller = False
|
||||
network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio)
|
||||
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
|
||||
|
||||
if xargs.algo == 'fbv2' or xargs.algo == 'tas':
|
||||
network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
|
||||
network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
|
||||
logger.log('[RESET tau as : {:}]'.format(network.tau))
|
||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, xargs.algo, epoch_str, xargs.print_freq, logger)
|
||||
= search_func(search_loader, network, criterion, w_scheduler,
|
||||
w_optimizer, a_optimizer, enable_controller, xargs.algo, epoch_str, xargs.print_freq, logger)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
|
||||
@ -275,6 +291,8 @@ if __name__ == '__main__':
|
||||
# FOR GDAS
|
||||
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
|
||||
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
|
||||
# FOR ALL
|
||||
parser.add_argument('--warmup_ratio', type=float, help='The warmup ratio, if None, not use warmup.')
|
||||
#
|
||||
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
|
||||
@ -291,7 +309,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
dirname = '{:}-affine{:}_BN{:}-AWD{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay)
|
||||
dirname = '{:}-affine{:}_BN{:}-AWD{:}-WARM{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay, args.warmup_ratio)
|
||||
if args.overwite_epochs is not None:
|
||||
dirname = dirname + '-E{:}'.format(args.overwite_epochs)
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname)
|
||||
|
@ -26,7 +26,7 @@ from nats_bench import create
|
||||
from log_utils import time_string
|
||||
|
||||
|
||||
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
|
||||
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'):
|
||||
ss_dir = '{:}-{:}'.format(root_dir, search_space)
|
||||
alg2name, alg2path = OrderedDict(), OrderedDict()
|
||||
seeds = [777, 888, 999]
|
||||
@ -39,9 +39,9 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
|
||||
alg2name['ENAS'] = 'enas-affine0_BN0-None'
|
||||
alg2name['SETN'] = 'setn-affine0_BN0-None'
|
||||
else:
|
||||
alg2name['TAS'] = 'tas-affine0_BN0'
|
||||
alg2name['FBNetV2'] = 'fbv2-affine0_BN0'
|
||||
alg2name['TuNAS'] = 'tunas-affine0_BN0'
|
||||
alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
|
||||
alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
|
||||
alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
|
||||
for alg, name in alg2name.items():
|
||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
||||
alg2data = OrderedDict()
|
||||
|
@ -1,6 +1,10 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# Here, we utilized three techniques to search for the number of channels:
|
||||
# - feature interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
|
||||
# - masking + GumbelSoftmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||
from typing import List, Text, Any
|
||||
import random, torch
|
||||
import torch.nn as nn
|
||||
@ -43,6 +47,7 @@ class GenericNAS301Model(nn.Module):
|
||||
# algorithm related
|
||||
self.register_buffer('_tau', torch.zeros(1))
|
||||
self._algo = None
|
||||
self._warmup_ratio = None
|
||||
|
||||
def set_algo(self, algo: Text):
|
||||
# used for searching
|
||||
@ -62,6 +67,13 @@ class GenericNAS301Model(nn.Module):
|
||||
def set_tau(self, tau):
|
||||
self._tau.data[:] = tau
|
||||
|
||||
@property
|
||||
def warmup_ratio(self):
|
||||
return self._warmup_ratio
|
||||
|
||||
def set_warmup_ratio(self, ratio: float):
|
||||
self._warmup_ratio = ratio
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
xlist = list(self._cells.parameters())
|
||||
@ -112,7 +124,13 @@ class GenericNAS301Model(nn.Module):
|
||||
feature = cell(feature)
|
||||
# apply different searching algorithms
|
||||
idx = max(0, i-1)
|
||||
if self._algo == 'fbv2':
|
||||
if self._warmup_ratio is not None:
|
||||
if random.random() < self._warmup_ratio:
|
||||
mask = self._masks[-1]
|
||||
else:
|
||||
mask = self._masks[random.randint(0, len(self._masks)-1)]
|
||||
feature = feature * mask.view(1, -1, 1, 1)
|
||||
elif self._algo == 'fbv2':
|
||||
weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1)
|
||||
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
|
||||
feature = feature * mask
|
||||
|
30
scripts-search/NATS/search-size.sh
Normal file
30
scripts-search/NATS/search-size.sh
Normal file
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
# bash ./NATS/search-size.sh 0 777
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 2 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 2 parameters for GPU-device and seed"
|
||||
exit 1
|
||||
fi
|
||||
if [ "$TORCH_HOME" = "" ]; then
|
||||
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||
exit 1
|
||||
else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
device=$1
|
||||
seed=$2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed ${seed}
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed ${seed}
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed}
|
||||
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed ${seed}
|
Loading…
Reference in New Issue
Block a user