update README
This commit is contained in:
		| @@ -5,7 +5,7 @@ This project contains the following neural architecture search algorithms, imple | |||||||
| - Network Pruning via Transformable Architecture Search, NeurIPS 2019 | - Network Pruning via Transformable Architecture Search, NeurIPS 2019 | ||||||
| - One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 | - One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 | ||||||
| - Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 | - Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 | ||||||
| - several typical classification models, e.g., ResNet and DenseNet (see BASELINE.md) | - several typical classification models, e.g., ResNet and DenseNet (see [BASELINE.md](https://github.com/D-X-Y/NAS-Projects/blob/master/BASELINE.md)) | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Requirements and Preparation | ## Requirements and Preparation | ||||||
|   | |||||||
							
								
								
									
										105
									
								
								lib/models/CifarDenseNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								lib/models/CifarDenseNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ################################################## | ||||||
|  | import math, torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from .initialization import initialize_resnet | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Bottleneck(nn.Module): | ||||||
|  |   def __init__(self, nChannels, growthRate): | ||||||
|  |     super(Bottleneck, self).__init__() | ||||||
|  |     interChannels = 4*growthRate | ||||||
|  |     self.bn1 = nn.BatchNorm2d(nChannels) | ||||||
|  |     self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) | ||||||
|  |     self.bn2 = nn.BatchNorm2d(interChannels) | ||||||
|  |     self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) | ||||||
|  |  | ||||||
|  |   def forward(self, x): | ||||||
|  |     out = self.conv1(F.relu(self.bn1(x))) | ||||||
|  |     out = self.conv2(F.relu(self.bn2(out))) | ||||||
|  |     out = torch.cat((x, out), 1) | ||||||
|  |     return out | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SingleLayer(nn.Module): | ||||||
|  |   def __init__(self, nChannels, growthRate): | ||||||
|  |     super(SingleLayer, self).__init__() | ||||||
|  |     self.bn1 = nn.BatchNorm2d(nChannels) | ||||||
|  |     self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) | ||||||
|  |  | ||||||
|  |   def forward(self, x): | ||||||
|  |     out = self.conv1(F.relu(self.bn1(x))) | ||||||
|  |     out = torch.cat((x, out), 1) | ||||||
|  |     return out | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Transition(nn.Module): | ||||||
|  |   def __init__(self, nChannels, nOutChannels): | ||||||
|  |     super(Transition, self).__init__() | ||||||
|  |     self.bn1 = nn.BatchNorm2d(nChannels) | ||||||
|  |     self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) | ||||||
|  |  | ||||||
|  |   def forward(self, x): | ||||||
|  |     out = self.conv1(F.relu(self.bn1(x))) | ||||||
|  |     out = F.avg_pool2d(out, 2) | ||||||
|  |     return out | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DenseNet(nn.Module): | ||||||
|  |   def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): | ||||||
|  |     super(DenseNet, self).__init__() | ||||||
|  |  | ||||||
|  |     if bottleneck:  nDenseBlocks = int( (depth-4) / 6 ) | ||||||
|  |     else         :  nDenseBlocks = int( (depth-4) / 3 ) | ||||||
|  |  | ||||||
|  |     self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses) | ||||||
|  |  | ||||||
|  |     nChannels = 2*growthRate | ||||||
|  |     self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) | ||||||
|  |  | ||||||
|  |     self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||||
|  |     nChannels += nDenseBlocks*growthRate | ||||||
|  |     nOutChannels = int(math.floor(nChannels*reduction)) | ||||||
|  |     self.trans1 = Transition(nChannels, nOutChannels) | ||||||
|  |  | ||||||
|  |     nChannels = nOutChannels | ||||||
|  |     self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||||
|  |     nChannels += nDenseBlocks*growthRate | ||||||
|  |     nOutChannels = int(math.floor(nChannels*reduction)) | ||||||
|  |     self.trans2 = Transition(nChannels, nOutChannels) | ||||||
|  |  | ||||||
|  |     nChannels = nOutChannels | ||||||
|  |     self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||||
|  |     nChannels += nDenseBlocks*growthRate | ||||||
|  |  | ||||||
|  |     self.act = nn.Sequential( | ||||||
|  |                   nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), | ||||||
|  |                   nn.AvgPool2d(8)) | ||||||
|  |     self.fc  = nn.Linear(nChannels, nClasses) | ||||||
|  |  | ||||||
|  |     self.apply(initialize_resnet) | ||||||
|  |  | ||||||
|  |   def get_message(self): | ||||||
|  |     return self.message | ||||||
|  |  | ||||||
|  |   def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): | ||||||
|  |     layers = [] | ||||||
|  |     for i in range(int(nDenseBlocks)): | ||||||
|  |       if bottleneck: | ||||||
|  |         layers.append(Bottleneck(nChannels, growthRate)) | ||||||
|  |       else: | ||||||
|  |         layers.append(SingleLayer(nChannels, growthRate)) | ||||||
|  |       nChannels += growthRate | ||||||
|  |     return nn.Sequential(*layers) | ||||||
|  |  | ||||||
|  |   def forward(self, inputs): | ||||||
|  |     out = self.conv1( inputs ) | ||||||
|  |     out = self.trans1(self.dense1(out)) | ||||||
|  |     out = self.trans2(self.dense2(out)) | ||||||
|  |     out = self.dense3(out) | ||||||
|  |     features = self.act(out) | ||||||
|  |     features = features.view(features.size(0), -1) | ||||||
|  |     out = self.fc(features) | ||||||
|  |     return features, out | ||||||
| @@ -38,12 +38,15 @@ def get_search_spaces(xtype, name): | |||||||
|  |  | ||||||
| def get_cifar_models(config): | def get_cifar_models(config): | ||||||
|   from .CifarResNet      import CifarResNet |   from .CifarResNet      import CifarResNet | ||||||
|  |   from .CifarDenseNet    import DenseNet | ||||||
|   from .CifarWideResNet  import CifarWideResNet |   from .CifarWideResNet  import CifarWideResNet | ||||||
|    |    | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |   super_type = getattr(config, 'super_type', 'basic') | ||||||
|   if super_type == 'basic': |   if super_type == 'basic': | ||||||
|     if config.arch == 'resnet': |     if config.arch == 'resnet': | ||||||
|       return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) |       return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) | ||||||
|  |     elif config.arch == 'densenet': | ||||||
|  |       return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck) | ||||||
|     elif config.arch == 'wideresnet': |     elif config.arch == 'wideresnet': | ||||||
|       return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout) |       return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout) | ||||||
|     else: |     else: | ||||||
| @@ -68,8 +71,13 @@ def get_cifar_models(config): | |||||||
|  |  | ||||||
| def get_imagenet_models(config): | def get_imagenet_models(config): | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |   super_type = getattr(config, 'super_type', 'basic') | ||||||
|   # NAS searched architecture |   if super_type == 'basic': | ||||||
|   if super_type.startswith('infer'): |     from .ImagenetResNet import ResNet | ||||||
|  |     if config.arch == 'resnet': | ||||||
|  |       return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group) | ||||||
|  |     else: | ||||||
|  |       raise ValueError('invalid arch : {:}'.format( config.arch )) | ||||||
|  |   elif super_type.startswith('infer'): # NAS searched architecture | ||||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) |     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) | ||||||
|     infer_mode = super_type.split('-')[1] |     infer_mode = super_type.split('-')[1] | ||||||
|     if infer_mode == 'shape': |     if infer_mode == 'shape': | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user