beta-0.1
This commit is contained in:
		| @@ -3,10 +3,16 @@ | ||||
| ################################################## | ||||
| import torch | ||||
| from os import path as osp | ||||
|  | ||||
| __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | ||||
|            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ | ||||
|            'CellStructure', 'CellArchitectures' | ||||
|            ] | ||||
|  | ||||
| # useful modules | ||||
| from config_utils import dict2config | ||||
| from .SharedUtils import change_key | ||||
| from .clone_weights import init_from_model | ||||
| from .cell_searchs import CellStructure, CellArchitectures | ||||
|  | ||||
| # Cell-based NAS Models | ||||
| def get_cell_based_tiny_net(config): | ||||
| @@ -22,9 +28,13 @@ def get_cell_based_tiny_net(config): | ||||
|   elif config.name == 'SETN': | ||||
|     from .cell_searchs import TinyNetworkSETN | ||||
|     return TinyNetworkSETN(config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   elif config.name == 'infer.tiny': | ||||
|     from .cell_infers import TinyNetwork | ||||
|     return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) | ||||
|   else: | ||||
|     raise ValueError('invalid network name : {:}'.format(config.name)) | ||||
|  | ||||
|  | ||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||
| def get_search_spaces(xtype, name): | ||||
|   if xtype == 'cell': | ||||
|   | ||||
							
								
								
									
										1
									
								
								lib/models/cell_infers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								lib/models/cell_infers/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| from .tiny_network import TinyNetwork | ||||
							
								
								
									
										51
									
								
								lib/models/cell_infers/cells.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								lib/models/cell_infers/cells.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import OPS | ||||
|  | ||||
|  | ||||
| class InferCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_in, C_out, stride): | ||||
|     super(InferCell, self).__init__() | ||||
|  | ||||
|     self.layers  = nn.ModuleList() | ||||
|     self.node_IN = [] | ||||
|     self.node_IX = [] | ||||
|     self.genotype = deepcopy(genotype) | ||||
|     for i in range(1, len(genotype)): | ||||
|       node_info = genotype[i-1] | ||||
|       cur_index = [] | ||||
|       cur_innod = [] | ||||
|       for (op_name, op_in) in node_info: | ||||
|         if op_in == 0: | ||||
|           layer = OPS[op_name](C_in , C_out, stride) | ||||
|         else: | ||||
|           layer = OPS[op_name](C_out, C_out,      1) | ||||
|         cur_index.append( len(self.layers) ) | ||||
|         cur_innod.append( op_in ) | ||||
|         self.layers.append( layer ) | ||||
|       self.node_IX.append( cur_index ) | ||||
|       self.node_IN.append( cur_innod ) | ||||
|     self.nodes   = len(genotype) | ||||
|     self.in_dim  = C_in | ||||
|     self.out_dim = C_out | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) | ||||
|     laystr = [] | ||||
|     for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): | ||||
|       y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)] | ||||
|       x = '{:}<-({:})'.format(i+1, ','.join(y)) | ||||
|       laystr.append( x ) | ||||
|     return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr()) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     nodes = [inputs] | ||||
|     for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): | ||||
|       node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) | ||||
|       nodes.append( node_feature ) | ||||
|     return nodes[-1] | ||||
							
								
								
									
										58
									
								
								lib/models/cell_infers/tiny_network.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								lib/models/cell_infers/tiny_network.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .cells import InferCell | ||||
|  | ||||
|  | ||||
| class TinyNetwork(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, genotype, num_classes): | ||||
|     super(TinyNetwork, self).__init__() | ||||
|     self._C               = C | ||||
|     self._layerN          = N | ||||
|  | ||||
|     self.stem = nn.Sequential( | ||||
|                     nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), | ||||
|                     nn.BatchNorm2d(C)) | ||||
|    | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|     C_prev = C | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       if reduction: | ||||
|         cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|       else: | ||||
|         cell = InferCell(genotype, C_prev, C_curr, 1) | ||||
|       self.cells.append( cell ) | ||||
|       C_prev = cell.out_dim | ||||
|     self._Layer= len(self.cells) | ||||
|  | ||||
|     self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|  | ||||
|   def get_message(self): | ||||
|     string = self.extra_repr() | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||
|     return string | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     feature = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       feature = cell(feature) | ||||
|  | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling( out ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     return out, logits | ||||
| @@ -17,7 +17,8 @@ CONNECT_NAS_BENCHMARK  = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||
| AA_NAS_BENCHMARK       = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||
|  | ||||
| SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, | ||||
|                     'aa-nas'      : AA_NAS_BENCHMARK} | ||||
|                     'aa-nas'      : AA_NAS_BENCHMARK, | ||||
|                     'full'        : sorted(list(OPS.keys()))} | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|   | ||||
| @@ -2,3 +2,4 @@ from .search_model_darts_v1 import TinyNetworkDartsV1 | ||||
| from .search_model_darts_v2 import TinyNetworkDartsV2 | ||||
| from .search_model_gdas     import TinyNetworkGDAS | ||||
| from .search_model_setn     import TinyNetworkSETN | ||||
| from .genotypes             import Structure as CellStructure, architectures as CellArchitectures | ||||
|   | ||||
| @@ -60,6 +60,13 @@ class Structure: | ||||
|       strings.append( string ) | ||||
|     return '+'.join(strings) | ||||
|  | ||||
|   def check_valid_op(self, op_names): | ||||
|     for node_info in self.nodes: | ||||
|       for inode_edge in node_info: | ||||
|         #assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) | ||||
|         if inode_edge[0] not in op_names: return False | ||||
|     return True | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__)) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user