update codes
This commit is contained in:
		| @@ -122,6 +122,12 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 GDAS_V1 96 -1 | |||||||
| CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  | #### Searching on the NASNet search space | ||||||
|  | Please use the following scripts to use GDAS to search as in the original paper: | ||||||
|  | ``` | ||||||
|  | CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1 | ||||||
|  | ``` | ||||||
|  |  | ||||||
| #### Searching on a small search space (NAS-Bench-102) | #### Searching on a small search space (NAS-Bench-102) | ||||||
| The GDAS searching codes on a small search space: | The GDAS searching codes on a small search space: | ||||||
| ``` | ``` | ||||||
|   | |||||||
							
								
								
									
										9
									
								
								configs/search-archs/GDAS-NASNet-CIFAR.config
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								configs/search-archs/GDAS-NASNet-CIFAR.config
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | |||||||
|  | { | ||||||
|  |   "super_type"      : ["str",  "nasnet-super"], | ||||||
|  |   "name"            : ["str",  "GDAS"], | ||||||
|  |   "C"               : ["int",  "16"  ], | ||||||
|  |   "N"               : ["int",  "2"  ], | ||||||
|  |   "steps"           : ["int",  "4"  ], | ||||||
|  |   "multiplier"      : ["int",  "4"  ], | ||||||
|  |   "stem_multiplier" : ["int",  "3"  ] | ||||||
|  | } | ||||||
							
								
								
									
										13
									
								
								configs/search-opts/GDAS-NASNet-CIFAR.config
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								configs/search-opts/GDAS-NASNet-CIFAR.config
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | { | ||||||
|  |   "scheduler": ["str",   "cos"], | ||||||
|  |   "LR"       : ["float", "0.025"], | ||||||
|  |   "eta_min"  : ["float", "0.001"], | ||||||
|  |   "epochs"   : ["int",   "250"], | ||||||
|  |   "warmup"   : ["int",   "0"], | ||||||
|  |   "optim"    : ["str",   "SGD"], | ||||||
|  |   "decay"    : ["float", "0.0005"], | ||||||
|  |   "momentum" : ["float", "0.9"], | ||||||
|  |   "nesterov" : ["bool",  "1"], | ||||||
|  |   "criterion": ["str",   "Softmax"], | ||||||
|  |   "batch_size": ["int",  "256"] | ||||||
|  | } | ||||||
| @@ -88,12 +88,17 @@ def main(xargs): | |||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) |   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||||
|   model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, |   if xargs.model_config is None: | ||||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, |     model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, | ||||||
|                               'space'    : search_space, |                                 'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||||
|                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) |                                 'space'    : search_space, | ||||||
|  |                                 'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||||
|  |   else: | ||||||
|  |     model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space'    : search_space, | ||||||
|  |                                                     'affine'     : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||||
|   search_model = get_cell_based_tiny_net(model_config) |   search_model = get_cell_based_tiny_net(model_config) | ||||||
|   logger.log('search-model :\n{:}'.format(search_model)) |   logger.log('search-model :\n{:}'.format(search_model)) | ||||||
|  |   logger.log('model-config : {:}'.format(model_config)) | ||||||
|    |    | ||||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) |   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||||
|   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) |   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||||
| @@ -104,7 +109,7 @@ def main(xargs): | |||||||
|   flop, param  = get_model_infos(search_model, xshape) |   flop, param  = get_model_infos(search_model, xshape) | ||||||
|   #logger.log('{:}'.format(search_model)) |   #logger.log('{:}'.format(search_model)) | ||||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) |   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||||
|   logger.log('search-space : {:}'.format(search_space)) |   logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) | ||||||
|   if xargs.arch_nas_dataset is None: |   if xargs.arch_nas_dataset is None: | ||||||
|     api = None |     api = None | ||||||
|   else: |   else: | ||||||
| @@ -173,7 +178,7 @@ def main(xargs): | |||||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) |       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) | ||||||
|       copy_checkpoint(model_base_path, model_best_path, logger) |       copy_checkpoint(model_base_path, model_best_path, logger) | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) |       logger.log('{:}'.format(search_model.show_alphas())) | ||||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) |     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||||
|     # measure elapsed time |     # measure elapsed time | ||||||
|     epoch_time.update(time.time() - start_time) |     epoch_time.update(time.time() - start_time) | ||||||
| @@ -198,6 +203,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||||
|   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') |   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||||
|   parser.add_argument('--config_path',        type=str,   help='The path of the configuration.') |   parser.add_argument('--config_path',        type=str,   help='The path of the configuration.') | ||||||
|  |   parser.add_argument('--model_config',       type=str,   help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.') | ||||||
|   # architecture leraning rate |   # architecture leraning rate | ||||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') |   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') |   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||||
|   | |||||||
| @@ -13,20 +13,21 @@ from config_utils import dict2config | |||||||
| from .SharedUtils import change_key | from .SharedUtils import change_key | ||||||
| from .cell_searchs import CellStructure, CellArchitectures | from .cell_searchs import CellStructure, CellArchitectures | ||||||
|  |  | ||||||
|  |  | ||||||
| # Cell-based NAS Models | # Cell-based NAS Models | ||||||
| def get_cell_based_tiny_net(config): | def get_cell_based_tiny_net(config): | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |   super_type = getattr(config, 'super_type', 'basic') | ||||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] |   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] | ||||||
|   if super_type == 'basic' and config.name in group_names: |   if super_type == 'basic' and config.name in group_names: | ||||||
|     from .cell_searchs import nas_super_nets |     from .cell_searchs import nas102_super_nets as nas_super_nets | ||||||
|     try: |     try: | ||||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) |       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||||
|     except: |     except: | ||||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) |       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||||
|   elif super_type == 'l2s-base' and config.name in group_names: |   elif super_type == 'nasnet-super': | ||||||
|     from .l2s_cell_searchs import nas_super_nets |     from .cell_searchs import nasnet_super_nets as nas_super_nets | ||||||
|     return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space \ |     return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \ | ||||||
|                                       ,config.n_piece) |                     config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||||
|   elif config.name == 'infer.tiny': |   elif config.name == 'infer.tiny': | ||||||
|     from .cell_infers import TinyNetwork |     from .cell_infers import TinyNetwork | ||||||
|     return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) |     return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) | ||||||
|   | |||||||
| @@ -28,7 +28,6 @@ SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | |||||||
|                     'aa-nas'       : NAS_BENCH_102, |                     'aa-nas'       : NAS_BENCH_102, | ||||||
|                     'nas-bench-102': NAS_BENCH_102, |                     'nas-bench-102': NAS_BENCH_102, | ||||||
|                     'darts'        : DARTS_SPACE} |                     'darts'        : DARTS_SPACE} | ||||||
|                     #'full'         : sorted(list(OPS.keys()))} |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ReLUConvBN(nn.Module): | class ReLUConvBN(nn.Module): | ||||||
|   | |||||||
| @@ -1,16 +1,22 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ################################################## | ||||||
|  | # The macro structure is defined in NAS-Bench-102 | ||||||
| from .search_model_darts    import TinyNetworkDarts | from .search_model_darts    import TinyNetworkDarts | ||||||
| from .search_model_gdas     import TinyNetworkGDAS | from .search_model_gdas     import TinyNetworkGDAS | ||||||
| from .search_model_setn     import TinyNetworkSETN | from .search_model_setn     import TinyNetworkSETN | ||||||
| from .search_model_enas     import TinyNetworkENAS | from .search_model_enas     import TinyNetworkENAS | ||||||
| from .search_model_random   import TinyNetworkRANDOM | from .search_model_random   import TinyNetworkRANDOM | ||||||
| from .genotypes             import Structure as CellStructure, architectures as CellArchitectures | from .genotypes             import Structure as CellStructure, architectures as CellArchitectures | ||||||
|  | # NASNet-based macro structure | ||||||
|  | from .search_model_gdas_nasnet import NASNetworkGDAS | ||||||
|  |  | ||||||
| nas_super_nets = {'DARTS-V1': TinyNetworkDarts, |  | ||||||
|  | nas102_super_nets = {'DARTS-V1': TinyNetworkDarts, | ||||||
|                   'DARTS-V2': TinyNetworkDarts, |                   'DARTS-V2': TinyNetworkDarts, | ||||||
|                   'GDAS'    : TinyNetworkGDAS, |                   'GDAS'    : TinyNetworkGDAS, | ||||||
|                   'SETN'    : TinyNetworkSETN, |                   'SETN'    : TinyNetworkSETN, | ||||||
|                   'ENAS'    : TinyNetworkENAS, |                   'ENAS'    : TinyNetworkENAS, | ||||||
|                   'RANDOM'  : TinyNetworkRANDOM} |                   'RANDOM'  : TinyNetworkRANDOM} | ||||||
|  |  | ||||||
|  | nasnet_super_nets = {'GDAS' : NASNetworkGDAS} | ||||||
|   | |||||||
| @@ -9,10 +9,11 @@ from copy import deepcopy | |||||||
| from ..cell_operations import OPS | from ..cell_operations import OPS | ||||||
|  |  | ||||||
|  |  | ||||||
| class SearchCell(nn.Module): | # This module is used for NAS-Bench-102, represents a small search space with a complete DAG | ||||||
|  | class NAS102SearchCell(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): |   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): | ||||||
|     super(SearchCell, self).__init__() |     super(NAS102SearchCell, self).__init__() | ||||||
|  |  | ||||||
|     self.op_names  = deepcopy(op_names) |     self.op_names  = deepcopy(op_names) | ||||||
|     self.edges     = nn.ModuleDict() |     self.edges     = nn.ModuleDict() | ||||||
| @@ -74,7 +75,7 @@ class SearchCell(nn.Module): | |||||||
|       nodes.append( sum(inter_nodes) ) |       nodes.append( sum(inter_nodes) ) | ||||||
|     return nodes[-1] |     return nodes[-1] | ||||||
|  |  | ||||||
|   # uniform random sampling per iteration |   # uniform random sampling per iteration, SETN | ||||||
|   def forward_urs(self, inputs): |   def forward_urs(self, inputs): | ||||||
|     nodes = [inputs] |     nodes = [inputs] | ||||||
|     for i in range(1, self.max_nodes): |     for i in range(1, self.max_nodes): | ||||||
| @@ -118,3 +119,61 @@ class SearchCell(nn.Module): | |||||||
|         inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) |         inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) | ||||||
|       nodes.append( sum(inter_nodes) ) |       nodes.append( sum(inter_nodes) ) | ||||||
|     return nodes[-1] |     return nodes[-1] | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MixedOp(nn.Module): | ||||||
|  |  | ||||||
|  |   def __init__(self, space, C, stride, affine, track_running_stats): | ||||||
|  |     super(MixedOp, self).__init__() | ||||||
|  |     self._ops = nn.ModuleList() | ||||||
|  |     for primitive in space: | ||||||
|  |       op = OPS[primitive](C, C, stride, affine, track_running_stats) | ||||||
|  |       self._ops.append(op) | ||||||
|  |  | ||||||
|  |   def forward(self, x, weights, index): | ||||||
|  |     #return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||||||
|  |     return self._ops[index](x) * weights[index] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||||
|  | class NASNetSearchCell(nn.Module): | ||||||
|  |  | ||||||
|  |   def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): | ||||||
|  |     super(NASNetSearchCell, self).__init__() | ||||||
|  |     self.reduction = reduction | ||||||
|  |     self.op_names  = deepcopy(space) | ||||||
|  |     if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) | ||||||
|  |     else             : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) | ||||||
|  |     self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) | ||||||
|  |     self._steps = steps | ||||||
|  |     self._multiplier = multiplier | ||||||
|  |  | ||||||
|  |     self._ops = nn.ModuleList() | ||||||
|  |     self.edges     = nn.ModuleDict() | ||||||
|  |     for i in range(self._steps): | ||||||
|  |       for j in range(2+i): | ||||||
|  |         node_str = '{:}<-{:}'.format(i, j) | ||||||
|  |         stride = 2 if reduction and j < 2 else 1 | ||||||
|  |         op = MixedOp(space, C, stride, affine, track_running_stats) | ||||||
|  |         self.edges[ node_str ] = op | ||||||
|  |     self.edge_keys  = sorted(list(self.edges.keys())) | ||||||
|  |     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} | ||||||
|  |     self.num_edges  = len(self.edges) | ||||||
|  |  | ||||||
|  |   def forward_gdas(self, s0, s1, weightss, indexs): | ||||||
|  |     s0 = self.preprocess0(s0) | ||||||
|  |     s1 = self.preprocess1(s1) | ||||||
|  |  | ||||||
|  |     states = [s0, s1] | ||||||
|  |     for i in range(self._steps): | ||||||
|  |       clist = [] | ||||||
|  |       for j, h in enumerate(states): | ||||||
|  |         node_str = '{:}<-{:}'.format(i, j) | ||||||
|  |         op = self.edges[ node_str ] | ||||||
|  |         weights = weightss[ self.edge2index[node_str] ] | ||||||
|  |         index   = indexs[ self.edge2index[node_str] ].item() | ||||||
|  |         clist.append( op(h, weights, index) ) | ||||||
|  |       states.append( sum(clist) ) | ||||||
|  |  | ||||||
|  |     return torch.cat(states[-self._multiplier:], dim=1) | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from ..cell_operations import ResNetBasicblock | from ..cell_operations import ResNetBasicblock | ||||||
| from .search_cells     import SearchCell | from .search_cells     import NAS102SearchCell as SearchCell | ||||||
| from .genotypes        import Structure | from .genotypes        import Structure | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from ..cell_operations import ResNetBasicblock | from ..cell_operations import ResNetBasicblock | ||||||
| from .search_cells     import SearchCell | from .search_cells     import NAS102SearchCell as SearchCell | ||||||
| from .genotypes        import Structure | from .genotypes        import Structure | ||||||
| from .search_model_enas_utils import Controller | from .search_model_enas_utils import Controller | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from ..cell_operations import ResNetBasicblock | from ..cell_operations import ResNetBasicblock | ||||||
| from .search_cells     import SearchCell | from .search_cells     import NAS102SearchCell as SearchCell | ||||||
| from .genotypes        import Structure | from .genotypes        import Structure | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -59,6 +59,10 @@ class TinyNetworkGDAS(nn.Module): | |||||||
|   def get_alphas(self): |   def get_alphas(self): | ||||||
|     return [self.arch_parameters] |     return [self.arch_parameters] | ||||||
|  |  | ||||||
|  |   def show_alphas(self): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |       return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) | ||||||
|  |  | ||||||
|   def get_message(self): |   def get_message(self): | ||||||
|     string = self.extra_repr() |     string = self.extra_repr() | ||||||
|     for i, cell in enumerate(self.cells): |     for i, cell in enumerate(self.cells): | ||||||
|   | |||||||
							
								
								
									
										126
									
								
								lib/models/cell_searchs/search_model_gdas_nasnet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								lib/models/cell_searchs/search_model_gdas_nasnet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | |||||||
|  | ########################################################################### | ||||||
|  | # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||||
|  | ########################################################################### | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from copy import deepcopy | ||||||
|  | from .search_cells     import NASNetSearchCell as SearchCell | ||||||
|  | from .genotypes        import Structure | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # The macro structure is based on NASNet | ||||||
|  | class NASNetworkGDAS(nn.Module): | ||||||
|  |  | ||||||
|  |   def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): | ||||||
|  |     super(NASNetworkGDAS, self).__init__() | ||||||
|  |     self._C        = C | ||||||
|  |     self._layerN   = N | ||||||
|  |     self._steps    = steps | ||||||
|  |     self._multiplier = multiplier | ||||||
|  |     self.stem = nn.Sequential( | ||||||
|  |                     nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||||
|  |                     nn.BatchNorm2d(C*stem_multiplier)) | ||||||
|  |    | ||||||
|  |     # config for each layer | ||||||
|  |     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * (N-1) + [C*4 ] + [C*4  ] * (N-1) | ||||||
|  |     layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) | ||||||
|  |  | ||||||
|  |     num_edge, edge2index = None, None | ||||||
|  |     C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False | ||||||
|  |  | ||||||
|  |     self.cells = nn.ModuleList() | ||||||
|  |     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||||
|  |       cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) | ||||||
|  |       if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index | ||||||
|  |       else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) | ||||||
|  |       self.cells.append( cell ) | ||||||
|  |       C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction | ||||||
|  |     self.op_names   = deepcopy( search_space ) | ||||||
|  |     self._Layer     = len(self.cells) | ||||||
|  |     self.edge2index = edge2index | ||||||
|  |     self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||||
|  |     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||||
|  |     self.classifier = nn.Linear(C_prev, num_classes) | ||||||
|  |     self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||||
|  |     self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||||
|  |     self.tau        = 10 | ||||||
|  |  | ||||||
|  |   def get_weights(self): | ||||||
|  |     xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) | ||||||
|  |     xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) | ||||||
|  |     xlist+= list( self.classifier.parameters() ) | ||||||
|  |     return xlist | ||||||
|  |  | ||||||
|  |   def set_tau(self, tau): | ||||||
|  |     self.tau = tau | ||||||
|  |  | ||||||
|  |   def get_tau(self): | ||||||
|  |     return self.tau | ||||||
|  |  | ||||||
|  |   def get_alphas(self): | ||||||
|  |     return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||||
|  |  | ||||||
|  |   def show_alphas(self): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |       A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) | ||||||
|  |       B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) | ||||||
|  |     return '{:}\n{:}'.format(A, B) | ||||||
|  |  | ||||||
|  |   def get_message(self): | ||||||
|  |     string = self.extra_repr() | ||||||
|  |     for i, cell in enumerate(self.cells): | ||||||
|  |       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||||
|  |     return string | ||||||
|  |  | ||||||
|  |   def extra_repr(self): | ||||||
|  |     return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||||
|  |  | ||||||
|  |   def genotype(self): | ||||||
|  |     def _parse(weights): | ||||||
|  |       gene = [] | ||||||
|  |       for i in range(self._steps): | ||||||
|  |         edges = [] | ||||||
|  |         for j in range(2+i): | ||||||
|  |           node_str = '{:}<-{:}'.format(i, j) | ||||||
|  |           ws = weights[ self.edge2index[node_str] ] | ||||||
|  |           for k, op_name in enumerate(self.op_names): | ||||||
|  |             if op_name == 'none': continue | ||||||
|  |             edges.append( (op_name, j, ws[k]) ) | ||||||
|  |         edges = sorted(edges, key=lambda x: -x[-1]) | ||||||
|  |         selected_edges = edges[:2] | ||||||
|  |         gene.append( tuple(selected_edges) ) | ||||||
|  |       return gene | ||||||
|  |     with torch.no_grad(): | ||||||
|  |       gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()) | ||||||
|  |       gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()) | ||||||
|  |     return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)), | ||||||
|  |             'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))} | ||||||
|  |  | ||||||
|  |   def forward(self, inputs): | ||||||
|  |     def get_gumbel_prob(xins): | ||||||
|  |       while True: | ||||||
|  |         gumbels = -torch.empty_like(xins).exponential_().log() | ||||||
|  |         logits  = (xins.log_softmax(dim=1) + gumbels) / self.tau | ||||||
|  |         probs   = nn.functional.softmax(logits, dim=1) | ||||||
|  |         index   = probs.max(-1, keepdim=True)[1] | ||||||
|  |         one_h   = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||||
|  |         hardwts = one_h - probs.detach() + probs | ||||||
|  |         if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): | ||||||
|  |           continue | ||||||
|  |         else: break | ||||||
|  |       return hardwts, index | ||||||
|  |  | ||||||
|  |     normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters) | ||||||
|  |     reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters) | ||||||
|  |  | ||||||
|  |     s0 = s1 = self.stem(inputs) | ||||||
|  |     for i, cell in enumerate(self.cells): | ||||||
|  |       if cell.reduction: hardwts, index = reduce_hardwts, reduce_index | ||||||
|  |       else             : hardwts, index = normal_hardwts, normal_index | ||||||
|  |       s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) | ||||||
|  |     out = self.lastact(s1) | ||||||
|  |     out = self.global_pooling( out ) | ||||||
|  |     out = out.view(out.size(0), -1) | ||||||
|  |     logits = self.classifier(out) | ||||||
|  |  | ||||||
|  |     return out, logits | ||||||
| @@ -7,7 +7,7 @@ import torch, random | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from ..cell_operations import ResNetBasicblock | from ..cell_operations import ResNetBasicblock | ||||||
| from .search_cells     import SearchCell | from .search_cells     import NAS102SearchCell as SearchCell | ||||||
| from .genotypes        import Structure | from .genotypes        import Structure | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ import torch, random | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from ..cell_operations import ResNetBasicblock | from ..cell_operations import ResNetBasicblock | ||||||
| from .search_cells     import SearchCell | from .search_cells     import NAS102SearchCell as SearchCell | ||||||
| from .genotypes        import Structure | from .genotypes        import Structure | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										38
									
								
								scripts-search/GDAS-search-NASNet-space.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								scripts-search/GDAS-search-NASNet-space.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1 | ||||||
|  | echo script name: $0 | ||||||
|  | echo $# arguments | ||||||
|  | if [ "$#" -ne 3 ] ;then | ||||||
|  |   echo "Input illegal number of parameters " $# | ||||||
|  |   echo "Need 3 parameters for dataset, track_running_stats, and seed" | ||||||
|  |   exit 1 | ||||||
|  | fi | ||||||
|  | if [ "$TORCH_HOME" = "" ]; then | ||||||
|  |   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||||
|  |   exit 1 | ||||||
|  | else | ||||||
|  |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | dataset=$1 | ||||||
|  | track_running_stats=$2 | ||||||
|  | seed=$3 | ||||||
|  | space=darts | ||||||
|  |  | ||||||
|  | if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | ||||||
|  |   data_path="$TORCH_HOME/cifar.python" | ||||||
|  | else | ||||||
|  |   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${track_running_stats} | ||||||
|  |  | ||||||
|  | OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \ | ||||||
|  | 	--save_dir ${save_dir} \ | ||||||
|  | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
|  | 	--search_space_name ${space} \ | ||||||
|  | 	--config_path  configs/search-opts/GDAS-NASNet-CIFAR.config \ | ||||||
|  | 	--model_config configs/search-archs/GDAS-NASNet-CIFAR.config \ | ||||||
|  | 	--tau_max 10 --tau_min 0.1 --track_running_stats ${track_running_stats} \ | ||||||
|  | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
|  | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
		Reference in New Issue
	
	Block a user