278 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			278 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
 | |
| #####################################################
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| from ..initialization import initialize_resnet
 | |
| 
 | |
| 
 | |
| class ConvBNReLU(nn.Module):
 | |
|     def __init__(
 | |
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
 | |
|     ):
 | |
|         super(ConvBNReLU, self).__init__()
 | |
|         if has_avg:
 | |
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
 | |
|         else:
 | |
|             self.avg = None
 | |
|         self.conv = nn.Conv2d(
 | |
|             nIn,
 | |
|             nOut,
 | |
|             kernel_size=kernel,
 | |
|             stride=stride,
 | |
|             padding=padding,
 | |
|             dilation=1,
 | |
|             groups=1,
 | |
|             bias=bias,
 | |
|         )
 | |
|         if has_bn:
 | |
|             self.bn = nn.BatchNorm2d(nOut)
 | |
|         else:
 | |
|             self.bn = None
 | |
|         if has_relu:
 | |
|             self.relu = nn.ReLU(inplace=True)
 | |
|         else:
 | |
|             self.relu = None
 | |
| 
 | |
|     def forward(self, inputs):
 | |
|         if self.avg:
 | |
|             out = self.avg(inputs)
 | |
|         else:
 | |
|             out = inputs
 | |
|         conv = self.conv(out)
 | |
|         if self.bn:
 | |
|             out = self.bn(conv)
 | |
|         else:
 | |
|             out = conv
 | |
|         if self.relu:
 | |
|             out = self.relu(out)
 | |
|         else:
 | |
|             out = out
 | |
| 
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class ResNetBasicblock(nn.Module):
 | |
|     num_conv = 2
 | |
|     expansion = 1
 | |
| 
 | |
|     def __init__(self, iCs, stride):
 | |
|         super(ResNetBasicblock, self).__init__()
 | |
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
 | |
|         assert isinstance(iCs, tuple) or isinstance(
 | |
|             iCs, list
 | |
|         ), "invalid type of iCs : {:}".format(iCs)
 | |
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
 | |
| 
 | |
|         self.conv_a = ConvBNReLU(
 | |
|             iCs[0],
 | |
|             iCs[1],
 | |
|             3,
 | |
|             stride,
 | |
|             1,
 | |
|             False,
 | |
|             has_avg=False,
 | |
|             has_bn=True,
 | |
|             has_relu=True,
 | |
|         )
 | |
|         self.conv_b = ConvBNReLU(
 | |
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
 | |
|         )
 | |
|         residual_in = iCs[0]
 | |
|         if stride == 2:
 | |
|             self.downsample = ConvBNReLU(
 | |
|                 iCs[0],
 | |
|                 iCs[2],
 | |
|                 1,
 | |
|                 1,
 | |
|                 0,
 | |
|                 False,
 | |
|                 has_avg=True,
 | |
|                 has_bn=False,
 | |
|                 has_relu=False,
 | |
|             )
 | |
|             residual_in = iCs[2]
 | |
|         elif iCs[0] != iCs[2]:
 | |
|             self.downsample = ConvBNReLU(
 | |
|                 iCs[0],
 | |
|                 iCs[2],
 | |
|                 1,
 | |
|                 1,
 | |
|                 0,
 | |
|                 False,
 | |
|                 has_avg=False,
 | |
|                 has_bn=True,
 | |
|                 has_relu=False,
 | |
|             )
 | |
|         else:
 | |
|             self.downsample = None
 | |
|         # self.out_dim  = max(residual_in, iCs[2])
 | |
|         self.out_dim = iCs[2]
 | |
| 
 | |
|     def forward(self, inputs):
 | |
|         basicblock = self.conv_a(inputs)
 | |
|         basicblock = self.conv_b(basicblock)
 | |
| 
 | |
|         if self.downsample is not None:
 | |
|             residual = self.downsample(inputs)
 | |
|         else:
 | |
|             residual = inputs
 | |
|         out = residual + basicblock
 | |
|         return F.relu(out, inplace=True)
 | |
| 
 | |
| 
 | |
| class ResNetBottleneck(nn.Module):
 | |
|     expansion = 4
 | |
|     num_conv = 3
 | |
| 
 | |
|     def __init__(self, iCs, stride):
 | |
|         super(ResNetBottleneck, self).__init__()
 | |
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
 | |
|         assert isinstance(iCs, tuple) or isinstance(
 | |
|             iCs, list
 | |
|         ), "invalid type of iCs : {:}".format(iCs)
 | |
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
 | |
|         self.conv_1x1 = ConvBNReLU(
 | |
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
 | |
|         )
 | |
|         self.conv_3x3 = ConvBNReLU(
 | |
|             iCs[1],
 | |
|             iCs[2],
 | |
|             3,
 | |
|             stride,
 | |
|             1,
 | |
|             False,
 | |
|             has_avg=False,
 | |
|             has_bn=True,
 | |
|             has_relu=True,
 | |
|         )
 | |
|         self.conv_1x4 = ConvBNReLU(
 | |
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
 | |
|         )
 | |
|         residual_in = iCs[0]
 | |
|         if stride == 2:
 | |
|             self.downsample = ConvBNReLU(
 | |
|                 iCs[0],
 | |
|                 iCs[3],
 | |
|                 1,
 | |
|                 1,
 | |
|                 0,
 | |
|                 False,
 | |
|                 has_avg=True,
 | |
|                 has_bn=False,
 | |
|                 has_relu=False,
 | |
|             )
 | |
|             residual_in = iCs[3]
 | |
|         elif iCs[0] != iCs[3]:
 | |
|             self.downsample = ConvBNReLU(
 | |
|                 iCs[0],
 | |
|                 iCs[3],
 | |
|                 1,
 | |
|                 1,
 | |
|                 0,
 | |
|                 False,
 | |
|                 has_avg=False,
 | |
|                 has_bn=False,
 | |
|                 has_relu=False,
 | |
|             )
 | |
|             residual_in = iCs[3]
 | |
|         else:
 | |
|             self.downsample = None
 | |
|         # self.out_dim = max(residual_in, iCs[3])
 | |
|         self.out_dim = iCs[3]
 | |
| 
 | |
|     def forward(self, inputs):
 | |
| 
 | |
|         bottleneck = self.conv_1x1(inputs)
 | |
|         bottleneck = self.conv_3x3(bottleneck)
 | |
|         bottleneck = self.conv_1x4(bottleneck)
 | |
| 
 | |
|         if self.downsample is not None:
 | |
|             residual = self.downsample(inputs)
 | |
|         else:
 | |
|             residual = inputs
 | |
|         out = residual + bottleneck
 | |
|         return F.relu(out, inplace=True)
 | |
| 
 | |
| 
 | |
| class InferWidthCifarResNet(nn.Module):
 | |
|     def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
 | |
|         super(InferWidthCifarResNet, self).__init__()
 | |
| 
 | |
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
 | |
|         if block_name == "ResNetBasicblock":
 | |
|             block = ResNetBasicblock
 | |
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
 | |
|             layer_blocks = (depth - 2) // 6
 | |
|         elif block_name == "ResNetBottleneck":
 | |
|             block = ResNetBottleneck
 | |
|             assert (depth - 2) % 9 == 0, "depth should be one of 164"
 | |
|             layer_blocks = (depth - 2) // 9
 | |
|         else:
 | |
|             raise ValueError("invalid block : {:}".format(block_name))
 | |
| 
 | |
|         self.message = (
 | |
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
 | |
|                 depth, layer_blocks
 | |
|             )
 | |
|         )
 | |
|         self.num_classes = num_classes
 | |
|         self.xchannels = xchannels
 | |
|         self.layers = nn.ModuleList(
 | |
|             [
 | |
|                 ConvBNReLU(
 | |
|                     xchannels[0],
 | |
|                     xchannels[1],
 | |
|                     3,
 | |
|                     1,
 | |
|                     1,
 | |
|                     False,
 | |
|                     has_avg=False,
 | |
|                     has_bn=True,
 | |
|                     has_relu=True,
 | |
|                 )
 | |
|             ]
 | |
|         )
 | |
|         last_channel_idx = 1
 | |
|         for stage in range(3):
 | |
|             for iL in range(layer_blocks):
 | |
|                 num_conv = block.num_conv
 | |
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
 | |
|                 stride = 2 if stage > 0 and iL == 0 else 1
 | |
|                 module = block(iCs, stride)
 | |
|                 last_channel_idx += num_conv
 | |
|                 self.xchannels[last_channel_idx] = module.out_dim
 | |
|                 self.layers.append(module)
 | |
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
 | |
|                     stage,
 | |
|                     iL,
 | |
|                     layer_blocks,
 | |
|                     len(self.layers) - 1,
 | |
|                     iCs,
 | |
|                     module.out_dim,
 | |
|                     stride,
 | |
|                 )
 | |
| 
 | |
|         self.avgpool = nn.AvgPool2d(8)
 | |
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes)
 | |
| 
 | |
|         self.apply(initialize_resnet)
 | |
|         if zero_init_residual:
 | |
|             for m in self.modules():
 | |
|                 if isinstance(m, ResNetBasicblock):
 | |
|                     nn.init.constant_(m.conv_b.bn.weight, 0)
 | |
|                 elif isinstance(m, ResNetBottleneck):
 | |
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0)
 | |
| 
 | |
|     def get_message(self):
 | |
|         return self.message
 | |
| 
 | |
|     def forward(self, inputs):
 | |
|         x = inputs
 | |
|         for i, layer in enumerate(self.layers):
 | |
|             x = layer(x)
 | |
|         features = self.avgpool(x)
 | |
|         features = features.view(features.size(0), -1)
 | |
|         logits = self.classifier(features)
 | |
|         return features, logits
 |