106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
##################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
##################################################
|
|
import math, torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .initialization import initialize_resnet
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
def __init__(self, nChannels, growthRate):
|
|
super(Bottleneck, self).__init__()
|
|
interChannels = 4*growthRate
|
|
self.bn1 = nn.BatchNorm2d(nChannels)
|
|
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
|
|
self.bn2 = nn.BatchNorm2d(interChannels)
|
|
self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(F.relu(self.bn1(x)))
|
|
out = self.conv2(F.relu(self.bn2(out)))
|
|
out = torch.cat((x, out), 1)
|
|
return out
|
|
|
|
|
|
class SingleLayer(nn.Module):
|
|
def __init__(self, nChannels, growthRate):
|
|
super(SingleLayer, self).__init__()
|
|
self.bn1 = nn.BatchNorm2d(nChannels)
|
|
self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(F.relu(self.bn1(x)))
|
|
out = torch.cat((x, out), 1)
|
|
return out
|
|
|
|
|
|
class Transition(nn.Module):
|
|
def __init__(self, nChannels, nOutChannels):
|
|
super(Transition, self).__init__()
|
|
self.bn1 = nn.BatchNorm2d(nChannels)
|
|
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(F.relu(self.bn1(x)))
|
|
out = F.avg_pool2d(out, 2)
|
|
return out
|
|
|
|
|
|
class DenseNet(nn.Module):
|
|
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
|
|
super(DenseNet, self).__init__()
|
|
|
|
if bottleneck: nDenseBlocks = int( (depth-4) / 6 )
|
|
else : nDenseBlocks = int( (depth-4) / 3 )
|
|
|
|
self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses)
|
|
|
|
nChannels = 2*growthRate
|
|
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
|
|
|
|
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
|
nChannels += nDenseBlocks*growthRate
|
|
nOutChannels = int(math.floor(nChannels*reduction))
|
|
self.trans1 = Transition(nChannels, nOutChannels)
|
|
|
|
nChannels = nOutChannels
|
|
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
|
nChannels += nDenseBlocks*growthRate
|
|
nOutChannels = int(math.floor(nChannels*reduction))
|
|
self.trans2 = Transition(nChannels, nOutChannels)
|
|
|
|
nChannels = nOutChannels
|
|
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
|
nChannels += nDenseBlocks*growthRate
|
|
|
|
self.act = nn.Sequential(
|
|
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True),
|
|
nn.AvgPool2d(8))
|
|
self.fc = nn.Linear(nChannels, nClasses)
|
|
|
|
self.apply(initialize_resnet)
|
|
|
|
def get_message(self):
|
|
return self.message
|
|
|
|
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
|
|
layers = []
|
|
for i in range(int(nDenseBlocks)):
|
|
if bottleneck:
|
|
layers.append(Bottleneck(nChannels, growthRate))
|
|
else:
|
|
layers.append(SingleLayer(nChannels, growthRate))
|
|
nChannels += growthRate
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, inputs):
|
|
out = self.conv1( inputs )
|
|
out = self.trans1(self.dense1(out))
|
|
out = self.trans2(self.dense2(out))
|
|
out = self.dense3(out)
|
|
features = self.act(out)
|
|
features = features.view(features.size(0), -1)
|
|
out = self.fc(features)
|
|
return features, out
|