change batchsize in DARTS-NASNet to 64 ; add some type checking
This commit is contained in:
		| @@ -9,5 +9,5 @@ | |||||||
|   "momentum" : ["float", "0.9"], |   "momentum" : ["float", "0.9"], | ||||||
|   "nesterov" : ["bool",  "1"], |   "nesterov" : ["bool",  "1"], | ||||||
|   "criterion": ["str",   "Softmax"], |   "criterion": ["str",   "Softmax"], | ||||||
|   "batch_size": ["int",  "256"] |   "batch_size": ["int",  "64"] | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ################################################## | ||||||
| from os import path as osp | from os import path as osp | ||||||
|  | from typing import List, Text | ||||||
|  |  | ||||||
| __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | ||||||
|            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ |            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ | ||||||
| @@ -42,7 +43,7 @@ def get_cell_based_tiny_net(config): | |||||||
|  |  | ||||||
|  |  | ||||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||||
| def get_search_spaces(xtype, name): | def get_search_spaces(xtype, name) -> List[Text]: | ||||||
|   if xtype == 'cell': |   if xtype == 'cell': | ||||||
|     from .cell_operations import SearchSpaceNames |     from .cell_operations import SearchSpaceNames | ||||||
|     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) |     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ | |||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  | from typing import List, Text, Dict | ||||||
| from .search_cells     import NASNetSearchCell as SearchCell | from .search_cells     import NASNetSearchCell as SearchCell | ||||||
| from .genotypes        import Structure | from .genotypes        import Structure | ||||||
|  |  | ||||||
| @@ -11,7 +12,7 @@ from .genotypes        import Structure | |||||||
| # The macro structure is based on NASNet | # The macro structure is based on NASNet | ||||||
| class NASNetworkDARTS(nn.Module): | class NASNetworkDARTS(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): |   def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): | ||||||
|     super(NASNetworkDARTS, self).__init__() |     super(NASNetworkDARTS, self).__init__() | ||||||
|     self._C        = C |     self._C        = C | ||||||
|     self._layerN   = N |     self._layerN   = N | ||||||
| @@ -44,31 +45,31 @@ class NASNetworkDARTS(nn.Module): | |||||||
|     self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) |     self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||||
|     self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) |     self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||||
|  |  | ||||||
|   def get_weights(self): |   def get_weights(self) -> List[torch.nn.Parameter]: | ||||||
|     xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) |     xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) | ||||||
|     xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) |     xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) | ||||||
|     xlist+= list( self.classifier.parameters() ) |     xlist+= list( self.classifier.parameters() ) | ||||||
|     return xlist |     return xlist | ||||||
|  |  | ||||||
|   def get_alphas(self): |   def get_alphas(self) -> List[torch.nn.Parameter]: | ||||||
|     return [self.arch_normal_parameters, self.arch_reduce_parameters] |     return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||||
|  |  | ||||||
|   def show_alphas(self): |   def show_alphas(self) -> Text: | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|       A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) |       A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) | ||||||
|       B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) |       B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) | ||||||
|     return '{:}\n{:}'.format(A, B) |     return '{:}\n{:}'.format(A, B) | ||||||
|  |  | ||||||
|   def get_message(self): |   def get_message(self) -> Text: | ||||||
|     string = self.extra_repr() |     string = self.extra_repr() | ||||||
|     for i, cell in enumerate(self.cells): |     for i, cell in enumerate(self.cells): | ||||||
|       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) |       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||||
|     return string |     return string | ||||||
|  |  | ||||||
|   def extra_repr(self): |   def extra_repr(self) -> Text: | ||||||
|     return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) |     return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||||
|  |  | ||||||
|   def genotype(self): |   def genotype(self) -> Dict[Text, List]: | ||||||
|     def _parse(weights): |     def _parse(weights): | ||||||
|       gene = [] |       gene = [] | ||||||
|       for i in range(self._steps): |       for i in range(self._steps): | ||||||
|   | |||||||
| @@ -37,9 +37,12 @@ def print_information(information, extra_info=None, show=False): | |||||||
|   if show: print('\n'.join(strings)) |   if show: print('\n'.join(strings)) | ||||||
|   return strings |   return strings | ||||||
|  |  | ||||||
|  | """ | ||||||
|  | This is the class for API of NAS-Bench-201. | ||||||
|  | """ | ||||||
| class NASBench201API(object): | class NASBench201API(object): | ||||||
|  |  | ||||||
|  |   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ | ||||||
|   def __init__(self, file_path_or_dict, verbose=True): |   def __init__(self, file_path_or_dict, verbose=True): | ||||||
|     if isinstance(file_path_or_dict, str): |     if isinstance(file_path_or_dict, str): | ||||||
|       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) |       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) | ||||||
| @@ -49,6 +52,7 @@ class NASBench201API(object): | |||||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) |       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||||
|     else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict))) |     else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict))) | ||||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) |     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||||
|  |     self.verbose = verbose # [TODO] a flag indicating whether to print more logs | ||||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') |     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) |     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||||
|     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) |     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user