################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## import math, torch import torch.nn as nn import torch.nn.functional as F from .initialization import initialize_resnet class Bottleneck(nn.Module): def __init__(self, nChannels, growthRate): super(Bottleneck, self).__init__() interChannels = 4*growthRate self.bn1 = nn.BatchNorm2d(nChannels) self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(interChannels) self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = self.conv2(F.relu(self.bn2(out))) out = torch.cat((x, out), 1) return out class SingleLayer(nn.Module): def __init__(self, nChannels, growthRate): super(SingleLayer, self).__init__() self.bn1 = nn.BatchNorm2d(nChannels) self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = torch.cat((x, out), 1) return out class Transition(nn.Module): def __init__(self, nChannels, nOutChannels): super(Transition, self).__init__() self.bn1 = nn.BatchNorm2d(nChannels) self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = F.avg_pool2d(out, 2) return out class DenseNet(nn.Module): def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): super(DenseNet, self).__init__() if bottleneck: nDenseBlocks = int( (depth-4) / 6 ) else : nDenseBlocks = int( (depth-4) / 3 ) self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses) nChannels = 2*growthRate self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) nChannels += nDenseBlocks*growthRate nOutChannels = int(math.floor(nChannels*reduction)) self.trans1 = Transition(nChannels, nOutChannels) nChannels = nOutChannels self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) nChannels += nDenseBlocks*growthRate nOutChannels = int(math.floor(nChannels*reduction)) self.trans2 = Transition(nChannels, nOutChannels) nChannels = nOutChannels self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) nChannels += nDenseBlocks*growthRate self.act = nn.Sequential( nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8)) self.fc = nn.Linear(nChannels, nClasses) self.apply(initialize_resnet) def get_message(self): return self.message def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): layers = [] for i in range(int(nDenseBlocks)): if bottleneck: layers.append(Bottleneck(nChannels, growthRate)) else: layers.append(SingleLayer(nChannels, growthRate)) nChannels += growthRate return nn.Sequential(*layers) def forward(self, inputs): out = self.conv1( inputs ) out = self.trans1(self.dense1(out)) out = self.trans2(self.dense2(out)) out = self.dense3(out) features = self.act(out) features = features.view(features.size(0), -1) out = self.fc(features) return features, out