diff --git a/configs/archs/NAS-CIFAR-none.config b/configs/archs/NAS-CIFAR-none.config new file mode 100644 index 0000000..52b3bfd --- /dev/null +++ b/configs/archs/NAS-CIFAR-none.config @@ -0,0 +1,10 @@ +{ + "super_type": ["str", "infer-nasnet.cifar"], + "genotype" : ["none", "none"], + "dataset" : ["str", "cifar"], + "ichannel" : ["int", 33], + "layers" : ["int", 6], + "stem_multi": ["int", 3], + "auxiliary" : ["bool", 1], + "drop_path_prob": ["float", 0.2] +} diff --git a/configs/archs/NAS-IMAGENET-none.config b/configs/archs/NAS-IMAGENET-none.config new file mode 100644 index 0000000..192bbbc --- /dev/null +++ b/configs/archs/NAS-IMAGENET-none.config @@ -0,0 +1,9 @@ +{ + "super_type": ["str", "infer-nasnet.imagenet"], + "genotype" : ["none", "none"], + "dataset" : ["str", "imagenet"], + "ichannel" : ["int", 50], + "layers" : ["int", 4], + "auxiliary" : ["bool", 1], + "drop_path_prob": ["float", 0] +} diff --git a/docs/CVPR-2019-GDAS.md b/docs/CVPR-2019-GDAS.md index 4e276cf..bb00eb6 100644 --- a/docs/CVPR-2019-GDAS.md +++ b/docs/CVPR-2019-GDAS.md @@ -41,7 +41,16 @@ Please use the following scripts to use GDAS to search as in the original paper: ``` CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1 ``` -If you want to train the searched architecture found by the above scripts, you need to add the config of that architecture (will be printed in log) in [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py). + +**After searching***, if you want to re-train the searched architecture found by the above script, you can use the following script: +``` +CUDA_VISIBLE_DEVICES=0 bash ./scripts/retrain-searched-net.sh cifar10 gdas-searched \ + output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth 96 -1 +``` +Note that `gdas-searched` is a string to indicate the name of the saved dir and `output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth` is the file path that the searching algorithm generated. + +The above script does not apply heavy augmentation to train the model, so the accuracy will be lower than the original paper. +If you want to change the default hyper-parameter for re-training, please have a look at `./scripts/retrain-searched-net.sh` and `configs/archs/NAS-*-none.config`. ### Searching on a small search space (NAS-Bench-201) The GDAS searching codes on a small search space: diff --git a/exps/basic-main.py b/exps/basic-main.py index f3dda20..ecdf2d9 100644 --- a/exps/basic-main.py +++ b/exps/basic-main.py @@ -39,7 +39,9 @@ def main(args): if args.model_source == 'normal': base_model = obtain_model(model_config) elif args.model_source == 'nas': - base_model = obtain_nas_infer_model(model_config) + base_model = obtain_nas_infer_model(model_config, args.extra_model_path) + elif args.model_source == 'autodl-searched': + base_model = obtain_model(model_config, args.extra_model_path) else: raise ValueError('invalid model-source : {:}'.format(args.model_source)) flop, param = get_model_infos(base_model, xshape) diff --git a/lib/config_utils/basic_args.py b/lib/config_utils/basic_args.py index 22a414f..45d7430 100644 --- a/lib/config_utils/basic_args.py +++ b/lib/config_utils/basic_args.py @@ -12,6 +12,7 @@ def obtain_basic_args(): parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.') + parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).') add_shared_args( parser ) # Optimization options parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') diff --git a/lib/config_utils/configure_utils.py b/lib/config_utils/configure_utils.py index bae2a35..ef96f71 100644 --- a/lib/config_utils/configure_utils.py +++ b/lib/config_utils/configure_utils.py @@ -29,7 +29,8 @@ def convert_param(original_lists): elif ctype == 'float': x = float(x) elif ctype == 'none': - assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x) + if x.lower() != 'none': + raise ValueError('For the none type, the value must be none instead of {:}'.format(x)) x = None else: raise TypeError('Does not know this type : {:}'.format(ctype)) diff --git a/lib/models/__init__.py b/lib/models/__init__.py index 27581fa..7e98999 100644 --- a/lib/models/__init__.py +++ b/lib/models/__init__.py @@ -3,6 +3,7 @@ ################################################## from os import path as osp from typing import List, Text +import torch __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ 'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ @@ -38,6 +39,9 @@ def get_cell_based_tiny_net(config): genotype = CellStructure.str2structure(config.arch_str) else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) return TinyNetwork(config.C, config.N, genotype, config.num_classes) + elif config.name == 'infer.nasnet-cifar': + from .cell_infers import NASNetonCIFAR + raise NotImplementedError else: raise ValueError('invalid network name : {:}'.format(config.name)) @@ -52,13 +56,12 @@ def get_search_spaces(xtype, name) -> List[Text]: raise ValueError('invalid search-space type is {:}'.format(xtype)) -def get_cifar_models(config): - from .CifarResNet import CifarResNet - from .CifarDenseNet import DenseNet - from .CifarWideResNet import CifarWideResNet - +def get_cifar_models(config, extra_path=None): super_type = getattr(config, 'super_type', 'basic') if super_type == 'basic': + from .CifarResNet import CifarResNet + from .CifarDenseNet import DenseNet + from .CifarWideResNet import CifarWideResNet if config.arch == 'resnet': return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) elif config.arch == 'densenet': @@ -71,6 +74,7 @@ def get_cifar_models(config): from .shape_infers import InferWidthCifarResNet from .shape_infers import InferDepthCifarResNet from .shape_infers import InferCifarResNet + from .cell_infers import NASNetonCIFAR assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) infer_mode = super_type.split('-')[1] if infer_mode == 'width': @@ -79,6 +83,16 @@ def get_cifar_models(config): return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) elif infer_mode == 'shape': return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) + elif infer_mode == 'nasnet.cifar': + genotype = config.genotype + if extra_path is not None: # reload genotype by extra_path + if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path)) + xdata = torch.load(extra_path) + current_epoch = xdata['epoch'] + genotype = xdata['genotypes'][current_epoch-1] + C = config.C if hasattr(config, 'C') else config.ichannel + N = config.N if hasattr(config, 'N') else config.layers + return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary) else: raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) else: @@ -111,9 +125,10 @@ def get_imagenet_models(config): raise ValueError('invalid super-type : {:}'.format(super_type)) -def obtain_model(config): +# Try to obtain the network by config. +def obtain_model(config, extra_path=None): if config.dataset == 'cifar': - return get_cifar_models(config) + return get_cifar_models(config, extra_path) elif config.dataset == 'imagenet': return get_imagenet_models(config) else: @@ -152,7 +167,6 @@ def obtain_search_model(config): def load_net_from_checkpoint(checkpoint): - import torch assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) checkpoint = torch.load(checkpoint) model_config = dict2config(checkpoint['model-config'], None) diff --git a/lib/models/cell_infers/__init__.py b/lib/models/cell_infers/__init__.py index 052b477..ac1a183 100644 --- a/lib/models/cell_infers/__init__.py +++ b/lib/models/cell_infers/__init__.py @@ -2,3 +2,4 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### from .tiny_network import TinyNetwork +from .nasnet_cifar import NASNetonCIFAR diff --git a/lib/models/cell_infers/cells.py b/lib/models/cell_infers/cells.py index bd94676..2dbb925 100644 --- a/lib/models/cell_infers/cells.py +++ b/lib/models/cell_infers/cells.py @@ -2,6 +2,7 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### +import torch import torch.nn as nn from copy import deepcopy from ..cell_operations import OPS @@ -50,3 +51,70 @@ class InferCell(nn.Module): node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) nodes.append( node_feature ) return nodes[-1] + + + +# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 +class NASNetInferCell(nn.Module): + + def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): + super(NASNetInferCell, self).__init__() + self.reduction = reduction + if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) + else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) + self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) + + if not reduction: + nodes, concats = genotype['normal'], genotype['normal_concat'] + else: + nodes, concats = genotype['reduce'], genotype['reduce_concat'] + self._multiplier = len(concats) + self._concats = concats + self._steps = len(nodes) + self._nodes = nodes + self.edges = nn.ModuleDict() + for i, node in enumerate(nodes): + for in_node in node: + name, j = in_node[0], in_node[1] + stride = 2 if reduction and j < 2 else 1 + node_str = '{:}<-{:}'.format(i+2, j) + self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats) + + # [TODO] to support drop_prob in this function.. + def forward(self, s0, s1, unused_drop_prob): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + + states = [s0, s1] + for i, node in enumerate(self._nodes): + clist = [] + for in_node in node: + name, j = in_node[0], in_node[1] + node_str = '{:}<-{:}'.format(i+2, j) + op = self.edges[ node_str ] + clist.append( op(states[j]) ) + states.append( sum(clist) ) + return torch.cat([states[x] for x in self._concats], dim=1) + + +class AuxiliaryHeadCIFAR(nn.Module): + + def __init__(self, C, num_classes): + """assuming input size 8x8""" + super(AuxiliaryHeadCIFAR, self).__init__() + self.features = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 + nn.Conv2d(C, 128, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, 2, bias=False), + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.classifier = nn.Linear(768, num_classes) + + def forward(self, x): + x = self.features(x) + x = self.classifier(x.view(x.size(0),-1)) + return x diff --git a/lib/models/cell_infers/nasnet_cifar.py b/lib/models/cell_infers/nasnet_cifar.py new file mode 100644 index 0000000..20b0f82 --- /dev/null +++ b/lib/models/cell_infers/nasnet_cifar.py @@ -0,0 +1,71 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # +##################################################### +import torch +import torch.nn as nn +from copy import deepcopy +from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR + + +# The macro structure is based on NASNet +class NASNetonCIFAR(nn.Module): + + def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True): + super(NASNetonCIFAR, self).__init__() + self._C = C + self._layerN = N + self.stem = nn.Sequential( + nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C*stem_multiplier)) + + # config for each layer + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1) + layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) + + C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False + self.auxiliary_index = None + self.auxiliary_head = None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) + self.cells.append( cell ) + C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction + if reduction and C_curr == C*4 and auxiliary: + self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) + self.auxiliary_index = index + self._Layer = len(self.cells) + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.drop_path_prob = -1 + + def update_drop_path(self, drop_path_prob): + self.drop_path_prob = drop_path_prob + + def auxiliary_param(self): + if self.auxiliary_head is None: return [] + else: return list( self.auxiliary_head.parameters() ) + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def forward(self, inputs): + stem_feature, logits_aux = self.stem(inputs), None + cell_results = [stem_feature, stem_feature] + for i, cell in enumerate(self.cells): + cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) + cell_results.append( cell_feature ) + if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: + logits_aux = self.auxiliary_head( cell_results[-1] ) + out = self.lastact(cell_results[-1]) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + if logits_aux is None: return out, logits + else: return out, [logits, logits_aux] diff --git a/lib/models/cell_searchs/search_cells.py b/lib/models/cell_searchs/search_cells.py index ba397a5..5e10224 100644 --- a/lib/models/cell_searchs/search_cells.py +++ b/lib/models/cell_searchs/search_cells.py @@ -155,7 +155,7 @@ class NASNetSearchCell(nn.Module): self.edges = nn.ModuleDict() for i in range(self._steps): for j in range(2+i): - node_str = '{:}<-{:}'.format(i, j) + node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2) stride = 2 if reduction and j < 2 else 1 op = MixedOp(space, C, stride, affine, track_running_stats) self.edges[ node_str ] = op diff --git a/lib/models/cell_searchs/search_model_darts_nasnet.py b/lib/models/cell_searchs/search_model_darts_nasnet.py index 85c275c..e31fb10 100644 --- a/lib/models/cell_searchs/search_model_darts_nasnet.py +++ b/lib/models/cell_searchs/search_model_darts_nasnet.py @@ -5,8 +5,7 @@ import torch import torch.nn as nn from copy import deepcopy from typing import List, Text, Dict -from .search_cells import NASNetSearchCell as SearchCell -from .genotypes import Structure +from .search_cells import NASNetSearchCell as SearchCell # The macro structure is based on NASNet diff --git a/lib/models/cell_searchs/search_model_gdas_nasnet.py b/lib/models/cell_searchs/search_model_gdas_nasnet.py index 24edffd..2115ce4 100644 --- a/lib/models/cell_searchs/search_model_gdas_nasnet.py +++ b/lib/models/cell_searchs/search_model_gdas_nasnet.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn from copy import deepcopy -from .search_cells import NASNetSearchCell as SearchCell -from .genotypes import Structure +from .search_cells import NASNetSearchCell as SearchCell # The macro structure is based on NASNet diff --git a/lib/nas_infer_model/DXYs/genotypes.py b/lib/nas_infer_model/DXYs/genotypes.py index ec1f449..d1b5c4d 100644 --- a/lib/nas_infer_model/DXYs/genotypes.py +++ b/lib/nas_infer_model/DXYs/genotypes.py @@ -168,5 +168,15 @@ Networks = {'DARTS_V1': DARTS_V1, 'SETN' : SETN, } +# This function will return a Genotype from a dict. def build_genotype_from_dict(xdict): - import pdb; pdb.set_trace() + def remove_value(nodes): + return [tuple([(x[0], x[1]) for x in node]) for node in nodes] + genotype = Genotype( + normal=remove_value(xdict['normal']), + normal_concat=xdict['normal_concat'], + reduce=remove_value(xdict['reduce']), + reduce_concat=xdict['reduce_concat'], + connectN=None, connects=None + ) + return genotype diff --git a/lib/nas_infer_model/__init__.py b/lib/nas_infer_model/__init__.py index ef20f89..d542401 100644 --- a/lib/nas_infer_model/__init__.py +++ b/lib/nas_infer_model/__init__.py @@ -6,12 +6,22 @@ # Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019). ################################################## -import torch +import os, torch -def obtain_nas_infer_model(config): +def obtain_nas_infer_model(config, extra_model_path=None): + if config.arch == 'dxys': from .DXYs import CifarNet, ImageNet, Networks - genotype = Networks[config.genotype] + from .DXYs import build_genotype_from_dict + if config.genotype is None: + if extra_model_path is not None and not os.path.isfile(extra_model_path): + raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path)) + xdata = torch.load(extra_model_path) + current_epoch = xdata['epoch'] + genotype_dict = xdata['genotypes'][current_epoch-1] + genotype = build_genotype_from_dict(genotype_dict) + else: + genotype = Networks[config.genotype] if config.dataset == 'cifar': return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num) elif config.dataset == 'imagenet': diff --git a/scripts/nas-infer-train.sh b/scripts/nas-infer-train.sh index 65ef501..8b2eea2 100644 --- a/scripts/nas-infer-train.sh +++ b/scripts/nas-infer-train.sh @@ -4,7 +4,7 @@ echo script name: $0 echo $# arguments if [ "$#" -ne 4 ] ;then echo "Input illegal number of parameters " $# - echo "Need 4 parameters for dataset and the-model-name and epochs and LR and the-batch-size and the-random-seed" + echo "Need 4 parameters for dataset, the-model-name, the-batch-size and the-random-seed" exit 1 fi if [ "$TORCH_HOME" = "" ]; then diff --git a/scripts/retrain-searched-net.sh b/scripts/retrain-searched-net.sh new file mode 100644 index 0000000..b4f207d --- /dev/null +++ b/scripts/retrain-searched-net.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# bash ./scripts/retrain-searched-net.sh cifar10 ${NAME} ${PATH} 256 -1 +echo script name: $0 +echo $# arguments +if [ "$#" -ne 5 ] ;then + echo "Input illegal number of parameters " $# + echo "Need 5 parameters for dataset, the save dir base name, the model path, the batch size, the random seed" + exit 1 +fi +if [ "$TORCH_HOME" = "" ]; then + echo "Must set TORCH_HOME envoriment variable for data dir saving" + exit 1 +else + echo "TORCH_HOME : $TORCH_HOME" +fi + +dataset=$1 +save_name=$2 +model_path=$3 +batch=$4 +rseed=$5 + +if [ ${dataset} == 'cifar10' ] || [ ${dataset} == 'cifar100' ]; then + xpath=$TORCH_HOME/cifar.python + base=CIFAR + workers=4 + cutout_length=16 +elif [ ${dataset} == 'imagenet-1k' ]; then + xpath=$TORCH_HOME/ILSVRC2012 + base=IMAGENET + workers=28 + cutout_length=-1 +else + exit 1 + echo 'Unknown dataset: '${dataset} +fi + +SAVE_ROOT="./output" + +save_dir=${SAVE_ROOT}/nas-infer/${dataset}-BS${batch}-${save_name} + +python --version + +python ./exps/basic-main.py --dataset ${dataset} \ + --data_path ${xpath} --model_source autodl-searched \ + --model_config ./configs/archs/NAS-${base}-none.config \ + --optim_config ./configs/opts/NAS-${base}.config \ + --extra_model_path ${model_path} \ + --procedure basic \ + --save_dir ${save_dir} \ + --cutout_length ${cutout_length} \ + --batch_size ${batch} --rand_seed ${rseed} --workers ${workers} \ + --eval_frequency 1 --print_freq 500 --print_freq_eval 1000