update README
This commit is contained in:
		| @@ -5,7 +5,7 @@ This project contains the following neural architecture search algorithms, imple | ||||
| - Network Pruning via Transformable Architecture Search, NeurIPS 2019 | ||||
| - One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 | ||||
| - Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 | ||||
| - several typical classification models, e.g., ResNet and DenseNet (see BASELINE.md) | ||||
| - several typical classification models, e.g., ResNet and DenseNet (see [BASELINE.md](https://github.com/D-X-Y/NAS-Projects/blob/master/BASELINE.md)) | ||||
|  | ||||
|  | ||||
| ## Requirements and Preparation | ||||
|   | ||||
							
								
								
									
										105
									
								
								lib/models/CifarDenseNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								lib/models/CifarDenseNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class Bottleneck(nn.Module): | ||||
|   def __init__(self, nChannels, growthRate): | ||||
|     super(Bottleneck, self).__init__() | ||||
|     interChannels = 4*growthRate | ||||
|     self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|     self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) | ||||
|     self.bn2 = nn.BatchNorm2d(interChannels) | ||||
|     self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     out = self.conv1(F.relu(self.bn1(x))) | ||||
|     out = self.conv2(F.relu(self.bn2(out))) | ||||
|     out = torch.cat((x, out), 1) | ||||
|     return out | ||||
|  | ||||
|  | ||||
| class SingleLayer(nn.Module): | ||||
|   def __init__(self, nChannels, growthRate): | ||||
|     super(SingleLayer, self).__init__() | ||||
|     self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|     self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     out = self.conv1(F.relu(self.bn1(x))) | ||||
|     out = torch.cat((x, out), 1) | ||||
|     return out | ||||
|  | ||||
|  | ||||
| class Transition(nn.Module): | ||||
|   def __init__(self, nChannels, nOutChannels): | ||||
|     super(Transition, self).__init__() | ||||
|     self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|     self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     out = self.conv1(F.relu(self.bn1(x))) | ||||
|     out = F.avg_pool2d(out, 2) | ||||
|     return out | ||||
|  | ||||
|  | ||||
| class DenseNet(nn.Module): | ||||
|   def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): | ||||
|     super(DenseNet, self).__init__() | ||||
|  | ||||
|     if bottleneck:  nDenseBlocks = int( (depth-4) / 6 ) | ||||
|     else         :  nDenseBlocks = int( (depth-4) / 3 ) | ||||
|  | ||||
|     self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses) | ||||
|  | ||||
|     nChannels = 2*growthRate | ||||
|     self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) | ||||
|  | ||||
|     self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|     nChannels += nDenseBlocks*growthRate | ||||
|     nOutChannels = int(math.floor(nChannels*reduction)) | ||||
|     self.trans1 = Transition(nChannels, nOutChannels) | ||||
|  | ||||
|     nChannels = nOutChannels | ||||
|     self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|     nChannels += nDenseBlocks*growthRate | ||||
|     nOutChannels = int(math.floor(nChannels*reduction)) | ||||
|     self.trans2 = Transition(nChannels, nOutChannels) | ||||
|  | ||||
|     nChannels = nOutChannels | ||||
|     self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|     nChannels += nDenseBlocks*growthRate | ||||
|  | ||||
|     self.act = nn.Sequential( | ||||
|                   nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), | ||||
|                   nn.AvgPool2d(8)) | ||||
|     self.fc  = nn.Linear(nChannels, nClasses) | ||||
|  | ||||
|     self.apply(initialize_resnet) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|  | ||||
|   def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): | ||||
|     layers = [] | ||||
|     for i in range(int(nDenseBlocks)): | ||||
|       if bottleneck: | ||||
|         layers.append(Bottleneck(nChannels, growthRate)) | ||||
|       else: | ||||
|         layers.append(SingleLayer(nChannels, growthRate)) | ||||
|       nChannels += growthRate | ||||
|     return nn.Sequential(*layers) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     out = self.conv1( inputs ) | ||||
|     out = self.trans1(self.dense1(out)) | ||||
|     out = self.trans2(self.dense2(out)) | ||||
|     out = self.dense3(out) | ||||
|     features = self.act(out) | ||||
|     features = features.view(features.size(0), -1) | ||||
|     out = self.fc(features) | ||||
|     return features, out | ||||
| @@ -38,12 +38,15 @@ def get_search_spaces(xtype, name): | ||||
|  | ||||
| def get_cifar_models(config): | ||||
|   from .CifarResNet      import CifarResNet | ||||
|   from .CifarDenseNet    import DenseNet | ||||
|   from .CifarWideResNet  import CifarWideResNet | ||||
|    | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   if super_type == 'basic': | ||||
|     if config.arch == 'resnet': | ||||
|       return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) | ||||
|     elif config.arch == 'densenet': | ||||
|       return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck) | ||||
|     elif config.arch == 'wideresnet': | ||||
|       return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout) | ||||
|     else: | ||||
| @@ -68,8 +71,13 @@ def get_cifar_models(config): | ||||
|  | ||||
| def get_imagenet_models(config): | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   # NAS searched architecture | ||||
|   if super_type.startswith('infer'): | ||||
|   if super_type == 'basic': | ||||
|     from .ImagenetResNet import ResNet | ||||
|     if config.arch == 'resnet': | ||||
|       return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group) | ||||
|     else: | ||||
|       raise ValueError('invalid arch : {:}'.format( config.arch )) | ||||
|   elif super_type.startswith('infer'): # NAS searched architecture | ||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) | ||||
|     infer_mode = super_type.split('-')[1] | ||||
|     if infer_mode == 'shape': | ||||
|   | ||||
		Reference in New Issue
	
	Block a user