naswot/pycls/models/anynet.py
Jack Turner b74255e1f3 v2
2021-02-26 16:12:51 +00:00

407 lines
14 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""AnyNet models."""
import pycls.core.net as net
import torch.nn as nn
from pycls.core.config import cfg
def get_stem_fun(stem_type):
"""Retrieves the stem function by name."""
stem_funs = {
"res_stem_cifar": ResStemCifar,
"res_stem_in": ResStemIN,
"simple_stem_in": SimpleStemIN,
}
err_str = "Stem type '{}' not supported"
assert stem_type in stem_funs.keys(), err_str.format(stem_type)
return stem_funs[stem_type]
def get_block_fun(block_type):
"""Retrieves the block function by name."""
block_funs = {
"vanilla_block": VanillaBlock,
"res_basic_block": ResBasicBlock,
"res_bottleneck_block": ResBottleneckBlock,
}
err_str = "Block type '{}' not supported"
assert block_type in block_funs.keys(), err_str.format(block_type)
return block_funs[block_type]
class AnyHead(nn.Module):
"""AnyNet head: AvgPool, 1x1."""
def __init__(self, w_in, nc):
super(AnyHead, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(w_in, nc, bias=True)
def forward(self, x):
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
@staticmethod
def complexity(cx, w_in, nc):
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
return cx
class VanillaBlock(nn.Module):
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Vanilla block does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
super(VanillaBlock, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Vanilla block does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class BasicTransform(nn.Module):
"""Basic transformation: [3x3 conv, BN, Relu] x2."""
def __init__(self, w_in, w_out, stride):
super(BasicTransform, self).__init__()
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBasicBlock(nn.Module):
"""Residual basic block: x + F(x), F = basic transform."""
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Basic transform does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
super(ResBasicBlock, self).__init__()
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = BasicTransform(w_in, w_out, stride)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
err_str = "Basic transform does not support bm, gw, and se_r options"
assert bm is None and gw is None and se_r is None, err_str
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = BasicTransform.complexity(cx, w_in, w_out, stride)
return cx
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, 1, bias=True),
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
nn.Conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
@staticmethod
def complexity(cx, w_in, w_se):
h, w = cx["h"], cx["w"]
cx["h"], cx["w"] = 1, 1
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
cx["h"], cx["w"] = h, w
return cx
class BottleneckTransform(nn.Module):
"""Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
def __init__(self, w_in, w_out, stride, bm, gw, se_r):
super(BottleneckTransform, self).__init__()
w_b = int(round(w_out * bm))
g = w_b // gw
self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False)
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False)
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
if se_r:
w_se = int(round(w_in * se_r))
self.se = SE(w_b, w_se)
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.c_bn.final_bn = True
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm, gw, se_r):
w_b = int(round(w_out * bm))
g = w_b // gw
cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_b)
cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g)
cx = net.complexity_batchnorm2d(cx, w_b)
if se_r:
w_se = int(round(w_in * se_r))
cx = SE.complexity(cx, w_b, w_se)
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResBottleneckBlock(nn.Module):
"""Residual bottleneck block: x + F(x), F = bottleneck transform."""
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
super(ResBottleneckBlock, self).__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
proj_block = (w_in != w_out) or (stride != 1)
if proj_block:
h, w = cx["h"], cx["w"]
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
cx = net.complexity_batchnorm2d(cx, w_out)
cx["h"], cx["w"] = h, w # parallel branch
cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r)
return cx
class ResStemCifar(nn.Module):
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(ResStemCifar, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
def __init__(self, w_in, w_out):
super(ResStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
cx = net.complexity_batchnorm2d(cx, w_out)
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
return cx
class SimpleStemIN(nn.Module):
"""Simple stem for ImageNet: 3x3, BN, ReLU."""
def __init__(self, w_in, w_out):
super(SimpleStemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
@staticmethod
def complexity(cx, w_in, w_out):
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
cx = net.complexity_batchnorm2d(cx, w_out)
return cx
class AnyStage(nn.Module):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
super(AnyStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
name = "b{}".format(i + 1)
self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r))
def forward(self, x):
for block in self.children():
x = block(x)
return x
@staticmethod
def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r)
return cx
class AnyNet(nn.Module):
"""AnyNet model."""
@staticmethod
def get_args():
return {
"stem_type": cfg.ANYNET.STEM_TYPE,
"stem_w": cfg.ANYNET.STEM_W,
"block_type": cfg.ANYNET.BLOCK_TYPE,
"ds": cfg.ANYNET.DEPTHS,
"ws": cfg.ANYNET.WIDTHS,
"ss": cfg.ANYNET.STRIDES,
"bms": cfg.ANYNET.BOT_MULS,
"gws": cfg.ANYNET.GROUP_WS,
"se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
"nc": cfg.MODEL.NUM_CLASSES,
}
def __init__(self, **kwargs):
super(AnyNet, self).__init__()
kwargs = self.get_args() if not kwargs else kwargs
#print(kwargs)
self._construct(**kwargs)
self.apply(net.init_weights)
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
# Generate dummy bot muls and gs for models that do not use them
bms = bms if bms else [None for _d in ds]
gws = gws if gws else [None for _d in ds]
stage_params = list(zip(ds, ws, ss, bms, gws))
stem_fun = get_stem_fun(stem_type)
self.stem = stem_fun(3, stem_w)
block_fun = get_block_fun(block_type)
prev_w = stem_w
for i, (d, w, s, bm, gw) in enumerate(stage_params):
name = "s{}".format(i + 1)
self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r))
prev_w = w
self.head = AnyHead(w_in=prev_w, nc=nc)
def forward(self, x, get_ints=False):
for module in self.children():
x = module(x)
return x
@staticmethod
def complexity(cx, **kwargs):
"""Computes model complexity. If you alter the model, make sure to update."""
kwargs = AnyNet.get_args() if not kwargs else kwargs
return AnyNet._complexity(cx, **kwargs)
@staticmethod
def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
bms = bms if bms else [None for _d in ds]
gws = gws if gws else [None for _d in ds]
stage_params = list(zip(ds, ws, ss, bms, gws))
stem_fun = get_stem_fun(stem_type)
cx = stem_fun.complexity(cx, 3, stem_w)
block_fun = get_block_fun(block_type)
prev_w = stem_w
for d, w, s, bm, gw in stage_params:
cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r)
prev_w = w
cx = AnyHead.complexity(cx, prev_w, nc)
return cx