106 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			106 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
##################################################
 | 
						|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
						|
##################################################
 | 
						|
import math, torch
 | 
						|
import torch.nn as nn
 | 
						|
import torch.nn.functional as F
 | 
						|
from .initialization import initialize_resnet
 | 
						|
 | 
						|
 | 
						|
class Bottleneck(nn.Module):
 | 
						|
  def __init__(self, nChannels, growthRate):
 | 
						|
    super(Bottleneck, self).__init__()
 | 
						|
    interChannels = 4*growthRate
 | 
						|
    self.bn1 = nn.BatchNorm2d(nChannels)
 | 
						|
    self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
 | 
						|
    self.bn2 = nn.BatchNorm2d(interChannels)
 | 
						|
    self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False)
 | 
						|
 | 
						|
  def forward(self, x):
 | 
						|
    out = self.conv1(F.relu(self.bn1(x)))
 | 
						|
    out = self.conv2(F.relu(self.bn2(out)))
 | 
						|
    out = torch.cat((x, out), 1)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
class SingleLayer(nn.Module):
 | 
						|
  def __init__(self, nChannels, growthRate):
 | 
						|
    super(SingleLayer, self).__init__()
 | 
						|
    self.bn1 = nn.BatchNorm2d(nChannels)
 | 
						|
    self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)
 | 
						|
 | 
						|
  def forward(self, x):
 | 
						|
    out = self.conv1(F.relu(self.bn1(x)))
 | 
						|
    out = torch.cat((x, out), 1)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
class Transition(nn.Module):
 | 
						|
  def __init__(self, nChannels, nOutChannels):
 | 
						|
    super(Transition, self).__init__()
 | 
						|
    self.bn1 = nn.BatchNorm2d(nChannels)
 | 
						|
    self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
 | 
						|
 | 
						|
  def forward(self, x):
 | 
						|
    out = self.conv1(F.relu(self.bn1(x)))
 | 
						|
    out = F.avg_pool2d(out, 2)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
class DenseNet(nn.Module):
 | 
						|
  def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
 | 
						|
    super(DenseNet, self).__init__()
 | 
						|
 | 
						|
    if bottleneck:  nDenseBlocks = int( (depth-4) / 6 )
 | 
						|
    else         :  nDenseBlocks = int( (depth-4) / 3 )
 | 
						|
 | 
						|
    self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses)
 | 
						|
 | 
						|
    nChannels = 2*growthRate
 | 
						|
    self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
 | 
						|
 | 
						|
    self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
 | 
						|
    nChannels += nDenseBlocks*growthRate
 | 
						|
    nOutChannels = int(math.floor(nChannels*reduction))
 | 
						|
    self.trans1 = Transition(nChannels, nOutChannels)
 | 
						|
 | 
						|
    nChannels = nOutChannels
 | 
						|
    self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
 | 
						|
    nChannels += nDenseBlocks*growthRate
 | 
						|
    nOutChannels = int(math.floor(nChannels*reduction))
 | 
						|
    self.trans2 = Transition(nChannels, nOutChannels)
 | 
						|
 | 
						|
    nChannels = nOutChannels
 | 
						|
    self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
 | 
						|
    nChannels += nDenseBlocks*growthRate
 | 
						|
 | 
						|
    self.act = nn.Sequential(
 | 
						|
                  nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True),
 | 
						|
                  nn.AvgPool2d(8))
 | 
						|
    self.fc  = nn.Linear(nChannels, nClasses)
 | 
						|
 | 
						|
    self.apply(initialize_resnet)
 | 
						|
 | 
						|
  def get_message(self):
 | 
						|
    return self.message
 | 
						|
 | 
						|
  def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
 | 
						|
    layers = []
 | 
						|
    for i in range(int(nDenseBlocks)):
 | 
						|
      if bottleneck:
 | 
						|
        layers.append(Bottleneck(nChannels, growthRate))
 | 
						|
      else:
 | 
						|
        layers.append(SingleLayer(nChannels, growthRate))
 | 
						|
      nChannels += growthRate
 | 
						|
    return nn.Sequential(*layers)
 | 
						|
 | 
						|
  def forward(self, inputs):
 | 
						|
    out = self.conv1( inputs )
 | 
						|
    out = self.trans1(self.dense1(out))
 | 
						|
    out = self.trans2(self.dense2(out))
 | 
						|
    out = self.dense3(out)
 | 
						|
    features = self.act(out)
 | 
						|
    features = features.view(features.size(0), -1)
 | 
						|
    out = self.fc(features)
 | 
						|
    return features, out
 |