naswot/pycls/models/resnet.py
Jack Turner b74255e1f3 v2
2021-02-26 16:12:51 +00:00

281 lines
10 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""ResNe(X)t models."""
import pycls.core.net as net
import torch.nn as nn
from pycls.core.config import cfg
# Stage depths for ImageNet models
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
def get_trans_fun(name):
"""Retrieves the transformation function by name."""
trans_funs = {
"basic_transform": BasicTransform,
"bottleneck_transform": BottleneckTransform,
}
err_str = "Transformation function '{}' not supported"
assert name in trans_funs.keys(), err_str.format(name)
return trans_funs[name]
class ResHead(nn.Module):
"""ResNet head: AvgPool, 1x1."""
def __init__(self, w_in, nc):
super(ResHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, nc):
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
return cx
class BasicTransform(nn.Module):
"""Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
err_str = "Basic transform does not support w_b and num_gs options"
assert w_b is None and num_gs == 1, err_str
super(BasicTransform, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1):
err_str = "Basic transform does not support w_b and num_gs options"
assert w_b is None and num_gs == 1, err_str
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
def __init__(self, w_in, w_out, stride, w_b, num_gs):
super(BottleneckTransform, self).__init__()
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, w_b, num_gs):
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBlock(nn.Module):
"""Residual block: x + F(x)."""
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
super(ResBlock, self).__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs):
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs)
return cx
class ResStage(nn.Module):
"""Stage of ResNet."""
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
super(ResStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
self.add_module("b{}".format(i + 1), res_block)
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN)
cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs)
return cx
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
return cx
class ResNet(nn.Module):
"""ResNet model."""
def __init__(self):
datasets = ["cifar10", "imagenet"]
err_str = "Dataset {} is not supported"
assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET)
super(ResNet, self).__init__()
if "cifar" in cfg.TRAIN.DATASET:
self._construct_cifar()
else:
self._construct_imagenet()
self.apply(net.init_weights)
def _construct_cifar(self):
err_str = "Model depth should be of the format 6n + 2 for cifar"
assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str
d = int((cfg.MODEL.DEPTH - 2) / 6)
self.stem = ResStemCifar(3, 16)
self.s1 = ResStage(16, 16, stride=1, d=d)
self.s2 = ResStage(16, 32, stride=2, d=d)
self.s3 = ResStage(32, 64, stride=2, d=d)
self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES)
def _construct_imagenet(self):
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
w_b = gw * g
self.stem = ResStemIN(3, 64)
self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES)
def forward(self, x):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx):
"""Computes model complexity. If you alter the model, make sure to update."""
if "cifar" in cfg.TRAIN.DATASET:
d = int((cfg.MODEL.DEPTH - 2) / 6)
cx = ResStemCifar.complexity(cx, 3, 16)
cx = ResStage.complexity(cx, 16, 16, stride=1, d=d)
cx = ResStage.complexity(cx, 16, 32, stride=2, d=d)
cx = ResStage.complexity(cx, 32, 64, stride=2, d=d)
cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES)
else:
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
w_b = gw * g
cx = ResStemIN.complexity(cx, 3, 64)
cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g)
cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g)
cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g)
cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g)
cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES)
return cx