update CVPR-2019-GDAS re-train NASNet-search-space searched models

This commit is contained in:
D-X-Y 2020-03-06 19:29:07 +11:00
parent 8b6df42f1f
commit 9a83814a46
17 changed files with 278 additions and 21 deletions

View File

@ -0,0 +1,10 @@
{
"super_type": ["str", "infer-nasnet.cifar"],
"genotype" : ["none", "none"],
"dataset" : ["str", "cifar"],
"ichannel" : ["int", 33],
"layers" : ["int", 6],
"stem_multi": ["int", 3],
"auxiliary" : ["bool", 1],
"drop_path_prob": ["float", 0.2]
}

View File

@ -0,0 +1,9 @@
{
"super_type": ["str", "infer-nasnet.imagenet"],
"genotype" : ["none", "none"],
"dataset" : ["str", "imagenet"],
"ichannel" : ["int", 50],
"layers" : ["int", 4],
"auxiliary" : ["bool", 1],
"drop_path_prob": ["float", 0]
}

View File

@ -41,7 +41,16 @@ Please use the following scripts to use GDAS to search as in the original paper:
``` ```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1 CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1
``` ```
If you want to train the searched architecture found by the above scripts, you need to add the config of that architecture (will be printed in log) in [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
**After searching***, if you want to re-train the searched architecture found by the above script, you can use the following script:
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts/retrain-searched-net.sh cifar10 gdas-searched \
output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth 96 -1
```
Note that `gdas-searched` is a string to indicate the name of the saved dir and `output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth` is the file path that the searching algorithm generated.
The above script does not apply heavy augmentation to train the model, so the accuracy will be lower than the original paper.
If you want to change the default hyper-parameter for re-training, please have a look at `./scripts/retrain-searched-net.sh` and `configs/archs/NAS-*-none.config`.
### Searching on a small search space (NAS-Bench-201) ### Searching on a small search space (NAS-Bench-201)
The GDAS searching codes on a small search space: The GDAS searching codes on a small search space:

View File

@ -39,7 +39,9 @@ def main(args):
if args.model_source == 'normal': if args.model_source == 'normal':
base_model = obtain_model(model_config) base_model = obtain_model(model_config)
elif args.model_source == 'nas': elif args.model_source == 'nas':
base_model = obtain_nas_infer_model(model_config) base_model = obtain_nas_infer_model(model_config, args.extra_model_path)
elif args.model_source == 'autodl-searched':
base_model = obtain_model(model_config, args.extra_model_path)
else: else:
raise ValueError('invalid model-source : {:}'.format(args.model_source)) raise ValueError('invalid model-source : {:}'.format(args.model_source))
flop, param = get_model_infos(base_model, xshape) flop, param = get_model_infos(base_model, xshape)

View File

@ -12,6 +12,7 @@ def obtain_basic_args():
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.') parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.')
parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).')
add_shared_args( parser ) add_shared_args( parser )
# Optimization options # Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')

View File

@ -29,7 +29,8 @@ def convert_param(original_lists):
elif ctype == 'float': elif ctype == 'float':
x = float(x) x = float(x)
elif ctype == 'none': elif ctype == 'none':
assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x) if x.lower() != 'none':
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
x = None x = None
else: else:
raise TypeError('Does not know this type : {:}'.format(ctype)) raise TypeError('Does not know this type : {:}'.format(ctype))

View File

@ -3,6 +3,7 @@
################################################## ##################################################
from os import path as osp from os import path as osp
from typing import List, Text from typing import List, Text
import torch
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ 'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
@ -38,6 +39,9 @@ def get_cell_based_tiny_net(config):
genotype = CellStructure.str2structure(config.arch_str) genotype = CellStructure.str2structure(config.arch_str)
else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
return TinyNetwork(config.C, config.N, genotype, config.num_classes) return TinyNetwork(config.C, config.N, genotype, config.num_classes)
elif config.name == 'infer.nasnet-cifar':
from .cell_infers import NASNetonCIFAR
raise NotImplementedError
else: else:
raise ValueError('invalid network name : {:}'.format(config.name)) raise ValueError('invalid network name : {:}'.format(config.name))
@ -52,13 +56,12 @@ def get_search_spaces(xtype, name) -> List[Text]:
raise ValueError('invalid search-space type is {:}'.format(xtype)) raise ValueError('invalid search-space type is {:}'.format(xtype))
def get_cifar_models(config): def get_cifar_models(config, extra_path=None):
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
super_type = getattr(config, 'super_type', 'basic') super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic': if super_type == 'basic':
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
if config.arch == 'resnet': if config.arch == 'resnet':
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
elif config.arch == 'densenet': elif config.arch == 'densenet':
@ -71,6 +74,7 @@ def get_cifar_models(config):
from .shape_infers import InferWidthCifarResNet from .shape_infers import InferWidthCifarResNet
from .shape_infers import InferDepthCifarResNet from .shape_infers import InferDepthCifarResNet
from .shape_infers import InferCifarResNet from .shape_infers import InferCifarResNet
from .cell_infers import NASNetonCIFAR
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1] infer_mode = super_type.split('-')[1]
if infer_mode == 'width': if infer_mode == 'width':
@ -79,6 +83,16 @@ def get_cifar_models(config):
return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual)
elif infer_mode == 'shape': elif infer_mode == 'shape':
return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual)
elif infer_mode == 'nasnet.cifar':
genotype = config.genotype
if extra_path is not None: # reload genotype by extra_path
if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path))
xdata = torch.load(extra_path)
current_epoch = xdata['epoch']
genotype = xdata['genotypes'][current_epoch-1]
C = config.C if hasattr(config, 'C') else config.ichannel
N = config.N if hasattr(config, 'N') else config.layers
return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary)
else: else:
raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) raise ValueError('invalid infer-mode : {:}'.format(infer_mode))
else: else:
@ -111,9 +125,10 @@ def get_imagenet_models(config):
raise ValueError('invalid super-type : {:}'.format(super_type)) raise ValueError('invalid super-type : {:}'.format(super_type))
def obtain_model(config): # Try to obtain the network by config.
def obtain_model(config, extra_path=None):
if config.dataset == 'cifar': if config.dataset == 'cifar':
return get_cifar_models(config) return get_cifar_models(config, extra_path)
elif config.dataset == 'imagenet': elif config.dataset == 'imagenet':
return get_imagenet_models(config) return get_imagenet_models(config)
else: else:
@ -152,7 +167,6 @@ def obtain_search_model(config):
def load_net_from_checkpoint(checkpoint): def load_net_from_checkpoint(checkpoint):
import torch
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
checkpoint = torch.load(checkpoint) checkpoint = torch.load(checkpoint)
model_config = dict2config(checkpoint['model-config'], None) model_config = dict2config(checkpoint['model-config'], None)

View File

@ -2,3 +2,4 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
##################################################### #####################################################
from .tiny_network import TinyNetwork from .tiny_network import TinyNetwork
from .nasnet_cifar import NASNetonCIFAR

View File

@ -2,6 +2,7 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
##################################################### #####################################################
import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from ..cell_operations import OPS from ..cell_operations import OPS
@ -50,3 +51,70 @@ class InferCell(nn.Module):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature ) nodes.append( node_feature )
return nodes[-1] return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
if not reduction:
nodes, concats = genotype['normal'], genotype['normal_concat']
else:
nodes, concats = genotype['reduce'], genotype['reduce_concat']
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = '{:}<-{:}'.format(i+2, j)
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
# [TODO] to support drop_prob in this function..
def forward(self, s0, s1, unused_drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i, node in enumerate(self._nodes):
clist = []
for in_node in node:
name, j = in_node[0], in_node[1]
node_str = '{:}<-{:}'.format(i+2, j)
op = self.edges[ node_str ]
clist.append( op(states[j]) )
states.append( sum(clist) )
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@ -0,0 +1,71 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
# The macro structure is based on NASNet
class NASNetonCIFAR(nn.Module):
def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True):
super(NASNetonCIFAR, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.auxiliary_index = None
self.auxiliary_head = None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction
if reduction and C_curr == C*4 and auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
self.auxiliary_index = index
self._Layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob
def auxiliary_param(self):
if self.auxiliary_head is None: return []
else: return list( self.auxiliary_head.parameters() )
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
stem_feature, logits_aux = self.stem(inputs), None
cell_results = [stem_feature, stem_feature]
for i, cell in enumerate(self.cells):
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
cell_results.append( cell_feature )
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training:
logits_aux = self.auxiliary_head( cell_results[-1] )
out = self.lastact(cell_results[-1])
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if logits_aux is None: return out, logits
else: return out, [logits, logits_aux]

View File

@ -155,7 +155,7 @@ class NASNetSearchCell(nn.Module):
self.edges = nn.ModuleDict() self.edges = nn.ModuleDict()
for i in range(self._steps): for i in range(self._steps):
for j in range(2+i): for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j) node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2)
stride = 2 if reduction and j < 2 else 1 stride = 2 if reduction and j < 2 else 1
op = MixedOp(space, C, stride, affine, track_running_stats) op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[ node_str ] = op self.edges[ node_str ] = op

View File

@ -5,8 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from typing import List, Text, Dict from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell from .search_cells import NASNetSearchCell as SearchCell
from .genotypes import Structure
# The macro structure is based on NASNet # The macro structure is based on NASNet

View File

@ -4,8 +4,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell from .search_cells import NASNetSearchCell as SearchCell
from .genotypes import Structure
# The macro structure is based on NASNet # The macro structure is based on NASNet

View File

@ -168,5 +168,15 @@ Networks = {'DARTS_V1': DARTS_V1,
'SETN' : SETN, 'SETN' : SETN,
} }
# This function will return a Genotype from a dict.
def build_genotype_from_dict(xdict): def build_genotype_from_dict(xdict):
import pdb; pdb.set_trace() def remove_value(nodes):
return [tuple([(x[0], x[1]) for x in node]) for node in nodes]
genotype = Genotype(
normal=remove_value(xdict['normal']),
normal_concat=xdict['normal_concat'],
reduce=remove_value(xdict['reduce']),
reduce_concat=xdict['reduce_concat'],
connectN=None, connects=None
)
return genotype

View File

@ -6,12 +6,22 @@
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019). # Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
################################################## ##################################################
import torch import os, torch
def obtain_nas_infer_model(config): def obtain_nas_infer_model(config, extra_model_path=None):
if config.arch == 'dxys': if config.arch == 'dxys':
from .DXYs import CifarNet, ImageNet, Networks from .DXYs import CifarNet, ImageNet, Networks
genotype = Networks[config.genotype] from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
xdata = torch.load(extra_model_path)
current_epoch = xdata['epoch']
genotype_dict = xdata['genotypes'][current_epoch-1]
genotype = build_genotype_from_dict(genotype_dict)
else:
genotype = Networks[config.genotype]
if config.dataset == 'cifar': if config.dataset == 'cifar':
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num) return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
elif config.dataset == 'imagenet': elif config.dataset == 'imagenet':

View File

@ -4,7 +4,7 @@ echo script name: $0
echo $# arguments echo $# arguments
if [ "$#" -ne 4 ] ;then if [ "$#" -ne 4 ] ;then
echo "Input illegal number of parameters " $# echo "Input illegal number of parameters " $#
echo "Need 4 parameters for dataset and the-model-name and epochs and LR and the-batch-size and the-random-seed" echo "Need 4 parameters for dataset, the-model-name, the-batch-size and the-random-seed"
exit 1 exit 1
fi fi
if [ "$TORCH_HOME" = "" ]; then if [ "$TORCH_HOME" = "" ]; then

View File

@ -0,0 +1,53 @@
#!/bin/bash
# bash ./scripts/retrain-searched-net.sh cifar10 ${NAME} ${PATH} 256 -1
echo script name: $0
echo $# arguments
if [ "$#" -ne 5 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 5 parameters for dataset, the save dir base name, the model path, the batch size, the random seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
dataset=$1
save_name=$2
model_path=$3
batch=$4
rseed=$5
if [ ${dataset} == 'cifar10' ] || [ ${dataset} == 'cifar100' ]; then
xpath=$TORCH_HOME/cifar.python
base=CIFAR
workers=4
cutout_length=16
elif [ ${dataset} == 'imagenet-1k' ]; then
xpath=$TORCH_HOME/ILSVRC2012
base=IMAGENET
workers=28
cutout_length=-1
else
exit 1
echo 'Unknown dataset: '${dataset}
fi
SAVE_ROOT="./output"
save_dir=${SAVE_ROOT}/nas-infer/${dataset}-BS${batch}-${save_name}
python --version
python ./exps/basic-main.py --dataset ${dataset} \
--data_path ${xpath} --model_source autodl-searched \
--model_config ./configs/archs/NAS-${base}-none.config \
--optim_config ./configs/opts/NAS-${base}.config \
--extra_model_path ${model_path} \
--procedure basic \
--save_dir ${save_dir} \
--cutout_length ${cutout_length} \
--batch_size ${batch} --rand_seed ${rseed} --workers ${workers} \
--eval_frequency 1 --print_freq 500 --print_freq_eval 1000