#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """AnyNet models.""" import pycls.core.net as net import torch.nn as nn from pycls.core.config import cfg def get_stem_fun(stem_type): """Retrieves the stem function by name.""" stem_funs = { "res_stem_cifar": ResStemCifar, "res_stem_in": ResStemIN, "simple_stem_in": SimpleStemIN, } err_str = "Stem type '{}' not supported" assert stem_type in stem_funs.keys(), err_str.format(stem_type) return stem_funs[stem_type] def get_block_fun(block_type): """Retrieves the block function by name.""" block_funs = { "vanilla_block": VanillaBlock, "res_basic_block": ResBasicBlock, "res_bottleneck_block": ResBottleneckBlock, } err_str = "Block type '{}' not supported" assert block_type in block_funs.keys(), err_str.format(block_type) return block_funs[block_type] class AnyHead(nn.Module): """AnyNet head: AvgPool, 1x1.""" def __init__(self, w_in, nc): super(AnyHead, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(w_in, nc, bias=True) def forward(self, x): x = self.avg_pool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x @staticmethod def complexity(cx, w_in, nc): cx["h"], cx["w"] = 1, 1 cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True) return cx class VanillaBlock(nn.Module): """Vanilla block: [3x3 conv, BN, Relu] x2.""" def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None): err_str = "Vanilla block does not support bm, gw, and se_r options" assert bm is None and gw is None and se_r is None, err_str super(VanillaBlock, self).__init__() self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) def forward(self, x): for layer in self.children(): x = layer(x) return x @staticmethod def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None): err_str = "Vanilla block does not support bm, gw, and se_r options" assert bm is None and gw is None and se_r is None, err_str cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1) cx = net.complexity_batchnorm2d(cx, w_out) cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1) cx = net.complexity_batchnorm2d(cx, w_out) return cx class BasicTransform(nn.Module): """Basic transformation: [3x3 conv, BN, Relu] x2.""" def __init__(self, w_in, w_out, stride): super(BasicTransform, self).__init__() self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.b_bn.final_bn = True def forward(self, x): for layer in self.children(): x = layer(x) return x @staticmethod def complexity(cx, w_in, w_out, stride): cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1) cx = net.complexity_batchnorm2d(cx, w_out) cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1) cx = net.complexity_batchnorm2d(cx, w_out) return cx class ResBasicBlock(nn.Module): """Residual basic block: x + F(x), F = basic transform.""" def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None): err_str = "Basic transform does not support bm, gw, and se_r options" assert bm is None and gw is None and se_r is None, err_str super(ResBasicBlock, self).__init__() self.proj_block = (w_in != w_out) or (stride != 1) if self.proj_block: self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.f = BasicTransform(w_in, w_out, stride) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): if self.proj_block: x = self.bn(self.proj(x)) + self.f(x) else: x = x + self.f(x) x = self.relu(x) return x @staticmethod def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None): err_str = "Basic transform does not support bm, gw, and se_r options" assert bm is None and gw is None and se_r is None, err_str proj_block = (w_in != w_out) or (stride != 1) if proj_block: h, w = cx["h"], cx["w"] cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) cx = net.complexity_batchnorm2d(cx, w_out) cx["h"], cx["w"] = h, w # parallel branch cx = BasicTransform.complexity(cx, w_in, w_out, stride) return cx class SE(nn.Module): """Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid.""" def __init__(self, w_in, w_se): super(SE, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.f_ex = nn.Sequential( nn.Conv2d(w_in, w_se, 1, bias=True), nn.ReLU(inplace=cfg.MEM.RELU_INPLACE), nn.Conv2d(w_se, w_in, 1, bias=True), nn.Sigmoid(), ) def forward(self, x): return x * self.f_ex(self.avg_pool(x)) @staticmethod def complexity(cx, w_in, w_se): h, w = cx["h"], cx["w"] cx["h"], cx["w"] = 1, 1 cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True) cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True) cx["h"], cx["w"] = h, w return cx class BottleneckTransform(nn.Module): """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1.""" def __init__(self, w_in, w_out, stride, bm, gw, se_r): super(BottleneckTransform, self).__init__() w_b = int(round(w_out * bm)) g = w_b // gw self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False) self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False) self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) if se_r: w_se = int(round(w_in * se_r)) self.se = SE(w_b, w_se) self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False) self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.c_bn.final_bn = True def forward(self, x): for layer in self.children(): x = layer(x) return x @staticmethod def complexity(cx, w_in, w_out, stride, bm, gw, se_r): w_b = int(round(w_out * bm)) g = w_b // gw cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0) cx = net.complexity_batchnorm2d(cx, w_b) cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g) cx = net.complexity_batchnorm2d(cx, w_b) if se_r: w_se = int(round(w_in * se_r)) cx = SE.complexity(cx, w_b, w_se) cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0) cx = net.complexity_batchnorm2d(cx, w_out) return cx class ResBottleneckBlock(nn.Module): """Residual bottleneck block: x + F(x), F = bottleneck transform.""" def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): super(ResBottleneckBlock, self).__init__() # Use skip connection with projection if shape changes self.proj_block = (w_in != w_out) or (stride != 1) if self.proj_block: self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): if self.proj_block: x = self.bn(self.proj(x)) + self.f(x) else: x = x + self.f(x) x = self.relu(x) return x @staticmethod def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): proj_block = (w_in != w_out) or (stride != 1) if proj_block: h, w = cx["h"], cx["w"] cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) cx = net.complexity_batchnorm2d(cx, w_out) cx["h"], cx["w"] = h, w # parallel branch cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r) return cx class ResStemCifar(nn.Module): """ResNet stem for CIFAR: 3x3, BN, ReLU.""" def __init__(self, w_in, w_out): super(ResStemCifar, self).__init__() self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): for layer in self.children(): x = layer(x) return x @staticmethod def complexity(cx, w_in, w_out): cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1) cx = net.complexity_batchnorm2d(cx, w_out) return cx class ResStemIN(nn.Module): """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool.""" def __init__(self, w_in, w_out): super(ResStemIN, self).__init__() self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) self.pool = nn.MaxPool2d(3, stride=2, padding=1) def forward(self, x): for layer in self.children(): x = layer(x) return x @staticmethod def complexity(cx, w_in, w_out): cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3) cx = net.complexity_batchnorm2d(cx, w_out) cx = net.complexity_maxpool2d(cx, 3, 2, 1) return cx class SimpleStemIN(nn.Module): """Simple stem for ImageNet: 3x3, BN, ReLU.""" def __init__(self, w_in, w_out): super(SimpleStemIN, self).__init__() self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False) self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) def forward(self, x): for layer in self.children(): x = layer(x) return x @staticmethod def complexity(cx, w_in, w_out): cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1) cx = net.complexity_batchnorm2d(cx, w_out) return cx class AnyStage(nn.Module): """AnyNet stage (sequence of blocks w/ the same output shape).""" def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): super(AnyStage, self).__init__() for i in range(d): b_stride = stride if i == 0 else 1 b_w_in = w_in if i == 0 else w_out name = "b{}".format(i + 1) self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r)) def forward(self, x): for block in self.children(): x = block(x) return x @staticmethod def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r): for i in range(d): b_stride = stride if i == 0 else 1 b_w_in = w_in if i == 0 else w_out cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r) return cx class AnyNet(nn.Module): """AnyNet model.""" @staticmethod def get_args(): return { "stem_type": cfg.ANYNET.STEM_TYPE, "stem_w": cfg.ANYNET.STEM_W, "block_type": cfg.ANYNET.BLOCK_TYPE, "ds": cfg.ANYNET.DEPTHS, "ws": cfg.ANYNET.WIDTHS, "ss": cfg.ANYNET.STRIDES, "bms": cfg.ANYNET.BOT_MULS, "gws": cfg.ANYNET.GROUP_WS, "se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None, "nc": cfg.MODEL.NUM_CLASSES, } def __init__(self, **kwargs): super(AnyNet, self).__init__() kwargs = self.get_args() if not kwargs else kwargs #print(kwargs) self._construct(**kwargs) self.apply(net.init_weights) def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc): # Generate dummy bot muls and gs for models that do not use them bms = bms if bms else [None for _d in ds] gws = gws if gws else [None for _d in ds] stage_params = list(zip(ds, ws, ss, bms, gws)) stem_fun = get_stem_fun(stem_type) self.stem = stem_fun(3, stem_w) block_fun = get_block_fun(block_type) prev_w = stem_w for i, (d, w, s, bm, gw) in enumerate(stage_params): name = "s{}".format(i + 1) self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r)) prev_w = w self.head = AnyHead(w_in=prev_w, nc=nc) def forward(self, x, get_ints=False): for module in self.children(): x = module(x) return x @staticmethod def complexity(cx, **kwargs): """Computes model complexity. If you alter the model, make sure to update.""" kwargs = AnyNet.get_args() if not kwargs else kwargs return AnyNet._complexity(cx, **kwargs) @staticmethod def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc): bms = bms if bms else [None for _d in ds] gws = gws if gws else [None for _d in ds] stage_params = list(zip(ds, ws, ss, bms, gws)) stem_fun = get_stem_fun(stem_type) cx = stem_fun.complexity(cx, 3, stem_w) block_fun = get_block_fun(block_type) prev_w = stem_w for d, w, s, bm, gw in stage_params: cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r) prev_w = w cx = AnyHead.complexity(cx, prev_w, nc) return cx