From ffd23a6cbdc4b1ed911bc3349681cee3e1333569 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 18 Aug 2020 00:50:33 +0000 Subject: [PATCH] Support GDAS (FRC), see details in docs/CVPR-2019-GDAS.md --- .../search-archs/GDASFRC-NASNet-CIFAR.config | 9 ++ docs/CVPR-2019-GDAS.md | 7 + lib/models/cell_operations.py | 44 +++--- lib/models/cell_searchs/__init__.py | 2 + lib/models/cell_searchs/search_cells.py | 4 + .../search_model_gdas_frc_nasnet.py | 125 ++++++++++++++++++ .../NASNet-space-search-by-GDAS-FRC.sh | 38 ++++++ 7 files changed, 212 insertions(+), 17 deletions(-) create mode 100644 configs/search-archs/GDASFRC-NASNet-CIFAR.config create mode 100644 lib/models/cell_searchs/search_model_gdas_frc_nasnet.py create mode 100644 scripts-search/NASNet-space-search-by-GDAS-FRC.sh diff --git a/configs/search-archs/GDASFRC-NASNet-CIFAR.config b/configs/search-archs/GDASFRC-NASNet-CIFAR.config new file mode 100644 index 0000000..9ddb9ae --- /dev/null +++ b/configs/search-archs/GDASFRC-NASNet-CIFAR.config @@ -0,0 +1,9 @@ +{ + "super_type" : ["str", "nasnet-super"], + "name" : ["str", "GDAS_FRC"], + "C" : ["int", "16" ], + "N" : ["int", "2" ], + "steps" : ["int", "4" ], + "multiplier" : ["int", "4" ], + "stem_multiplier" : ["int", "3" ] +} diff --git a/docs/CVPR-2019-GDAS.md b/docs/CVPR-2019-GDAS.md index ffafd26..f6f2be5 100644 --- a/docs/CVPR-2019-GDAS.md +++ b/docs/CVPR-2019-GDAS.md @@ -37,9 +37,14 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_ If you are interested in the configs of each NAS-searched architecture, they are defined at [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py). ### Searching on the NASNet search space + Please use the following scripts to use GDAS to search as in the original paper: ``` +# search for both normal and reduction cells CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NASNet-space-search-by-GDAS.sh cifar10 1 -1 + +# search for the normal cell while use a fixed reduction cell +CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NASNet-space-search-by-GDAS-FRC.sh cifar10 1 -1 ``` **After searching**, if you want to re-train the searched architecture found by the above script, you can use the following script: @@ -52,7 +57,9 @@ Note that `gdas-searched` is a string to indicate the name of the saved dir and The above script does not apply heavy augmentation to train the model, so the accuracy will be lower than the original paper. If you want to change the default hyper-parameter for re-training, please have a look at `./scripts/retrain-searched-net.sh` and `configs/archs/NAS-*-none.config`. + ### Searching on a small search space (NAS-Bench-201) + The GDAS searching codes on a small search space: ``` CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1 diff --git a/lib/models/cell_operations.py b/lib/models/cell_operations.py index a17fb44..ff1231a 100644 --- a/lib/models/cell_operations.py +++ b/lib/models/cell_operations.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] +__all__ = ['OPS', 'RAW_OP_CLASSES', 'ResNetBasicblock', 'SearchSpaceNames'] OPS = { 'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride), @@ -175,7 +175,7 @@ class FactorizedReduce(nn.Module): self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine)) self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) elif stride == 1: - self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False) + self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=not affine) else: raise ValueError('Invalid stride : {:}'.format(stride)) self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) @@ -256,41 +256,44 @@ def drop_path(x, drop_prob): # Searching for A Robust Neural Architecture in Four GPU Hours class GDAS_Reduction_Cell(nn.Module): - def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats): + def __init__(self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats): super(GDAS_Reduction_Cell, self).__init__() if reduction_prev: self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats) else: self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats) self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats) - self.multiplier = multiplier self.reduction = True self.ops1 = nn.ModuleList( [nn.Sequential( nn.ReLU(inplace=False), - nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False), - nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False), - nn.BatchNorm2d(C, affine=True), + nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine), + nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine), + nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=False), - nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(C, affine=True)), + nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), + nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)), nn.Sequential( nn.ReLU(inplace=False), - nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False), - nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False), - nn.BatchNorm2d(C, affine=True), + nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=not affine), + nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=not affine), + nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=False), - nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(C, affine=True))]) + nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), + nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))]) self.ops2 = nn.ModuleList( [nn.Sequential( - nn.MaxPool2d(3, stride=1, padding=1), - nn.BatchNorm2d(C, affine=True)), + nn.MaxPool2d(3, stride=2, padding=1), + nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats)), nn.Sequential( nn.MaxPool2d(3, stride=2, padding=1), - nn.BatchNorm2d(C, affine=True))]) + nn.BatchNorm2d(C, affine=affine, track_running_stats=track_running_stats))]) + + @property + def multiplier(self): + return 4 def forward(self, s0, s1, drop_prob = -1): s0 = self.preprocess0(s0) @@ -307,3 +310,10 @@ class GDAS_Reduction_Cell(nn.Module): if self.training and drop_prob > 0.: X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob) return torch.cat([X0, X1, X2, X3], dim=1) + + +# To manage the useful classes in this file. +RAW_OP_CLASSES = { + 'gdas_reduction': GDAS_Reduction_Cell +} + diff --git a/lib/models/cell_searchs/__init__.py b/lib/models/cell_searchs/__init__.py index 968eb37..05a315c 100644 --- a/lib/models/cell_searchs/__init__.py +++ b/lib/models/cell_searchs/__init__.py @@ -11,6 +11,7 @@ from .generic_model import GenericNAS201Model from .genotypes import Structure as CellStructure, architectures as CellArchitectures # NASNet-based macro structure from .search_model_gdas_nasnet import NASNetworkGDAS +from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC from .search_model_darts_nasnet import NASNetworkDARTS @@ -23,4 +24,5 @@ nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, "generic": GenericNAS201Model} nasnet_super_nets = {"GDAS": NASNetworkGDAS, + "GDAS_FRC": NASNetworkGDAS_FRC, "DARTS": NASNetworkDARTS} diff --git a/lib/models/cell_searchs/search_cells.py b/lib/models/cell_searchs/search_cells.py index 5e10224..818a32c 100644 --- a/lib/models/cell_searchs/search_cells.py +++ b/lib/models/cell_searchs/search_cells.py @@ -163,6 +163,10 @@ class NASNetSearchCell(nn.Module): self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} self.num_edges = len(self.edges) + @property + def multiplier(self): + return self._multiplier + def forward_gdas(self, s0, s1, weightss, indexs): s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) diff --git a/lib/models/cell_searchs/search_model_gdas_frc_nasnet.py b/lib/models/cell_searchs/search_model_gdas_frc_nasnet.py new file mode 100644 index 0000000..06aaf93 --- /dev/null +++ b/lib/models/cell_searchs/search_model_gdas_frc_nasnet.py @@ -0,0 +1,125 @@ +########################################################################### +# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # +########################################################################### +import torch +import torch.nn as nn +from copy import deepcopy +from models.cell_searchs.search_cells import NASNetSearchCell as SearchCell +from models.cell_operations import RAW_OP_CLASSES + + +# The macro structure is based on NASNet +class NASNetworkGDAS_FRC(nn.Module): + + def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): + super(NASNetworkGDAS_FRC, self).__init__() + self._C = C + self._layerN = N + self._steps = steps + self._multiplier = multiplier + self.stem = nn.Sequential( + nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C*stem_multiplier)) + + # config for each layer + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) + layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) + + num_edge, edge2index = None, None + C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False + + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = RAW_OP_CLASSES['gdas_reduction'](C_prev_prev, C_prev, C_curr, reduction_prev, affine, track_running_stats) + else: + cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert reduction or num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev_prev, C_prev, reduction_prev = C_prev, cell.multiplier * C_curr, reduction + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + self.tau = 10 + + def get_weights(self): + xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) + xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) + xlist+= list( self.classifier.parameters() ) + return xlist + + def set_tau(self, tau): + self.tau = tau + + def get_tau(self): + return self.tau + + def get_alphas(self): + return [self.arch_parameters] + + def show_alphas(self): + with torch.no_grad(): + A = 'arch-normal-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu()) + return '{:}'.format(A) + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def genotype(self): + def _parse(weights): + gene = [] + for i in range(self._steps): + edges = [] + for j in range(2+i): + node_str = '{:}<-{:}'.format(i, j) + ws = weights[ self.edge2index[node_str] ] + for k, op_name in enumerate(self.op_names): + if op_name == 'none': continue + edges.append( (op_name, j, ws[k]) ) + edges = sorted(edges, key=lambda x: -x[-1]) + selected_edges = edges[:2] + gene.append( tuple(selected_edges) ) + return gene + with torch.no_grad(): + gene_normal = _parse(torch.softmax(self.arch_parameters, dim=-1).cpu().numpy()) + return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} + + def forward(self, inputs): + def get_gumbel_prob(xins): + while True: + gumbels = -torch.empty_like(xins).exponential_().log() + logits = (xins.log_softmax(dim=1) + gumbels) / self.tau + probs = nn.functional.softmax(logits, dim=1) + index = probs.max(-1, keepdim=True)[1] + one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) + hardwts = one_h - probs.detach() + probs + if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): + continue + else: break + return hardwts, index + + hardwts, index = get_gumbel_prob(self.arch_parameters) + + s0 = s1 = self.stem(inputs) + for i, cell in enumerate(self.cells): + if cell.reduction: + s0, s1 = s1, cell(s0, s1) + else: + s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) + out = self.lastact(s1) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/scripts-search/NASNet-space-search-by-GDAS-FRC.sh b/scripts-search/NASNet-space-search-by-GDAS-FRC.sh new file mode 100644 index 0000000..22d26c9 --- /dev/null +++ b/scripts-search/NASNet-space-search-by-GDAS-FRC.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# bash ./scripts-search/NASNet-space-search-by-GDAS-FRC.sh cifar10 1 -1 +echo script name: $0 +echo $# arguments +if [ "$#" -ne 3 ] ;then + echo "Input illegal number of parameters " $# + echo "Need 3 parameters for dataset, track_running_stats, and seed" + exit 1 +fi +if [ "$TORCH_HOME" = "" ]; then + echo "Must set TORCH_HOME envoriment variable for data dir saving" + exit 1 +else + echo "TORCH_HOME : $TORCH_HOME" +fi + +dataset=$1 +track_running_stats=$2 +seed=$3 +space=darts + +if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then + data_path="$TORCH_HOME/cifar.python" +else + data_path="$TORCH_HOME/cifar.python/ImageNet16" +fi + +save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${track_running_stats} + +OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \ + --save_dir ${save_dir} \ + --dataset ${dataset} --data_path ${data_path} \ + --search_space_name ${space} \ + --config_path configs/search-opts/GDAS-NASNet-CIFAR.config \ + --model_config configs/search-archs/GDASFRC-NASNet-CIFAR.config \ + --tau_max 10 --tau_min 0.1 --track_running_stats ${track_running_stats} \ + --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ + --workers 4 --print_freq 200 --rand_seed ${seed}