change batchsize in DARTS-NASNet to 64 ; add some type checking
This commit is contained in:
parent
923b0fe9cf
commit
1efe3cbccf
@ -9,5 +9,5 @@
|
|||||||
"momentum" : ["float", "0.9"],
|
"momentum" : ["float", "0.9"],
|
||||||
"nesterov" : ["bool", "1"],
|
"nesterov" : ["bool", "1"],
|
||||||
"criterion": ["str", "Softmax"],
|
"criterion": ["str", "Softmax"],
|
||||||
"batch_size": ["int", "256"]
|
"batch_size": ["int", "64"]
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
from os import path as osp
|
from os import path as osp
|
||||||
|
from typing import List, Text
|
||||||
|
|
||||||
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
|
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
|
||||||
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
|
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
|
||||||
@ -42,7 +43,7 @@ def get_cell_based_tiny_net(config):
|
|||||||
|
|
||||||
|
|
||||||
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||||
def get_search_spaces(xtype, name):
|
def get_search_spaces(xtype, name) -> List[Text]:
|
||||||
if xtype == 'cell':
|
if xtype == 'cell':
|
||||||
from .cell_operations import SearchSpaceNames
|
from .cell_operations import SearchSpaceNames
|
||||||
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
|
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import List, Text, Dict
|
||||||
from .search_cells import NASNetSearchCell as SearchCell
|
from .search_cells import NASNetSearchCell as SearchCell
|
||||||
from .genotypes import Structure
|
from .genotypes import Structure
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ from .genotypes import Structure
|
|||||||
# The macro structure is based on NASNet
|
# The macro structure is based on NASNet
|
||||||
class NASNetworkDARTS(nn.Module):
|
class NASNetworkDARTS(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats):
|
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
|
||||||
super(NASNetworkDARTS, self).__init__()
|
super(NASNetworkDARTS, self).__init__()
|
||||||
self._C = C
|
self._C = C
|
||||||
self._layerN = N
|
self._layerN = N
|
||||||
@ -44,31 +45,31 @@ class NASNetworkDARTS(nn.Module):
|
|||||||
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||||
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||||
|
|
||||||
def get_weights(self):
|
def get_weights(self) -> List[torch.nn.Parameter]:
|
||||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||||
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
||||||
xlist+= list( self.classifier.parameters() )
|
xlist+= list( self.classifier.parameters() )
|
||||||
return xlist
|
return xlist
|
||||||
|
|
||||||
def get_alphas(self):
|
def get_alphas(self) -> List[torch.nn.Parameter]:
|
||||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||||
|
|
||||||
def show_alphas(self):
|
def show_alphas(self) -> Text:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
|
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
|
||||||
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
|
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
|
||||||
return '{:}\n{:}'.format(A, B)
|
return '{:}\n{:}'.format(A, B)
|
||||||
|
|
||||||
def get_message(self):
|
def get_message(self) -> Text:
|
||||||
string = self.extra_repr()
|
string = self.extra_repr()
|
||||||
for i, cell in enumerate(self.cells):
|
for i, cell in enumerate(self.cells):
|
||||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self) -> Text:
|
||||||
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
def genotype(self):
|
def genotype(self) -> Dict[Text, List]:
|
||||||
def _parse(weights):
|
def _parse(weights):
|
||||||
gene = []
|
gene = []
|
||||||
for i in range(self._steps):
|
for i in range(self._steps):
|
||||||
|
@ -37,9 +37,12 @@ def print_information(information, extra_info=None, show=False):
|
|||||||
if show: print('\n'.join(strings))
|
if show: print('\n'.join(strings))
|
||||||
return strings
|
return strings
|
||||||
|
|
||||||
|
"""
|
||||||
|
This is the class for API of NAS-Bench-201.
|
||||||
|
"""
|
||||||
class NASBench201API(object):
|
class NASBench201API(object):
|
||||||
|
|
||||||
|
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||||
def __init__(self, file_path_or_dict, verbose=True):
|
def __init__(self, file_path_or_dict, verbose=True):
|
||||||
if isinstance(file_path_or_dict, str):
|
if isinstance(file_path_or_dict, str):
|
||||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||||
@ -49,6 +52,7 @@ class NASBench201API(object):
|
|||||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||||
|
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
|
||||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||||
|
Loading…
Reference in New Issue
Block a user