407 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			407 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| 
 | |
| # Copyright (c) Facebook, Inc. and its affiliates.
 | |
| #
 | |
| # This source code is licensed under the MIT license found in the
 | |
| # LICENSE file in the root directory of this source tree.
 | |
| 
 | |
| """AnyNet models."""
 | |
| 
 | |
| import pycls.core.net as net
 | |
| import torch.nn as nn
 | |
| from pycls.core.config import cfg
 | |
| 
 | |
| 
 | |
| def get_stem_fun(stem_type):
 | |
|     """Retrieves the stem function by name."""
 | |
|     stem_funs = {
 | |
|         "res_stem_cifar": ResStemCifar,
 | |
|         "res_stem_in": ResStemIN,
 | |
|         "simple_stem_in": SimpleStemIN,
 | |
|     }
 | |
|     err_str = "Stem type '{}' not supported"
 | |
|     assert stem_type in stem_funs.keys(), err_str.format(stem_type)
 | |
|     return stem_funs[stem_type]
 | |
| 
 | |
| 
 | |
| def get_block_fun(block_type):
 | |
|     """Retrieves the block function by name."""
 | |
|     block_funs = {
 | |
|         "vanilla_block": VanillaBlock,
 | |
|         "res_basic_block": ResBasicBlock,
 | |
|         "res_bottleneck_block": ResBottleneckBlock,
 | |
|     }
 | |
|     err_str = "Block type '{}' not supported"
 | |
|     assert block_type in block_funs.keys(), err_str.format(block_type)
 | |
|     return block_funs[block_type]
 | |
| 
 | |
| 
 | |
| class AnyHead(nn.Module):
 | |
|     """AnyNet head: AvgPool, 1x1."""
 | |
| 
 | |
|     def __init__(self, w_in, nc):
 | |
|         super(AnyHead, self).__init__()
 | |
|         self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
 | |
|         self.fc = nn.Linear(w_in, nc, bias=True)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x = self.avg_pool(x)
 | |
|         x = x.view(x.size(0), -1)
 | |
|         x = self.fc(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, nc):
 | |
|         cx["h"], cx["w"] = 1, 1
 | |
|         cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class VanillaBlock(nn.Module):
 | |
|     """Vanilla block: [3x3 conv, BN, Relu] x2."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
 | |
|         err_str = "Vanilla block does not support bm, gw, and se_r options"
 | |
|         assert bm is None and gw is None and se_r is None, err_str
 | |
|         super(VanillaBlock, self).__init__()
 | |
|         self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
 | |
|         self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
 | |
|         self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
 | |
|         self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         for layer in self.children():
 | |
|             x = layer(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
 | |
|         err_str = "Vanilla block does not support bm, gw, and se_r options"
 | |
|         assert bm is None and gw is None and se_r is None, err_str
 | |
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|         cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class BasicTransform(nn.Module):
 | |
|     """Basic transformation: [3x3 conv, BN, Relu] x2."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out, stride):
 | |
|         super(BasicTransform, self).__init__()
 | |
|         self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
 | |
|         self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
 | |
|         self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
 | |
|         self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.b_bn.final_bn = True
 | |
| 
 | |
|     def forward(self, x):
 | |
|         for layer in self.children():
 | |
|             x = layer(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out, stride):
 | |
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|         cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class ResBasicBlock(nn.Module):
 | |
|     """Residual basic block: x + F(x), F = basic transform."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
 | |
|         err_str = "Basic transform does not support bm, gw, and se_r options"
 | |
|         assert bm is None and gw is None and se_r is None, err_str
 | |
|         super(ResBasicBlock, self).__init__()
 | |
|         self.proj_block = (w_in != w_out) or (stride != 1)
 | |
|         if self.proj_block:
 | |
|             self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
 | |
|             self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.f = BasicTransform(w_in, w_out, stride)
 | |
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         if self.proj_block:
 | |
|             x = self.bn(self.proj(x)) + self.f(x)
 | |
|         else:
 | |
|             x = x + self.f(x)
 | |
|         x = self.relu(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
 | |
|         err_str = "Basic transform does not support bm, gw, and se_r options"
 | |
|         assert bm is None and gw is None and se_r is None, err_str
 | |
|         proj_block = (w_in != w_out) or (stride != 1)
 | |
|         if proj_block:
 | |
|             h, w = cx["h"], cx["w"]
 | |
|             cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
 | |
|             cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|             cx["h"], cx["w"] = h, w  # parallel branch
 | |
|         cx = BasicTransform.complexity(cx, w_in, w_out, stride)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class SE(nn.Module):
 | |
|     """Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid."""
 | |
| 
 | |
|     def __init__(self, w_in, w_se):
 | |
|         super(SE, self).__init__()
 | |
|         self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
 | |
|         self.f_ex = nn.Sequential(
 | |
|             nn.Conv2d(w_in, w_se, 1, bias=True),
 | |
|             nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
 | |
|             nn.Conv2d(w_se, w_in, 1, bias=True),
 | |
|             nn.Sigmoid(),
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return x * self.f_ex(self.avg_pool(x))
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_se):
 | |
|         h, w = cx["h"], cx["w"]
 | |
|         cx["h"], cx["w"] = 1, 1
 | |
|         cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
 | |
|         cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
 | |
|         cx["h"], cx["w"] = h, w
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class BottleneckTransform(nn.Module):
 | |
|     """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out, stride, bm, gw, se_r):
 | |
|         super(BottleneckTransform, self).__init__()
 | |
|         w_b = int(round(w_out * bm))
 | |
|         g = w_b // gw
 | |
|         self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False)
 | |
|         self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
 | |
|         self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False)
 | |
|         self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
 | |
|         if se_r:
 | |
|             w_se = int(round(w_in * se_r))
 | |
|             self.se = SE(w_b, w_se)
 | |
|         self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
 | |
|         self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.c_bn.final_bn = True
 | |
| 
 | |
|     def forward(self, x):
 | |
|         for layer in self.children():
 | |
|             x = layer(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out, stride, bm, gw, se_r):
 | |
|         w_b = int(round(w_out * bm))
 | |
|         g = w_b // gw
 | |
|         cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_b)
 | |
|         cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_b)
 | |
|         if se_r:
 | |
|             w_se = int(round(w_in * se_r))
 | |
|             cx = SE.complexity(cx, w_b, w_se)
 | |
|         cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class ResBottleneckBlock(nn.Module):
 | |
|     """Residual bottleneck block: x + F(x), F = bottleneck transform."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
 | |
|         super(ResBottleneckBlock, self).__init__()
 | |
|         # Use skip connection with projection if shape changes
 | |
|         self.proj_block = (w_in != w_out) or (stride != 1)
 | |
|         if self.proj_block:
 | |
|             self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
 | |
|             self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
 | |
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         if self.proj_block:
 | |
|             x = self.bn(self.proj(x)) + self.f(x)
 | |
|         else:
 | |
|             x = x + self.f(x)
 | |
|         x = self.relu(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
 | |
|         proj_block = (w_in != w_out) or (stride != 1)
 | |
|         if proj_block:
 | |
|             h, w = cx["h"], cx["w"]
 | |
|             cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
 | |
|             cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|             cx["h"], cx["w"] = h, w  # parallel branch
 | |
|         cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class ResStemCifar(nn.Module):
 | |
|     """ResNet stem for CIFAR: 3x3, BN, ReLU."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out):
 | |
|         super(ResStemCifar, self).__init__()
 | |
|         self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
 | |
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         for layer in self.children():
 | |
|             x = layer(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out):
 | |
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class ResStemIN(nn.Module):
 | |
|     """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out):
 | |
|         super(ResStemIN, self).__init__()
 | |
|         self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
 | |
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
 | |
|         self.pool = nn.MaxPool2d(3, stride=2, padding=1)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         for layer in self.children():
 | |
|             x = layer(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out):
 | |
|         cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|         cx = net.complexity_maxpool2d(cx, 3, 2, 1)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class SimpleStemIN(nn.Module):
 | |
|     """Simple stem for ImageNet: 3x3, BN, ReLU."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out):
 | |
|         super(SimpleStemIN, self).__init__()
 | |
|         self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
 | |
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
 | |
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         for layer in self.children():
 | |
|             x = layer(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out):
 | |
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
 | |
|         cx = net.complexity_batchnorm2d(cx, w_out)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class AnyStage(nn.Module):
 | |
|     """AnyNet stage (sequence of blocks w/ the same output shape)."""
 | |
| 
 | |
|     def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
 | |
|         super(AnyStage, self).__init__()
 | |
|         for i in range(d):
 | |
|             b_stride = stride if i == 0 else 1
 | |
|             b_w_in = w_in if i == 0 else w_out
 | |
|             name = "b{}".format(i + 1)
 | |
|             self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r))
 | |
| 
 | |
|     def forward(self, x):
 | |
|         for block in self.children():
 | |
|             x = block(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
 | |
|         for i in range(d):
 | |
|             b_stride = stride if i == 0 else 1
 | |
|             b_w_in = w_in if i == 0 else w_out
 | |
|             cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r)
 | |
|         return cx
 | |
| 
 | |
| 
 | |
| class AnyNet(nn.Module):
 | |
|     """AnyNet model."""
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_args():
 | |
|         return {
 | |
|             "stem_type": cfg.ANYNET.STEM_TYPE,
 | |
|             "stem_w": cfg.ANYNET.STEM_W,
 | |
|             "block_type": cfg.ANYNET.BLOCK_TYPE,
 | |
|             "ds": cfg.ANYNET.DEPTHS,
 | |
|             "ws": cfg.ANYNET.WIDTHS,
 | |
|             "ss": cfg.ANYNET.STRIDES,
 | |
|             "bms": cfg.ANYNET.BOT_MULS,
 | |
|             "gws": cfg.ANYNET.GROUP_WS,
 | |
|             "se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
 | |
|             "nc": cfg.MODEL.NUM_CLASSES,
 | |
|         }
 | |
| 
 | |
|     def __init__(self, **kwargs):
 | |
|         super(AnyNet, self).__init__()
 | |
|         kwargs = self.get_args() if not kwargs else kwargs
 | |
|         #print(kwargs)
 | |
|         self._construct(**kwargs)
 | |
|         self.apply(net.init_weights)
 | |
| 
 | |
|     def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
 | |
|         # Generate dummy bot muls and gs for models that do not use them
 | |
|         bms = bms if bms else [None for _d in ds]
 | |
|         gws = gws if gws else [None for _d in ds]
 | |
|         stage_params = list(zip(ds, ws, ss, bms, gws))
 | |
|         stem_fun = get_stem_fun(stem_type)
 | |
|         self.stem = stem_fun(3, stem_w)
 | |
|         block_fun = get_block_fun(block_type)
 | |
|         prev_w = stem_w
 | |
|         for i, (d, w, s, bm, gw) in enumerate(stage_params):
 | |
|             name = "s{}".format(i + 1)
 | |
|             self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r))
 | |
|             prev_w = w
 | |
|         self.head = AnyHead(w_in=prev_w, nc=nc)
 | |
| 
 | |
|     def forward(self, x, get_ints=False):
 | |
|         for module in self.children():
 | |
|             x = module(x)
 | |
|         return x
 | |
| 
 | |
|     @staticmethod
 | |
|     def complexity(cx, **kwargs):
 | |
|         """Computes model complexity. If you alter the model, make sure to update."""
 | |
|         kwargs = AnyNet.get_args() if not kwargs else kwargs
 | |
|         return AnyNet._complexity(cx, **kwargs)
 | |
| 
 | |
|     @staticmethod
 | |
|     def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
 | |
|         bms = bms if bms else [None for _d in ds]
 | |
|         gws = gws if gws else [None for _d in ds]
 | |
|         stage_params = list(zip(ds, ws, ss, bms, gws))
 | |
|         stem_fun = get_stem_fun(stem_type)
 | |
|         cx = stem_fun.complexity(cx, 3, stem_w)
 | |
|         block_fun = get_block_fun(block_type)
 | |
|         prev_w = stem_w
 | |
|         for d, w, s, bm, gw in stage_params:
 | |
|             cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r)
 | |
|             prev_w = w
 | |
|         cx = AnyHead.complexity(cx, prev_w, nc)
 | |
|         return cx
 |