Support GDAS (FRC), see details in docs/CVPR-2019-GDAS.md

This commit is contained in:
D-X-Y 2020-08-18 00:50:33 +00:00
parent 75eefa3d44
commit ffd23a6cbd
7 changed files with 212 additions and 17 deletions

View File

@ -0,0 +1,9 @@
{
"super_type" : ["str", "nasnet-super"],
"name" : ["str", "GDAS_FRC"],
"C" : ["int", "16" ],
"N" : ["int", "2" ],
"steps" : ["int", "4" ],
"multiplier" : ["int", "4" ],
"stem_multiplier" : ["int", "3" ]
}

View File

@ -37,9 +37,14 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_
If you are interested in the configs of each NAS-searched architecture, they are defined at [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
### Searching on the NASNet search space
Please use the following scripts to use GDAS to search as in the original paper:
```
# search for both normal and reduction cells
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NASNet-space-search-by-GDAS.sh cifar10 1 -1
# search for the normal cell while use a fixed reduction cell
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NASNet-space-search-by-GDAS-FRC.sh cifar10 1 -1
```
**After searching**, if you want to re-train the searched architecture found by the above script, you can use the following script:
@ -52,7 +57,9 @@ Note that `gdas-searched` is a string to indicate the name of the saved dir and
The above script does not apply heavy augmentation to train the model, so the accuracy will be lower than the original paper.
If you want to change the default hyper-parameter for re-training, please have a look at `./scripts/retrain-searched-net.sh` and `configs/archs/NAS-*-none.config`.
### Searching on a small search space (NAS-Bench-201)
The GDAS searching codes on a small search space:
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1

View File

@ -4,7 +4,7 @@
import torch
import torch.nn as nn
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
__all__ = ['OPS', 'RAW_OP_CLASSES', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
@ -175,7 +175,7 @@ class FactorizedReduce(nn.Module):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=not affine)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
@ -256,41 +256,44 @@ def drop_path(x, drop_prob):
# Searching for A Robust Neural Architecture in Four GPU Hours
class GDAS_Reduction_Cell(nn.Module):
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats):
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats):
super(GDAS_Reduction_Cell, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.multiplier = multiplier
self.reduction = True
self.ops1 = nn.ModuleList(
[nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine),
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True)),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)),
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine),
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True))])
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))])
self.ops2 = nn.ModuleList(
[nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.BatchNorm2d(C, affine=True)),
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)),
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=True))])
nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))])
@property
def multiplier(self):
return 4
def forward(self, s0, s1, drop_prob = -1):
s0 = self.preprocess0(s0)
@ -307,3 +310,10 @@ class GDAS_Reduction_Cell(nn.Module):
if self.training and drop_prob > 0.:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
return torch.cat([X0, X1, X2, X3], dim=1)
# To manage the useful classes in this file.
RAW_OP_CLASSES = {
'gdas_reduction': GDAS_Reduction_Cell
}

View File

@ -11,6 +11,7 @@ from .generic_model import GenericNAS201Model
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
from .search_model_gdas_nasnet import NASNetworkGDAS
from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC
from .search_model_darts_nasnet import NASNetworkDARTS
@ -23,4 +24,5 @@ nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
"generic": GenericNAS201Model}
nasnet_super_nets = {"GDAS": NASNetworkGDAS,
"GDAS_FRC": NASNetworkGDAS_FRC,
"DARTS": NASNetworkDARTS}

View File

@ -163,6 +163,10 @@ class NASNetSearchCell(nn.Module):
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
@property
def multiplier(self):
return self._multiplier
def forward_gdas(self, s0, s1, weightss, indexs):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)

View File

@ -0,0 +1,125 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from models.cell_searchs.search_cells import NASNetSearchCell as SearchCell
from models.cell_operations import RAW_OP_CLASSES
# The macro structure is based on NASNet
class NASNetworkGDAS_FRC(nn.Module):
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats):
super(NASNetworkGDAS_FRC, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = RAW_OP_CLASSES['gdas_reduction'](C_prev_prev, C_prev, C_curr, reduction_prev, affine, track_running_stats)
else:
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert reduction or num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, cell.multiplier * C_curr, reduction
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.tau = 10
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu())
return '{:}'.format(A)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs):
def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue
else: break
return hardwts, index
hardwts, index = get_gumbel_prob(self.arch_parameters)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
s0, s1 = s1, cell(s0, s1)
else:
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,38 @@
#!/bin/bash
# bash ./scripts-search/NASNet-space-search-by-GDAS-FRC.sh cifar10 1 -1
echo script name: $0
echo $# arguments
if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 3 parameters for dataset, track_running_stats, and seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
dataset=$1
track_running_stats=$2
seed=$3
space=darts
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python"
else
data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi
save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${track_running_stats}
OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
--save_dir ${save_dir} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \
--config_path configs/search-opts/GDAS-NASNet-CIFAR.config \
--model_config configs/search-archs/GDASFRC-NASNet-CIFAR.config \
--tau_max 10 --tau_min 0.1 --track_running_stats ${track_running_stats} \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed}