update CVPR-2019-GDAS re-train NASNet-search-space searched models
This commit is contained in:
parent
8b6df42f1f
commit
9a83814a46
10
configs/archs/NAS-CIFAR-none.config
Normal file
10
configs/archs/NAS-CIFAR-none.config
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"super_type": ["str", "infer-nasnet.cifar"],
|
||||||
|
"genotype" : ["none", "none"],
|
||||||
|
"dataset" : ["str", "cifar"],
|
||||||
|
"ichannel" : ["int", 33],
|
||||||
|
"layers" : ["int", 6],
|
||||||
|
"stem_multi": ["int", 3],
|
||||||
|
"auxiliary" : ["bool", 1],
|
||||||
|
"drop_path_prob": ["float", 0.2]
|
||||||
|
}
|
9
configs/archs/NAS-IMAGENET-none.config
Normal file
9
configs/archs/NAS-IMAGENET-none.config
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"super_type": ["str", "infer-nasnet.imagenet"],
|
||||||
|
"genotype" : ["none", "none"],
|
||||||
|
"dataset" : ["str", "imagenet"],
|
||||||
|
"ichannel" : ["int", 50],
|
||||||
|
"layers" : ["int", 4],
|
||||||
|
"auxiliary" : ["bool", 1],
|
||||||
|
"drop_path_prob": ["float", 0]
|
||||||
|
}
|
@ -41,7 +41,16 @@ Please use the following scripts to use GDAS to search as in the original paper:
|
|||||||
```
|
```
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1
|
||||||
```
|
```
|
||||||
If you want to train the searched architecture found by the above scripts, you need to add the config of that architecture (will be printed in log) in [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
|
|
||||||
|
**After searching***, if you want to re-train the searched architecture found by the above script, you can use the following script:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/retrain-searched-net.sh cifar10 gdas-searched \
|
||||||
|
output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth 96 -1
|
||||||
|
```
|
||||||
|
Note that `gdas-searched` is a string to indicate the name of the saved dir and `output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth` is the file path that the searching algorithm generated.
|
||||||
|
|
||||||
|
The above script does not apply heavy augmentation to train the model, so the accuracy will be lower than the original paper.
|
||||||
|
If you want to change the default hyper-parameter for re-training, please have a look at `./scripts/retrain-searched-net.sh` and `configs/archs/NAS-*-none.config`.
|
||||||
|
|
||||||
### Searching on a small search space (NAS-Bench-201)
|
### Searching on a small search space (NAS-Bench-201)
|
||||||
The GDAS searching codes on a small search space:
|
The GDAS searching codes on a small search space:
|
||||||
|
@ -39,7 +39,9 @@ def main(args):
|
|||||||
if args.model_source == 'normal':
|
if args.model_source == 'normal':
|
||||||
base_model = obtain_model(model_config)
|
base_model = obtain_model(model_config)
|
||||||
elif args.model_source == 'nas':
|
elif args.model_source == 'nas':
|
||||||
base_model = obtain_nas_infer_model(model_config)
|
base_model = obtain_nas_infer_model(model_config, args.extra_model_path)
|
||||||
|
elif args.model_source == 'autodl-searched':
|
||||||
|
base_model = obtain_model(model_config, args.extra_model_path)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid model-source : {:}'.format(args.model_source))
|
raise ValueError('invalid model-source : {:}'.format(args.model_source))
|
||||||
flop, param = get_model_infos(base_model, xshape)
|
flop, param = get_model_infos(base_model, xshape)
|
||||||
|
@ -12,6 +12,7 @@ def obtain_basic_args():
|
|||||||
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
|
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
||||||
parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.')
|
parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.')
|
||||||
|
parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).')
|
||||||
add_shared_args( parser )
|
add_shared_args( parser )
|
||||||
# Optimization options
|
# Optimization options
|
||||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
||||||
|
@ -29,7 +29,8 @@ def convert_param(original_lists):
|
|||||||
elif ctype == 'float':
|
elif ctype == 'float':
|
||||||
x = float(x)
|
x = float(x)
|
||||||
elif ctype == 'none':
|
elif ctype == 'none':
|
||||||
assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x)
|
if x.lower() != 'none':
|
||||||
|
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
|
||||||
x = None
|
x = None
|
||||||
else:
|
else:
|
||||||
raise TypeError('Does not know this type : {:}'.format(ctype))
|
raise TypeError('Does not know this type : {:}'.format(ctype))
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
##################################################
|
##################################################
|
||||||
from os import path as osp
|
from os import path as osp
|
||||||
from typing import List, Text
|
from typing import List, Text
|
||||||
|
import torch
|
||||||
|
|
||||||
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
|
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
|
||||||
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
|
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
|
||||||
@ -38,6 +39,9 @@ def get_cell_based_tiny_net(config):
|
|||||||
genotype = CellStructure.str2structure(config.arch_str)
|
genotype = CellStructure.str2structure(config.arch_str)
|
||||||
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
|
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
|
||||||
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
|
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
|
||||||
|
elif config.name == 'infer.nasnet-cifar':
|
||||||
|
from .cell_infers import NASNetonCIFAR
|
||||||
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid network name : {:}'.format(config.name))
|
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||||
|
|
||||||
@ -52,13 +56,12 @@ def get_search_spaces(xtype, name) -> List[Text]:
|
|||||||
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
||||||
|
|
||||||
|
|
||||||
def get_cifar_models(config):
|
def get_cifar_models(config, extra_path=None):
|
||||||
from .CifarResNet import CifarResNet
|
|
||||||
from .CifarDenseNet import DenseNet
|
|
||||||
from .CifarWideResNet import CifarWideResNet
|
|
||||||
|
|
||||||
super_type = getattr(config, 'super_type', 'basic')
|
super_type = getattr(config, 'super_type', 'basic')
|
||||||
if super_type == 'basic':
|
if super_type == 'basic':
|
||||||
|
from .CifarResNet import CifarResNet
|
||||||
|
from .CifarDenseNet import DenseNet
|
||||||
|
from .CifarWideResNet import CifarWideResNet
|
||||||
if config.arch == 'resnet':
|
if config.arch == 'resnet':
|
||||||
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
|
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
|
||||||
elif config.arch == 'densenet':
|
elif config.arch == 'densenet':
|
||||||
@ -71,6 +74,7 @@ def get_cifar_models(config):
|
|||||||
from .shape_infers import InferWidthCifarResNet
|
from .shape_infers import InferWidthCifarResNet
|
||||||
from .shape_infers import InferDepthCifarResNet
|
from .shape_infers import InferDepthCifarResNet
|
||||||
from .shape_infers import InferCifarResNet
|
from .shape_infers import InferCifarResNet
|
||||||
|
from .cell_infers import NASNetonCIFAR
|
||||||
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
||||||
infer_mode = super_type.split('-')[1]
|
infer_mode = super_type.split('-')[1]
|
||||||
if infer_mode == 'width':
|
if infer_mode == 'width':
|
||||||
@ -79,6 +83,16 @@ def get_cifar_models(config):
|
|||||||
return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual)
|
return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual)
|
||||||
elif infer_mode == 'shape':
|
elif infer_mode == 'shape':
|
||||||
return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual)
|
return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual)
|
||||||
|
elif infer_mode == 'nasnet.cifar':
|
||||||
|
genotype = config.genotype
|
||||||
|
if extra_path is not None: # reload genotype by extra_path
|
||||||
|
if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path))
|
||||||
|
xdata = torch.load(extra_path)
|
||||||
|
current_epoch = xdata['epoch']
|
||||||
|
genotype = xdata['genotypes'][current_epoch-1]
|
||||||
|
C = config.C if hasattr(config, 'C') else config.ichannel
|
||||||
|
N = config.N if hasattr(config, 'N') else config.layers
|
||||||
|
return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid infer-mode : {:}'.format(infer_mode))
|
raise ValueError('invalid infer-mode : {:}'.format(infer_mode))
|
||||||
else:
|
else:
|
||||||
@ -111,9 +125,10 @@ def get_imagenet_models(config):
|
|||||||
raise ValueError('invalid super-type : {:}'.format(super_type))
|
raise ValueError('invalid super-type : {:}'.format(super_type))
|
||||||
|
|
||||||
|
|
||||||
def obtain_model(config):
|
# Try to obtain the network by config.
|
||||||
|
def obtain_model(config, extra_path=None):
|
||||||
if config.dataset == 'cifar':
|
if config.dataset == 'cifar':
|
||||||
return get_cifar_models(config)
|
return get_cifar_models(config, extra_path)
|
||||||
elif config.dataset == 'imagenet':
|
elif config.dataset == 'imagenet':
|
||||||
return get_imagenet_models(config)
|
return get_imagenet_models(config)
|
||||||
else:
|
else:
|
||||||
@ -152,7 +167,6 @@ def obtain_search_model(config):
|
|||||||
|
|
||||||
|
|
||||||
def load_net_from_checkpoint(checkpoint):
|
def load_net_from_checkpoint(checkpoint):
|
||||||
import torch
|
|
||||||
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
|
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
|
||||||
checkpoint = torch.load(checkpoint)
|
checkpoint = torch.load(checkpoint)
|
||||||
model_config = dict2config(checkpoint['model-config'], None)
|
model_config = dict2config(checkpoint['model-config'], None)
|
||||||
|
@ -2,3 +2,4 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||||
#####################################################
|
#####################################################
|
||||||
from .tiny_network import TinyNetwork
|
from .tiny_network import TinyNetwork
|
||||||
|
from .nasnet_cifar import NASNetonCIFAR
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||||
#####################################################
|
#####################################################
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from ..cell_operations import OPS
|
from ..cell_operations import OPS
|
||||||
@ -50,3 +51,70 @@ class InferCell(nn.Module):
|
|||||||
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
|
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
|
||||||
nodes.append( node_feature )
|
nodes.append( node_feature )
|
||||||
return nodes[-1]
|
return nodes[-1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||||
|
class NASNetInferCell(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
|
||||||
|
super(NASNetInferCell, self).__init__()
|
||||||
|
self.reduction = reduction
|
||||||
|
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
|
||||||
|
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
|
||||||
|
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
|
||||||
|
|
||||||
|
if not reduction:
|
||||||
|
nodes, concats = genotype['normal'], genotype['normal_concat']
|
||||||
|
else:
|
||||||
|
nodes, concats = genotype['reduce'], genotype['reduce_concat']
|
||||||
|
self._multiplier = len(concats)
|
||||||
|
self._concats = concats
|
||||||
|
self._steps = len(nodes)
|
||||||
|
self._nodes = nodes
|
||||||
|
self.edges = nn.ModuleDict()
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
for in_node in node:
|
||||||
|
name, j = in_node[0], in_node[1]
|
||||||
|
stride = 2 if reduction and j < 2 else 1
|
||||||
|
node_str = '{:}<-{:}'.format(i+2, j)
|
||||||
|
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
|
||||||
|
|
||||||
|
# [TODO] to support drop_prob in this function..
|
||||||
|
def forward(self, s0, s1, unused_drop_prob):
|
||||||
|
s0 = self.preprocess0(s0)
|
||||||
|
s1 = self.preprocess1(s1)
|
||||||
|
|
||||||
|
states = [s0, s1]
|
||||||
|
for i, node in enumerate(self._nodes):
|
||||||
|
clist = []
|
||||||
|
for in_node in node:
|
||||||
|
name, j = in_node[0], in_node[1]
|
||||||
|
node_str = '{:}<-{:}'.format(i+2, j)
|
||||||
|
op = self.edges[ node_str ]
|
||||||
|
clist.append( op(states[j]) )
|
||||||
|
states.append( sum(clist) )
|
||||||
|
return torch.cat([states[x] for x in self._concats], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class AuxiliaryHeadCIFAR(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C, num_classes):
|
||||||
|
"""assuming input size 8x8"""
|
||||||
|
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||||
|
nn.Conv2d(C, 128, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(128, 768, 2, bias=False),
|
||||||
|
nn.BatchNorm2d(768),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
self.classifier = nn.Linear(768, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = self.classifier(x.view(x.size(0),-1))
|
||||||
|
return x
|
||||||
|
71
lib/models/cell_infers/nasnet_cifar.py
Normal file
71
lib/models/cell_infers/nasnet_cifar.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||||
|
#####################################################
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from copy import deepcopy
|
||||||
|
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
|
||||||
|
|
||||||
|
|
||||||
|
# The macro structure is based on NASNet
|
||||||
|
class NASNetonCIFAR(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True):
|
||||||
|
super(NASNetonCIFAR, self).__init__()
|
||||||
|
self._C = C
|
||||||
|
self._layerN = N
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(C*stem_multiplier))
|
||||||
|
|
||||||
|
# config for each layer
|
||||||
|
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
|
||||||
|
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
|
||||||
|
|
||||||
|
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
|
||||||
|
self.auxiliary_index = None
|
||||||
|
self.auxiliary_head = None
|
||||||
|
self.cells = nn.ModuleList()
|
||||||
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
|
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
|
||||||
|
self.cells.append( cell )
|
||||||
|
C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction
|
||||||
|
if reduction and C_curr == C*4 and auxiliary:
|
||||||
|
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
|
||||||
|
self.auxiliary_index = index
|
||||||
|
self._Layer = len(self.cells)
|
||||||
|
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.classifier = nn.Linear(C_prev, num_classes)
|
||||||
|
self.drop_path_prob = -1
|
||||||
|
|
||||||
|
def update_drop_path(self, drop_path_prob):
|
||||||
|
self.drop_path_prob = drop_path_prob
|
||||||
|
|
||||||
|
def auxiliary_param(self):
|
||||||
|
if self.auxiliary_head is None: return []
|
||||||
|
else: return list( self.auxiliary_head.parameters() )
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
string = self.extra_repr()
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||||
|
return string
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
stem_feature, logits_aux = self.stem(inputs), None
|
||||||
|
cell_results = [stem_feature, stem_feature]
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
|
||||||
|
cell_results.append( cell_feature )
|
||||||
|
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training:
|
||||||
|
logits_aux = self.auxiliary_head( cell_results[-1] )
|
||||||
|
out = self.lastact(cell_results[-1])
|
||||||
|
out = self.global_pooling( out )
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
logits = self.classifier(out)
|
||||||
|
if logits_aux is None: return out, logits
|
||||||
|
else: return out, [logits, logits_aux]
|
@ -155,7 +155,7 @@ class NASNetSearchCell(nn.Module):
|
|||||||
self.edges = nn.ModuleDict()
|
self.edges = nn.ModuleDict()
|
||||||
for i in range(self._steps):
|
for i in range(self._steps):
|
||||||
for j in range(2+i):
|
for j in range(2+i):
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2)
|
||||||
stride = 2 if reduction and j < 2 else 1
|
stride = 2 if reduction and j < 2 else 1
|
||||||
op = MixedOp(space, C, stride, affine, track_running_stats)
|
op = MixedOp(space, C, stride, affine, track_running_stats)
|
||||||
self.edges[ node_str ] = op
|
self.edges[ node_str ] = op
|
||||||
|
@ -5,8 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List, Text, Dict
|
from typing import List, Text, Dict
|
||||||
from .search_cells import NASNetSearchCell as SearchCell
|
from .search_cells import NASNetSearchCell as SearchCell
|
||||||
from .genotypes import Structure
|
|
||||||
|
|
||||||
|
|
||||||
# The macro structure is based on NASNet
|
# The macro structure is based on NASNet
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from .search_cells import NASNetSearchCell as SearchCell
|
from .search_cells import NASNetSearchCell as SearchCell
|
||||||
from .genotypes import Structure
|
|
||||||
|
|
||||||
|
|
||||||
# The macro structure is based on NASNet
|
# The macro structure is based on NASNet
|
||||||
|
@ -168,5 +168,15 @@ Networks = {'DARTS_V1': DARTS_V1,
|
|||||||
'SETN' : SETN,
|
'SETN' : SETN,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# This function will return a Genotype from a dict.
|
||||||
def build_genotype_from_dict(xdict):
|
def build_genotype_from_dict(xdict):
|
||||||
import pdb; pdb.set_trace()
|
def remove_value(nodes):
|
||||||
|
return [tuple([(x[0], x[1]) for x in node]) for node in nodes]
|
||||||
|
genotype = Genotype(
|
||||||
|
normal=remove_value(xdict['normal']),
|
||||||
|
normal_concat=xdict['normal_concat'],
|
||||||
|
reduce=remove_value(xdict['reduce']),
|
||||||
|
reduce_concat=xdict['reduce_concat'],
|
||||||
|
connectN=None, connects=None
|
||||||
|
)
|
||||||
|
return genotype
|
||||||
|
@ -6,12 +6,22 @@
|
|||||||
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
|
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
|
||||||
##################################################
|
##################################################
|
||||||
|
|
||||||
import torch
|
import os, torch
|
||||||
|
|
||||||
def obtain_nas_infer_model(config):
|
def obtain_nas_infer_model(config, extra_model_path=None):
|
||||||
|
|
||||||
if config.arch == 'dxys':
|
if config.arch == 'dxys':
|
||||||
from .DXYs import CifarNet, ImageNet, Networks
|
from .DXYs import CifarNet, ImageNet, Networks
|
||||||
genotype = Networks[config.genotype]
|
from .DXYs import build_genotype_from_dict
|
||||||
|
if config.genotype is None:
|
||||||
|
if extra_model_path is not None and not os.path.isfile(extra_model_path):
|
||||||
|
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
|
||||||
|
xdata = torch.load(extra_model_path)
|
||||||
|
current_epoch = xdata['epoch']
|
||||||
|
genotype_dict = xdata['genotypes'][current_epoch-1]
|
||||||
|
genotype = build_genotype_from_dict(genotype_dict)
|
||||||
|
else:
|
||||||
|
genotype = Networks[config.genotype]
|
||||||
if config.dataset == 'cifar':
|
if config.dataset == 'cifar':
|
||||||
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
|
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
|
||||||
elif config.dataset == 'imagenet':
|
elif config.dataset == 'imagenet':
|
||||||
|
@ -4,7 +4,7 @@ echo script name: $0
|
|||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 4 ] ;then
|
if [ "$#" -ne 4 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 4 parameters for dataset and the-model-name and epochs and LR and the-batch-size and the-random-seed"
|
echo "Need 4 parameters for dataset, the-model-name, the-batch-size and the-random-seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
53
scripts/retrain-searched-net.sh
Normal file
53
scripts/retrain-searched-net.sh
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts/retrain-searched-net.sh cifar10 ${NAME} ${PATH} 256 -1
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 5 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 5 parameters for dataset, the save dir base name, the model path, the batch size, the random seed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
dataset=$1
|
||||||
|
save_name=$2
|
||||||
|
model_path=$3
|
||||||
|
batch=$4
|
||||||
|
rseed=$5
|
||||||
|
|
||||||
|
if [ ${dataset} == 'cifar10' ] || [ ${dataset} == 'cifar100' ]; then
|
||||||
|
xpath=$TORCH_HOME/cifar.python
|
||||||
|
base=CIFAR
|
||||||
|
workers=4
|
||||||
|
cutout_length=16
|
||||||
|
elif [ ${dataset} == 'imagenet-1k' ]; then
|
||||||
|
xpath=$TORCH_HOME/ILSVRC2012
|
||||||
|
base=IMAGENET
|
||||||
|
workers=28
|
||||||
|
cutout_length=-1
|
||||||
|
else
|
||||||
|
exit 1
|
||||||
|
echo 'Unknown dataset: '${dataset}
|
||||||
|
fi
|
||||||
|
|
||||||
|
SAVE_ROOT="./output"
|
||||||
|
|
||||||
|
save_dir=${SAVE_ROOT}/nas-infer/${dataset}-BS${batch}-${save_name}
|
||||||
|
|
||||||
|
python --version
|
||||||
|
|
||||||
|
python ./exps/basic-main.py --dataset ${dataset} \
|
||||||
|
--data_path ${xpath} --model_source autodl-searched \
|
||||||
|
--model_config ./configs/archs/NAS-${base}-none.config \
|
||||||
|
--optim_config ./configs/opts/NAS-${base}.config \
|
||||||
|
--extra_model_path ${model_path} \
|
||||||
|
--procedure basic \
|
||||||
|
--save_dir ${save_dir} \
|
||||||
|
--cutout_length ${cutout_length} \
|
||||||
|
--batch_size ${batch} --rand_seed ${rseed} --workers ${workers} \
|
||||||
|
--eval_frequency 1 --print_freq 500 --print_freq_eval 1000
|
Loading…
Reference in New Issue
Block a user