change batchsize in DARTS-NASNet to 64 ; add some type checking
This commit is contained in:
		| @@ -9,5 +9,5 @@ | ||||
|   "momentum" : ["float", "0.9"], | ||||
|   "nesterov" : ["bool",  "1"], | ||||
|   "criterion": ["str",   "Softmax"], | ||||
|   "batch_size": ["int",  "256"] | ||||
|   "batch_size": ["int",  "64"] | ||||
| } | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from os import path as osp | ||||
| from typing import List, Text | ||||
|  | ||||
| __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | ||||
|            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ | ||||
| @@ -42,7 +43,7 @@ def get_cell_based_tiny_net(config): | ||||
|  | ||||
|  | ||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||
| def get_search_spaces(xtype, name): | ||||
| def get_search_spaces(xtype, name) -> List[Text]: | ||||
|   if xtype == 'cell': | ||||
|     from .cell_operations import SearchSpaceNames | ||||
|     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import List, Text, Dict | ||||
| from .search_cells     import NASNetSearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
|  | ||||
| @@ -11,7 +12,7 @@ from .genotypes        import Structure | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkDARTS(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): | ||||
|   def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): | ||||
|     super(NASNetworkDARTS, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
| @@ -44,31 +45,31 @@ class NASNetworkDARTS(nn.Module): | ||||
|     self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||
|     self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||
|  | ||||
|   def get_weights(self): | ||||
|   def get_weights(self) -> List[torch.nn.Parameter]: | ||||
|     xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) | ||||
|     xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) | ||||
|     xlist+= list( self.classifier.parameters() ) | ||||
|     return xlist | ||||
|  | ||||
|   def get_alphas(self): | ||||
|   def get_alphas(self) -> List[torch.nn.Parameter]: | ||||
|     return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||
|  | ||||
|   def show_alphas(self): | ||||
|   def show_alphas(self) -> Text: | ||||
|     with torch.no_grad(): | ||||
|       A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() ) | ||||
|       B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() ) | ||||
|     return '{:}\n{:}'.format(A, B) | ||||
|  | ||||
|   def get_message(self): | ||||
|   def get_message(self) -> Text: | ||||
|     string = self.extra_repr() | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||
|     return string | ||||
|  | ||||
|   def extra_repr(self): | ||||
|   def extra_repr(self) -> Text: | ||||
|     return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def genotype(self): | ||||
|   def genotype(self) -> Dict[Text, List]: | ||||
|     def _parse(weights): | ||||
|       gene = [] | ||||
|       for i in range(self._steps): | ||||
|   | ||||
| @@ -37,9 +37,12 @@ def print_information(information, extra_info=None, show=False): | ||||
|   if show: print('\n'.join(strings)) | ||||
|   return strings | ||||
|  | ||||
|  | ||||
| """ | ||||
| This is the class for API of NAS-Bench-201. | ||||
| """ | ||||
| class NASBench201API(object): | ||||
|  | ||||
|   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ | ||||
|   def __init__(self, file_path_or_dict, verbose=True): | ||||
|     if isinstance(file_path_or_dict, str): | ||||
|       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) | ||||
| @@ -49,6 +52,7 @@ class NASBench201API(object): | ||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||
|     else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict))) | ||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||
|     self.verbose = verbose # [TODO] a flag indicating whether to print more logs | ||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||
|     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user