update NAS-Bench-102 baselines
This commit is contained in:
		| @@ -20,7 +20,10 @@ def get_cell_based_tiny_net(config): | ||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] | ||||
|   if super_type == 'basic' and config.name in group_names: | ||||
|     from .cell_searchs import nas_super_nets | ||||
|     return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|     try: | ||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||
|     except: | ||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   elif super_type == 'l2s-base' and config.name in group_names: | ||||
|     from .l2s_cell_searchs import nas_super_nets | ||||
|     return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space \ | ||||
|   | ||||
| @@ -11,7 +11,8 @@ from .genotypes        import Structure | ||||
|  | ||||
| class TinyNetworkGDAS(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): | ||||
|   #def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): | ||||
|   def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): | ||||
|     super(TinyNetworkGDAS, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
|   | ||||
| @@ -13,7 +13,7 @@ from .genotypes        import Structure | ||||
|  | ||||
| class TinyNetworkSETN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, max_nodes, num_classes, search_space): | ||||
|   def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): | ||||
|     super(TinyNetworkSETN, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
| @@ -31,7 +31,7 @@ class TinyNetworkSETN(nn.Module): | ||||
|       if reduction: | ||||
|         cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|       else: | ||||
|         cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) | ||||
|         cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) | ||||
|         if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|         else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) | ||||
|       self.cells.append( cell ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user