102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines
This commit is contained in:
		| @@ -19,7 +19,7 @@ def get_cell_based_tiny_net(config): | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] | ||||
|   if super_type == 'basic' and config.name in group_names: | ||||
|     from .cell_searchs import nas102_super_nets as nas_super_nets | ||||
|     from .cell_searchs import nas201_super_nets as nas_super_nets | ||||
|     try: | ||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||
|     except: | ||||
|   | ||||
| @@ -1,8 +1,13 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################## | ||||
|  | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # Cell for NAS-Bench-201 | ||||
| class InferCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_in, C_out, stride): | ||||
|   | ||||
| @@ -1,9 +1,13 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .cells import InferCell | ||||
|  | ||||
|  | ||||
| # The macro structure for architectures in NAS-Bench-201 | ||||
| class TinyNetwork(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, genotype, num_classes): | ||||
|   | ||||
| @@ -21,12 +21,11 @@ OPS = { | ||||
| } | ||||
|  | ||||
| CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||
| NAS_BENCH_102         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||
| NAS_BENCH_201         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||
| DARTS_SPACE           = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3'] | ||||
|  | ||||
| SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | ||||
|                     'aa-nas'       : NAS_BENCH_102, | ||||
|                     'nas-bench-102': NAS_BENCH_102, | ||||
|                     'nas-bench-201': NAS_BENCH_201, | ||||
|                     'darts'        : DARTS_SPACE} | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # The macro structure is defined in NAS-Bench-102 | ||||
| # The macro structure is defined in NAS-Bench-201 | ||||
| from .search_model_darts    import TinyNetworkDarts | ||||
| from .search_model_gdas     import TinyNetworkGDAS | ||||
| from .search_model_setn     import TinyNetworkSETN | ||||
| @@ -12,7 +12,7 @@ from .genotypes             import Structure as CellStructure, architectures as | ||||
| from .search_model_gdas_nasnet import NASNetworkGDAS | ||||
|  | ||||
|  | ||||
| nas102_super_nets = {'DARTS-V1': TinyNetworkDarts, | ||||
| nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, | ||||
|                   'DARTS-V2': TinyNetworkDarts, | ||||
|                   'GDAS'    : TinyNetworkGDAS, | ||||
|                   'SETN'    : TinyNetworkSETN, | ||||
|   | ||||
| @@ -9,11 +9,11 @@ from copy import deepcopy | ||||
| from ..cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # This module is used for NAS-Bench-102, represents a small search space with a complete DAG | ||||
| class NAS102SearchCell(nn.Module): | ||||
| # This module is used for NAS-Bench-201, represents a small search space with a complete DAG | ||||
| class NAS201SearchCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): | ||||
|     super(NAS102SearchCell, self).__init__() | ||||
|     super(NAS201SearchCell, self).__init__() | ||||
|  | ||||
|     self.op_names  = deepcopy(op_names) | ||||
|     self.edges     = nn.ModuleDict() | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells     import NAS102SearchCell as SearchCell | ||||
| from .search_cells     import NAS201SearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells     import NAS102SearchCell as SearchCell | ||||
| from .search_cells     import NAS201SearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
| from .search_model_enas_utils import Controller | ||||
|  | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells     import NAS102SearchCell as SearchCell | ||||
| from .search_cells     import NAS201SearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells     import NAS102SearchCell as SearchCell | ||||
| from .search_cells     import NAS201SearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells     import NAS102SearchCell as SearchCell | ||||
| from .search_cells     import NAS201SearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user