update NAS-Bench
This commit is contained in:
		| @@ -14,11 +14,11 @@ from .search_model_darts_nasnet import NASNetworkDARTS | ||||
|  | ||||
|  | ||||
| nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, | ||||
|                   'DARTS-V2': TinyNetworkDarts, | ||||
|                   'GDAS'    : TinyNetworkGDAS, | ||||
|                   'SETN'    : TinyNetworkSETN, | ||||
|                   'ENAS'    : TinyNetworkENAS, | ||||
|                   'RANDOM'  : TinyNetworkRANDOM} | ||||
|                      "DARTS-V2": TinyNetworkDarts, | ||||
|                      "GDAS": TinyNetworkGDAS, | ||||
|                      "SETN": TinyNetworkSETN, | ||||
|                      "ENAS": TinyNetworkENAS, | ||||
|                      "RANDOM": TinyNetworkRANDOM} | ||||
|  | ||||
| nasnet_super_nets = {'GDAS' : NASNetworkGDAS, | ||||
|                      'DARTS': NASNetworkDARTS} | ||||
| nasnet_super_nets = {"GDAS": NASNetworkGDAS, | ||||
|                      "DARTS": NASNetworkDARTS} | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| #################### | ||||
| # DARTS, ICLR 2019 #  | ||||
| # DARTS, ICLR 2019 # | ||||
| #################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| @@ -11,7 +11,8 @@ from .search_cells import NASNetSearchCell as SearchCell | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkDARTS(nn.Module): | ||||
|  | ||||
|   def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): | ||||
|   def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, | ||||
|                num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): | ||||
|     super(NASNetworkDARTS, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
|   | ||||
| @@ -6,14 +6,15 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import List, Text, Dict | ||||
| from .search_cells     import NASNetSearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkSETN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): | ||||
|   def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, | ||||
|                num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): | ||||
|     super(NASNetworkSETN, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
| @@ -45,6 +46,16 @@ class NASNetworkSETN(nn.Module): | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||
|     self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||
|     self.mode = 'urs' | ||||
|     self.dynamic_cell = None | ||||
|  | ||||
|   def set_cal_mode(self, mode, dynamic_cell=None): | ||||
|     assert mode in ['urs', 'joint', 'select', 'dynamic'] | ||||
|     self.mode = mode | ||||
|     if mode == 'dynamic': | ||||
|       self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|     else: | ||||
|       self.dynamic_cell = None | ||||
|  | ||||
|   def get_weights(self): | ||||
|     xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) | ||||
| @@ -70,6 +81,24 @@ class NASNetworkSETN(nn.Module): | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def dync_genotype(self, use_random=False): | ||||
|     genotypes = [] | ||||
|     with torch.no_grad(): | ||||
|       alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|     for i in range(1, self.max_nodes): | ||||
|       xlist = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         if use_random: | ||||
|           op_name  = random.choice(self.op_names) | ||||
|         else: | ||||
|           weights  = alphas_cpu[ self.edge2index[node_str] ] | ||||
|           op_index = torch.multinomial(weights, 1).item() | ||||
|           op_name  = self.op_names[ op_index ] | ||||
|         xlist.append((op_name, j)) | ||||
|       genotypes.append( tuple(xlist) ) | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|   def genotype(self): | ||||
|     def _parse(weights): | ||||
|       gene = [] | ||||
| @@ -94,9 +123,6 @@ class NASNetworkSETN(nn.Module): | ||||
|   def forward(self, inputs): | ||||
|     normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1) | ||||
|     reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1) | ||||
|     with torch.no_grad(): | ||||
|       normal_hardwts_cpu = normal_hardwts.detach().cpu() | ||||
|       reduce_hardwts_cpu = reduce_hardwts.detach().cpu() | ||||
|  | ||||
|     s0 = s1 = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user