From 654015bf9daf62c501b4f4f8090370f907517b86 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 11 Jan 2020 18:46:31 +1100 Subject: [PATCH] simplify DARTS codes and update affine/track --- NAS-Bench-102.md | 1 + README.md | 3 +- exps/algos/DARTS-V1.py | 4 +- exps/algos/DARTS-V2.py | 4 +- exps/algos/ENAS.py | 4 +- exps/algos/RANDOM-NAS.py | 4 +- lib/models/cell_searchs/__init__.py | 7 +- ...odel_darts_v1.py => search_model_darts.py} | 8 +- .../cell_searchs/search_model_darts_v2.py | 93 ------------------- lib/models/cell_searchs/search_model_enas.py | 4 +- .../cell_searchs/search_model_random.py | 4 +- scripts-search/algos/DARTS-V1.sh | 1 + scripts-search/algos/DARTS-V2.sh | 1 + scripts-search/algos/ENAS.sh | 1 + scripts-search/algos/RANDOM-NAS.sh | 1 + 15 files changed, 30 insertions(+), 110 deletions(-) rename lib/models/cell_searchs/{search_model_darts_v1.py => search_model_darts.py} (94%) delete mode 100644 lib/models/cell_searchs/search_model_darts_v2.py diff --git a/NAS-Bench-102.md b/NAS-Bench-102.md index d29007a..7bb9545 100644 --- a/NAS-Bench-102.md +++ b/NAS-Bench-102.md @@ -194,6 +194,7 @@ If you find that NAS-Bench-102 helps your research, please consider citing it: title = {NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search}, author = {Dong, Xuanyi and Yang, Yi}, booktitle = {International Conference on Learning Representations (ICLR)}, + url = {https://openreview.net/forum?id=HJxyZkBKDr}, year = {2020} } ``` diff --git a/README.md b/README.md index 926e8ad..28a55f4 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ More NAS resources can be found in [Awesome-NAS](https://github.com/D-X-Y/Awesom Please install `PyTorch>=1.2.0`, `Python>=3.6`, and `opencv`. -The CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`. +CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`. Some methods use knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Driver](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`. ### Usefull tools @@ -150,6 +150,7 @@ If you find that this project helps your research, please consider citing some o title = {NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search}, author = {Dong, Xuanyi and Yang, Yi}, booktitle = {International Conference on Learning Representations (ICLR)}, + url = {https://openreview.net/forum?id=HJxyZkBKDr}, year = {2020} } @inproceedings{dong2019tas, diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index 0681dd2..870f97b 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -114,7 +114,8 @@ def main(xargs): search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space}, None) + 'space' : search_space, + 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) @@ -217,6 +218,7 @@ if __name__ == '__main__': parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') + parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') # architecture leraning rate parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') diff --git a/exps/algos/DARTS-V2.py b/exps/algos/DARTS-V2.py index 4972a91..3893bcb 100644 --- a/exps/algos/DARTS-V2.py +++ b/exps/algos/DARTS-V2.py @@ -177,7 +177,8 @@ def main(xargs): search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space}, None) + 'space' : search_space, + 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) @@ -282,6 +283,7 @@ if __name__ == '__main__': parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') + parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') # architecture leraning rate parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') diff --git a/exps/algos/ENAS.py b/exps/algos/ENAS.py index a0fc88a..71ef2f7 100644 --- a/exps/algos/ENAS.py +++ b/exps/algos/ENAS.py @@ -198,7 +198,8 @@ def main(xargs): search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space}, None) + 'space' : search_space, + 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) shared_cnn = get_cell_based_tiny_net(model_config) controller = shared_cnn.create_controller() @@ -319,6 +320,7 @@ if __name__ == '__main__': parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') # channels and number-of-cells + parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') parser.add_argument('--search_space_name', type=str, help='The search space name.') parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--channel', type=int, help='The number of channels.') diff --git a/exps/algos/RANDOM-NAS.py b/exps/algos/RANDOM-NAS.py index b61b5e5..b06a570 100644 --- a/exps/algos/RANDOM-NAS.py +++ b/exps/algos/RANDOM-NAS.py @@ -126,7 +126,8 @@ def main(xargs): search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space}, None) + 'space' : search_space, + 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config) @@ -222,6 +223,7 @@ if __name__ == '__main__': parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.') + parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') # log parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') diff --git a/lib/models/cell_searchs/__init__.py b/lib/models/cell_searchs/__init__.py index 2133795..2df49ca 100644 --- a/lib/models/cell_searchs/__init__.py +++ b/lib/models/cell_searchs/__init__.py @@ -1,16 +1,15 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## -from .search_model_darts_v1 import TinyNetworkDartsV1 -from .search_model_darts_v2 import TinyNetworkDartsV2 +from .search_model_darts import TinyNetworkDarts from .search_model_gdas import TinyNetworkGDAS from .search_model_setn import TinyNetworkSETN from .search_model_enas import TinyNetworkENAS from .search_model_random import TinyNetworkRANDOM from .genotypes import Structure as CellStructure, architectures as CellArchitectures -nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1, - 'DARTS-V2': TinyNetworkDartsV2, +nas_super_nets = {'DARTS-V1': TinyNetworkDarts, + 'DARTS-V2': TinyNetworkDarts, 'GDAS' : TinyNetworkGDAS, 'SETN' : TinyNetworkSETN, 'ENAS' : TinyNetworkENAS, diff --git a/lib/models/cell_searchs/search_model_darts_v1.py b/lib/models/cell_searchs/search_model_darts.py similarity index 94% rename from lib/models/cell_searchs/search_model_darts_v1.py rename to lib/models/cell_searchs/search_model_darts.py index 61ef8ea..32ffffd 100644 --- a/lib/models/cell_searchs/search_model_darts_v1.py +++ b/lib/models/cell_searchs/search_model_darts.py @@ -11,10 +11,10 @@ from .search_cells import SearchCell from .genotypes import Structure -class TinyNetworkDartsV1(nn.Module): +class TinyNetworkDarts(nn.Module): - def __init__(self, C, N, max_nodes, num_classes, search_space): - super(TinyNetworkDartsV1, self).__init__() + def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): + super(TinyNetworkDarts, self).__init__() self._C = C self._layerN = N self.max_nodes = max_nodes @@ -31,7 +31,7 @@ class TinyNetworkDartsV1(nn.Module): if reduction: cell = ResNetBasicblock(C_prev, C_curr, 2) else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) self.cells.append( cell ) diff --git a/lib/models/cell_searchs/search_model_darts_v2.py b/lib/models/cell_searchs/search_model_darts_v2.py deleted file mode 100644 index cb996ff..0000000 --- a/lib/models/cell_searchs/search_model_darts_v2.py +++ /dev/null @@ -1,93 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -######################################################## -# DARTS: Differentiable Architecture Search, ICLR 2019 # -######################################################## -import torch -import torch.nn as nn -from copy import deepcopy -from ..cell_operations import ResNetBasicblock -from .search_cells import SearchCell -from .genotypes import Structure - - -class TinyNetworkDartsV2(nn.Module): - - def __init__(self, C, N, max_nodes, num_classes, search_space): - super(TinyNetworkDartsV2, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist - - def get_alphas(self): - return [self.arch_parameters] - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) - - def forward(self, inputs): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell(feature, alphas) - else: - feature = cell(feature) - - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - - return out, logits diff --git a/lib/models/cell_searchs/search_model_enas.py b/lib/models/cell_searchs/search_model_enas.py index 2422b52..3d89f37 100644 --- a/lib/models/cell_searchs/search_model_enas.py +++ b/lib/models/cell_searchs/search_model_enas.py @@ -14,7 +14,7 @@ from .search_model_enas_utils import Controller class TinyNetworkENAS(nn.Module): - def __init__(self, C, N, max_nodes, num_classes, search_space): + def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): super(TinyNetworkENAS, self).__init__() self._C = C self._layerN = N @@ -32,7 +32,7 @@ class TinyNetworkENAS(nn.Module): if reduction: cell = ResNetBasicblock(C_prev, C_curr, 2) else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) self.cells.append( cell ) diff --git a/lib/models/cell_searchs/search_model_random.py b/lib/models/cell_searchs/search_model_random.py index c2f83f9..e4f69e2 100644 --- a/lib/models/cell_searchs/search_model_random.py +++ b/lib/models/cell_searchs/search_model_random.py @@ -13,7 +13,7 @@ from .genotypes import Structure class TinyNetworkRANDOM(nn.Module): - def __init__(self, C, N, max_nodes, num_classes, search_space): + def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats): super(TinyNetworkRANDOM, self).__init__() self._C = C self._layerN = N @@ -31,7 +31,7 @@ class TinyNetworkRANDOM(nn.Module): if reduction: cell = ResNetBasicblock(C_prev, C_curr, 2) else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) self.cells.append( cell ) diff --git a/scripts-search/algos/DARTS-V1.sh b/scripts-search/algos/DARTS-V1.sh index a096b54..f25f37a 100644 --- a/scripts-search/algos/DARTS-V1.sh +++ b/scripts-search/algos/DARTS-V1.sh @@ -35,5 +35,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \ --search_space_name ${space} \ --config_path configs/nas-benchmark/algos/DARTS.config \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ + --track_running_stats 1 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --workers 4 --print_freq 200 --rand_seed ${seed} diff --git a/scripts-search/algos/DARTS-V2.sh b/scripts-search/algos/DARTS-V2.sh index 5b81268..f6d17da 100644 --- a/scripts-search/algos/DARTS-V2.sh +++ b/scripts-search/algos/DARTS-V2.sh @@ -35,5 +35,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \ --search_space_name ${space} \ --config_path configs/nas-benchmark/algos/DARTS.config \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ + --track_running_stats 1 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --workers 4 --print_freq 200 --rand_seed ${seed} diff --git a/scripts-search/algos/ENAS.sh b/scripts-search/algos/ENAS.sh index 18a378f..fc39361 100644 --- a/scripts-search/algos/ENAS.sh +++ b/scripts-search/algos/ENAS.sh @@ -35,6 +35,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/ENAS.py \ --dataset ${dataset} --data_path ${data_path} \ --search_space_name ${space} \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ + --track_running_stats 1 \ --config_path ./configs/nas-benchmark/algos/ENAS.config \ --controller_entropy_weight 0.0001 \ --controller_bl_dec 0.99 \ diff --git a/scripts-search/algos/RANDOM-NAS.sh b/scripts-search/algos/RANDOM-NAS.sh index ba8e326..d964958 100644 --- a/scripts-search/algos/RANDOM-NAS.sh +++ b/scripts-search/algos/RANDOM-NAS.sh @@ -34,6 +34,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM-NAS.py \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --dataset ${dataset} --data_path ${data_path} \ --search_space_name ${space} \ + --track_running_stats 1 \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --config_path ./configs/nas-benchmark/algos/RANDOM.config \ --select_num 100 \