v2
This commit is contained in:
		
							
								
								
									
										0
									
								
								pycls/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								pycls/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										406
									
								
								pycls/models/anynet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										406
									
								
								pycls/models/anynet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,406 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """AnyNet models.""" | ||||
|  | ||||
| import pycls.core.net as net | ||||
| import torch.nn as nn | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| def get_stem_fun(stem_type): | ||||
|     """Retrieves the stem function by name.""" | ||||
|     stem_funs = { | ||||
|         "res_stem_cifar": ResStemCifar, | ||||
|         "res_stem_in": ResStemIN, | ||||
|         "simple_stem_in": SimpleStemIN, | ||||
|     } | ||||
|     err_str = "Stem type '{}' not supported" | ||||
|     assert stem_type in stem_funs.keys(), err_str.format(stem_type) | ||||
|     return stem_funs[stem_type] | ||||
|  | ||||
|  | ||||
| def get_block_fun(block_type): | ||||
|     """Retrieves the block function by name.""" | ||||
|     block_funs = { | ||||
|         "vanilla_block": VanillaBlock, | ||||
|         "res_basic_block": ResBasicBlock, | ||||
|         "res_bottleneck_block": ResBottleneckBlock, | ||||
|     } | ||||
|     err_str = "Block type '{}' not supported" | ||||
|     assert block_type in block_funs.keys(), err_str.format(block_type) | ||||
|     return block_funs[block_type] | ||||
|  | ||||
|  | ||||
| class AnyHead(nn.Module): | ||||
|     """AnyNet head: AvgPool, 1x1.""" | ||||
|  | ||||
|     def __init__(self, w_in, nc): | ||||
|         super(AnyHead, self).__init__() | ||||
|         self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.fc = nn.Linear(w_in, nc, bias=True) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.avg_pool(x) | ||||
|         x = x.view(x.size(0), -1) | ||||
|         x = self.fc(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, nc): | ||||
|         cx["h"], cx["w"] = 1, 1 | ||||
|         cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class VanillaBlock(nn.Module): | ||||
|     """Vanilla block: [3x3 conv, BN, Relu] x2.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None): | ||||
|         err_str = "Vanilla block does not support bm, gw, and se_r options" | ||||
|         assert bm is None and gw is None and se_r is None, err_str | ||||
|         super(VanillaBlock, self).__init__() | ||||
|         self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) | ||||
|         self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||||
|         self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) | ||||
|         self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None): | ||||
|         err_str = "Vanilla block does not support bm, gw, and se_r options" | ||||
|         assert bm is None and gw is None and se_r is None, err_str | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class BasicTransform(nn.Module): | ||||
|     """Basic transformation: [3x3 conv, BN, Relu] x2.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride): | ||||
|         super(BasicTransform, self).__init__() | ||||
|         self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) | ||||
|         self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||||
|         self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) | ||||
|         self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.b_bn.final_bn = True | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride): | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResBasicBlock(nn.Module): | ||||
|     """Residual basic block: x + F(x), F = basic transform.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None): | ||||
|         err_str = "Basic transform does not support bm, gw, and se_r options" | ||||
|         assert bm is None and gw is None and se_r is None, err_str | ||||
|         super(ResBasicBlock, self).__init__() | ||||
|         self.proj_block = (w_in != w_out) or (stride != 1) | ||||
|         if self.proj_block: | ||||
|             self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) | ||||
|             self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.f = BasicTransform(w_in, w_out, stride) | ||||
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.proj_block: | ||||
|             x = self.bn(self.proj(x)) + self.f(x) | ||||
|         else: | ||||
|             x = x + self.f(x) | ||||
|         x = self.relu(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None): | ||||
|         err_str = "Basic transform does not support bm, gw, and se_r options" | ||||
|         assert bm is None and gw is None and se_r is None, err_str | ||||
|         proj_block = (w_in != w_out) or (stride != 1) | ||||
|         if proj_block: | ||||
|             h, w = cx["h"], cx["w"] | ||||
|             cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) | ||||
|             cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|             cx["h"], cx["w"] = h, w  # parallel branch | ||||
|         cx = BasicTransform.complexity(cx, w_in, w_out, stride) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class SE(nn.Module): | ||||
|     """Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_se): | ||||
|         super(SE, self).__init__() | ||||
|         self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.f_ex = nn.Sequential( | ||||
|             nn.Conv2d(w_in, w_se, 1, bias=True), | ||||
|             nn.ReLU(inplace=cfg.MEM.RELU_INPLACE), | ||||
|             nn.Conv2d(w_se, w_in, 1, bias=True), | ||||
|             nn.Sigmoid(), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return x * self.f_ex(self.avg_pool(x)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_se): | ||||
|         h, w = cx["h"], cx["w"] | ||||
|         cx["h"], cx["w"] = 1, 1 | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True) | ||||
|         cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True) | ||||
|         cx["h"], cx["w"] = h, w | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class BottleneckTransform(nn.Module): | ||||
|     """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, bm, gw, se_r): | ||||
|         super(BottleneckTransform, self).__init__() | ||||
|         w_b = int(round(w_out * bm)) | ||||
|         g = w_b // gw | ||||
|         self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False) | ||||
|         self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||||
|         self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False) | ||||
|         self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||||
|         if se_r: | ||||
|             w_se = int(round(w_in * se_r)) | ||||
|             self.se = SE(w_b, w_se) | ||||
|         self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False) | ||||
|         self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.c_bn.final_bn = True | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, bm, gw, se_r): | ||||
|         w_b = int(round(w_out * bm)) | ||||
|         g = w_b // gw | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_b) | ||||
|         cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_b) | ||||
|         if se_r: | ||||
|             w_se = int(round(w_in * se_r)) | ||||
|             cx = SE.complexity(cx, w_b, w_se) | ||||
|         cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResBottleneckBlock(nn.Module): | ||||
|     """Residual bottleneck block: x + F(x), F = bottleneck transform.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): | ||||
|         super(ResBottleneckBlock, self).__init__() | ||||
|         # Use skip connection with projection if shape changes | ||||
|         self.proj_block = (w_in != w_out) or (stride != 1) | ||||
|         if self.proj_block: | ||||
|             self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) | ||||
|             self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r) | ||||
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.proj_block: | ||||
|             x = self.bn(self.proj(x)) + self.f(x) | ||||
|         else: | ||||
|             x = x + self.f(x) | ||||
|         x = self.relu(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): | ||||
|         proj_block = (w_in != w_out) or (stride != 1) | ||||
|         if proj_block: | ||||
|             h, w = cx["h"], cx["w"] | ||||
|             cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) | ||||
|             cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|             cx["h"], cx["w"] = h, w  # parallel branch | ||||
|         cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResStemCifar(nn.Module): | ||||
|     """ResNet stem for CIFAR: 3x3, BN, ReLU.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out): | ||||
|         super(ResStemCifar, self).__init__() | ||||
|         self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False) | ||||
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out): | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResStemIN(nn.Module): | ||||
|     """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out): | ||||
|         super(ResStemIN, self).__init__() | ||||
|         self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False) | ||||
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||||
|         self.pool = nn.MaxPool2d(3, stride=2, padding=1) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out): | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         cx = net.complexity_maxpool2d(cx, 3, 2, 1) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class SimpleStemIN(nn.Module): | ||||
|     """Simple stem for ImageNet: 3x3, BN, ReLU.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out): | ||||
|         super(SimpleStemIN, self).__init__() | ||||
|         self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False) | ||||
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out): | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class AnyStage(nn.Module): | ||||
|     """AnyNet stage (sequence of blocks w/ the same output shape).""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): | ||||
|         super(AnyStage, self).__init__() | ||||
|         for i in range(d): | ||||
|             b_stride = stride if i == 0 else 1 | ||||
|             b_w_in = w_in if i == 0 else w_out | ||||
|             name = "b{}".format(i + 1) | ||||
|             self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r)) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for block in self.children(): | ||||
|             x = block(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r): | ||||
|         for i in range(d): | ||||
|             b_stride = stride if i == 0 else 1 | ||||
|             b_w_in = w_in if i == 0 else w_out | ||||
|             cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class AnyNet(nn.Module): | ||||
|     """AnyNet model.""" | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_args(): | ||||
|         return { | ||||
|             "stem_type": cfg.ANYNET.STEM_TYPE, | ||||
|             "stem_w": cfg.ANYNET.STEM_W, | ||||
|             "block_type": cfg.ANYNET.BLOCK_TYPE, | ||||
|             "ds": cfg.ANYNET.DEPTHS, | ||||
|             "ws": cfg.ANYNET.WIDTHS, | ||||
|             "ss": cfg.ANYNET.STRIDES, | ||||
|             "bms": cfg.ANYNET.BOT_MULS, | ||||
|             "gws": cfg.ANYNET.GROUP_WS, | ||||
|             "se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None, | ||||
|             "nc": cfg.MODEL.NUM_CLASSES, | ||||
|         } | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         super(AnyNet, self).__init__() | ||||
|         kwargs = self.get_args() if not kwargs else kwargs | ||||
|         #print(kwargs) | ||||
|         self._construct(**kwargs) | ||||
|         self.apply(net.init_weights) | ||||
|  | ||||
|     def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc): | ||||
|         # Generate dummy bot muls and gs for models that do not use them | ||||
|         bms = bms if bms else [None for _d in ds] | ||||
|         gws = gws if gws else [None for _d in ds] | ||||
|         stage_params = list(zip(ds, ws, ss, bms, gws)) | ||||
|         stem_fun = get_stem_fun(stem_type) | ||||
|         self.stem = stem_fun(3, stem_w) | ||||
|         block_fun = get_block_fun(block_type) | ||||
|         prev_w = stem_w | ||||
|         for i, (d, w, s, bm, gw) in enumerate(stage_params): | ||||
|             name = "s{}".format(i + 1) | ||||
|             self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r)) | ||||
|             prev_w = w | ||||
|         self.head = AnyHead(w_in=prev_w, nc=nc) | ||||
|  | ||||
|     def forward(self, x, get_ints=False): | ||||
|         for module in self.children(): | ||||
|             x = module(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, **kwargs): | ||||
|         """Computes model complexity. If you alter the model, make sure to update.""" | ||||
|         kwargs = AnyNet.get_args() if not kwargs else kwargs | ||||
|         return AnyNet._complexity(cx, **kwargs) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc): | ||||
|         bms = bms if bms else [None for _d in ds] | ||||
|         gws = gws if gws else [None for _d in ds] | ||||
|         stage_params = list(zip(ds, ws, ss, bms, gws)) | ||||
|         stem_fun = get_stem_fun(stem_type) | ||||
|         cx = stem_fun.complexity(cx, 3, stem_w) | ||||
|         block_fun = get_block_fun(block_type) | ||||
|         prev_w = stem_w | ||||
|         for d, w, s, bm, gw in stage_params: | ||||
|             cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r) | ||||
|             prev_w = w | ||||
|         cx = AnyHead.complexity(cx, prev_w, nc) | ||||
|         return cx | ||||
							
								
								
									
										108
									
								
								pycls/models/common.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								pycls/models/common.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,108 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| def Preprocess(x): | ||||
|     if cfg.TASK == 'jig': | ||||
|         assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw' | ||||
|         assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw' | ||||
|         x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]]) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Classifier(nn.Module): | ||||
|     def __init__(self, channels, num_classes): | ||||
|         super(Classifier, self).__init__() | ||||
|         if cfg.TASK == 'jig': | ||||
|             self.jig_sq = cfg.JIGSAW_GRID ** 2 | ||||
|             self.pooling = nn.AdaptiveAvgPool2d(1) | ||||
|             self.classifier = nn.Linear(channels * self.jig_sq, num_classes) | ||||
|         elif cfg.TASK == 'col': | ||||
|             self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1) | ||||
|         elif cfg.TASK == 'seg': | ||||
|             self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES) | ||||
|         else: | ||||
|             self.pooling = nn.AdaptiveAvgPool2d(1) | ||||
|             self.classifier = nn.Linear(channels, num_classes) | ||||
|  | ||||
|     def forward(self, x, shape): | ||||
|         if cfg.TASK == 'jig': | ||||
|             x = self.pooling(x) | ||||
|             x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]]) | ||||
|             x = self.classifier(x.view(x.size(0), -1)) | ||||
|         elif cfg.TASK in ['col', 'seg']: | ||||
|             x = self.classifier(x) | ||||
|             x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x) | ||||
|         else: | ||||
|             x = self.pooling(x) | ||||
|             x = self.classifier(x.view(x.size(0), -1)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class ASPP(nn.Module): | ||||
|     def __init__(self, in_channels, out_channels, num_classes, rates): | ||||
|         super(ASPP, self).__init__() | ||||
|         assert len(rates) in [1, 3] | ||||
|         self.rates = rates | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.aspp1 = nn.Sequential( | ||||
|             nn.Conv2d(in_channels, out_channels, 1, bias=False), | ||||
|             nn.BatchNorm2d(out_channels), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.aspp2 = nn.Sequential( | ||||
|             nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0], | ||||
|                 padding=rates[0], bias=False), | ||||
|             nn.BatchNorm2d(out_channels), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         if len(self.rates) == 3: | ||||
|             self.aspp3 = nn.Sequential( | ||||
|                 nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1], | ||||
|                     padding=rates[1], bias=False), | ||||
|                 nn.BatchNorm2d(out_channels), | ||||
|                 nn.ReLU(inplace=True) | ||||
|             ) | ||||
|             self.aspp4 = nn.Sequential( | ||||
|                 nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2], | ||||
|                     padding=rates[2], bias=False), | ||||
|                 nn.BatchNorm2d(out_channels), | ||||
|                 nn.ReLU(inplace=True) | ||||
|             ) | ||||
|         self.aspp5 = nn.Sequential( | ||||
|             nn.Conv2d(in_channels, out_channels, 1, bias=False), | ||||
|             nn.BatchNorm2d(out_channels), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.classifier = nn.Sequential( | ||||
|             nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, | ||||
|                 bias=False), | ||||
|             nn.BatchNorm2d(out_channels), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(out_channels, num_classes, 1) | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x1 = self.aspp1(x) | ||||
|         x2 = self.aspp2(x) | ||||
|         x5 = self.global_pooling(x) | ||||
|         x5 = self.aspp5(x5) | ||||
|         x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', | ||||
|                 align_corners=True)(x5) | ||||
|         if len(self.rates) == 3: | ||||
|             x3 = self.aspp3(x) | ||||
|             x4 = self.aspp4(x) | ||||
|             x = torch.cat((x1, x2, x3, x4, x5), 1) | ||||
|         else: | ||||
|             x = torch.cat((x1, x2, x5), 1) | ||||
|         x = self.classifier(x) | ||||
|         return x | ||||
							
								
								
									
										232
									
								
								pycls/models/effnet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										232
									
								
								pycls/models/effnet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,232 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """EfficientNet models.""" | ||||
|  | ||||
| import pycls.core.net as net | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| class EffHead(nn.Module): | ||||
|     """EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, nc): | ||||
|         super(EffHead, self).__init__() | ||||
|         self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False) | ||||
|         self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.conv_swish = Swish() | ||||
|         self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         if cfg.EN.DROPOUT_RATIO > 0.0: | ||||
|             self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO) | ||||
|         self.fc = nn.Linear(w_out, nc, bias=True) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.conv_swish(self.conv_bn(self.conv(x))) | ||||
|         x = self.avg_pool(x) | ||||
|         x = x.view(x.size(0), -1) | ||||
|         x = self.dropout(x) if hasattr(self, "dropout") else x | ||||
|         x = self.fc(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, nc): | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         cx["h"], cx["w"] = 1, 1 | ||||
|         cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class Swish(nn.Module): | ||||
|     """Swish activation function: x * sigmoid(x).""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super(Swish, self).__init__() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return x * torch.sigmoid(x) | ||||
|  | ||||
|  | ||||
| class SE(nn.Module): | ||||
|     """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_se): | ||||
|         super(SE, self).__init__() | ||||
|         self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.f_ex = nn.Sequential( | ||||
|             nn.Conv2d(w_in, w_se, 1, bias=True), | ||||
|             Swish(), | ||||
|             nn.Conv2d(w_se, w_in, 1, bias=True), | ||||
|             nn.Sigmoid(), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return x * self.f_ex(self.avg_pool(x)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_se): | ||||
|         h, w = cx["h"], cx["w"] | ||||
|         cx["h"], cx["w"] = 1, 1 | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True) | ||||
|         cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True) | ||||
|         cx["h"], cx["w"] = h, w | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class MBConv(nn.Module): | ||||
|     """Mobile inverted bottleneck block w/ SE (MBConv).""" | ||||
|  | ||||
|     def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out): | ||||
|         # expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection | ||||
|         super(MBConv, self).__init__() | ||||
|         self.exp = None | ||||
|         w_exp = int(w_in * exp_r) | ||||
|         if w_exp != w_in: | ||||
|             self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False) | ||||
|             self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|             self.exp_swish = Swish() | ||||
|         dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False} | ||||
|         self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args) | ||||
|         self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.dwise_swish = Swish() | ||||
|         self.se = SE(w_exp, int(w_in * se_r)) | ||||
|         self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False) | ||||
|         self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         # Skip connection if in and out shapes are the same (MN-V2 style) | ||||
|         self.has_skip = stride == 1 and w_in == w_out | ||||
|  | ||||
|     def forward(self, x): | ||||
|         f_x = x | ||||
|         if self.exp: | ||||
|             f_x = self.exp_swish(self.exp_bn(self.exp(f_x))) | ||||
|         f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x))) | ||||
|         f_x = self.se(f_x) | ||||
|         f_x = self.lin_proj_bn(self.lin_proj(f_x)) | ||||
|         if self.has_skip: | ||||
|             if self.training and cfg.EN.DC_RATIO > 0.0: | ||||
|                 f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO) | ||||
|             f_x = x + f_x | ||||
|         return f_x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out): | ||||
|         w_exp = int(w_in * exp_r) | ||||
|         if w_exp != w_in: | ||||
|             cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0) | ||||
|             cx = net.complexity_batchnorm2d(cx, w_exp) | ||||
|         padding = (kernel - 1) // 2 | ||||
|         cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_exp) | ||||
|         cx = SE.complexity(cx, w_exp, int(w_in * se_r)) | ||||
|         cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class EffStage(nn.Module): | ||||
|     """EfficientNet stage.""" | ||||
|  | ||||
|     def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d): | ||||
|         super(EffStage, self).__init__() | ||||
|         for i in range(d): | ||||
|             b_stride = stride if i == 0 else 1 | ||||
|             b_w_in = w_in if i == 0 else w_out | ||||
|             name = "b{}".format(i + 1) | ||||
|             self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out)) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for block in self.children(): | ||||
|             x = block(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d): | ||||
|         for i in range(d): | ||||
|             b_stride = stride if i == 0 else 1 | ||||
|             b_w_in = w_in if i == 0 else w_out | ||||
|             cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class StemIN(nn.Module): | ||||
|     """EfficientNet stem for ImageNet: 3x3, BN, Swish.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out): | ||||
|         super(StemIN, self).__init__() | ||||
|         self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False) | ||||
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.swish = Swish() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out): | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class EffNet(nn.Module): | ||||
|     """EfficientNet model.""" | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_args(): | ||||
|         return { | ||||
|             "stem_w": cfg.EN.STEM_W, | ||||
|             "ds": cfg.EN.DEPTHS, | ||||
|             "ws": cfg.EN.WIDTHS, | ||||
|             "exp_rs": cfg.EN.EXP_RATIOS, | ||||
|             "se_r": cfg.EN.SE_R, | ||||
|             "ss": cfg.EN.STRIDES, | ||||
|             "ks": cfg.EN.KERNELS, | ||||
|             "head_w": cfg.EN.HEAD_W, | ||||
|             "nc": cfg.MODEL.NUM_CLASSES, | ||||
|         } | ||||
|  | ||||
|     def __init__(self): | ||||
|         err_str = "Dataset {} is not supported" | ||||
|         assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET) | ||||
|         assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET) | ||||
|         super(EffNet, self).__init__() | ||||
|         self._construct(**EffNet.get_args()) | ||||
|         self.apply(net.init_weights) | ||||
|  | ||||
|     def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc): | ||||
|         stage_params = list(zip(ds, ws, exp_rs, ss, ks)) | ||||
|         self.stem = StemIN(3, stem_w) | ||||
|         prev_w = stem_w | ||||
|         for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params): | ||||
|             name = "s{}".format(i + 1) | ||||
|             self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d)) | ||||
|             prev_w = w | ||||
|         self.head = EffHead(prev_w, head_w, nc) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for module in self.children(): | ||||
|             x = module(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx): | ||||
|         """Computes model complexity. If you alter the model, make sure to update.""" | ||||
|         return EffNet._complexity(cx, **EffNet.get_args()) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc): | ||||
|         stage_params = list(zip(ds, ws, exp_rs, ss, ks)) | ||||
|         cx = StemIN.complexity(cx, 3, stem_w) | ||||
|         prev_w = stem_w | ||||
|         for d, w, exp_r, stride, kernel in stage_params: | ||||
|             cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d) | ||||
|             prev_w = w | ||||
|         cx = EffHead.complexity(cx, prev_w, head_w, nc) | ||||
|         return cx | ||||
							
								
								
									
										634
									
								
								pycls/models/nas/genotypes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										634
									
								
								pycls/models/nas/genotypes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,634 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """NAS genotypes (adopted from DARTS).""" | ||||
|  | ||||
| from collections import namedtuple | ||||
|  | ||||
|  | ||||
| Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') | ||||
|  | ||||
|  | ||||
| # NASNet ops | ||||
| NASNET_OPS = [ | ||||
|     'skip_connect', | ||||
|     'conv_3x1_1x3', | ||||
|     'conv_7x1_1x7', | ||||
|     'dil_conv_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'max_pool_3x3', | ||||
|     'max_pool_5x5', | ||||
|     'max_pool_7x7', | ||||
|     'conv_1x1', | ||||
|     'conv_3x3', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
| ] | ||||
|  | ||||
| # ENAS ops | ||||
| ENAS_OPS = [ | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'avg_pool_3x3', | ||||
|     'max_pool_3x3', | ||||
| ] | ||||
|  | ||||
| # AmoebaNet ops | ||||
| AMOEBA_OPS = [ | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
|     'avg_pool_3x3', | ||||
|     'max_pool_3x3', | ||||
|     'dil_sep_conv_3x3', | ||||
|     'conv_7x1_1x7', | ||||
| ] | ||||
|  | ||||
| # NAO ops | ||||
| NAO_OPS = [ | ||||
|     'skip_connect', | ||||
|     'conv_1x1', | ||||
|     'conv_3x3', | ||||
|     'conv_3x1_1x3', | ||||
|     'conv_7x1_1x7', | ||||
|     'max_pool_2x2', | ||||
|     'max_pool_3x3', | ||||
|     'max_pool_5x5', | ||||
|     'avg_pool_2x2', | ||||
|     'avg_pool_3x3', | ||||
|     'avg_pool_5x5', | ||||
| ] | ||||
|  | ||||
| # PNAS ops | ||||
| PNAS_OPS = [ | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
|     'conv_7x1_1x7', | ||||
|     'skip_connect', | ||||
|     'avg_pool_3x3', | ||||
|     'max_pool_3x3', | ||||
|     'dil_conv_3x3', | ||||
| ] | ||||
|  | ||||
| # DARTS ops | ||||
| DARTS_OPS = [ | ||||
|     'none', | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_3x3', | ||||
|     'dil_conv_5x5', | ||||
| ] | ||||
|  | ||||
|  | ||||
| NASNet = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 1), | ||||
|     ], | ||||
|     normal_concat=[2, 3, 4, 5, 6], | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_7x7', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_7x7', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('skip_connect', 3), | ||||
|         ('avg_pool_3x3', 2), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('max_pool_3x3', 1), | ||||
|     ], | ||||
|     reduce_concat=[4, 5, 6], | ||||
| ) | ||||
|  | ||||
|  | ||||
| PNASNet = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_7x7', 1), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|     ], | ||||
|     normal_concat=[2, 3, 4, 5, 6], | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_7x7', 1), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|     ], | ||||
|     reduce_concat=[2, 3, 4, 5, 6], | ||||
| ) | ||||
|  | ||||
|  | ||||
| AmoebaNet = Genotype( | ||||
|     normal=[ | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('avg_pool_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|     ], | ||||
|     normal_concat=[4, 5, 6], | ||||
|     reduce=[ | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_7x7', 2), | ||||
|         ('sep_conv_7x7', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('conv_7x1_1x7', 0), | ||||
|         ('sep_conv_3x3', 5), | ||||
|     ], | ||||
|     reduce_concat=[3, 4, 6] | ||||
| ) | ||||
|  | ||||
|  | ||||
| DARTS_V1 = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 2) | ||||
|     ], | ||||
|     normal_concat=[2, 3, 4, 5], | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('skip_connect', 2), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('skip_connect', 2), | ||||
|         ('skip_connect', 2), | ||||
|         ('avg_pool_3x3', 0) | ||||
|     ], | ||||
|     reduce_concat=[2, 3, 4, 5] | ||||
| ) | ||||
|  | ||||
|  | ||||
| DARTS_V2 = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('skip_connect', 0), | ||||
|         ('dil_conv_3x3', 2) | ||||
|     ], | ||||
|     normal_concat=[2, 3, 4, 5], | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('skip_connect', 2), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('skip_connect', 2), | ||||
|         ('skip_connect', 2), | ||||
|         ('max_pool_3x3', 1) | ||||
|     ], | ||||
|     reduce_concat=[2, 3, 4, 5] | ||||
| ) | ||||
|  | ||||
| PDARTS = Genotype( | ||||
|     normal=[ | ||||
|         ('skip_connect', 0), | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('dil_conv_5x5', 4) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('dil_conv_5x5', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| PCDARTS_C10 = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('dil_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| PCDARTS_IN1K = Genotype( | ||||
|     normal=[ | ||||
|         ('skip_connect', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('dil_conv_5x5', 4) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET_CLS = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('dil_conv_5x5', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET_ROT = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('sep_conv_5x5', 2) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET_COL = Genotype( | ||||
|     normal=[ | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 3), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_3x3', 4) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET_JIG = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_5x5', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET22K_CLS = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('max_pool_3x3', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('dil_conv_5x5', 3), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('dil_conv_5x5', 4), | ||||
|         ('dil_conv_5x5', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET22K_ROT = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('dil_conv_5x5', 3), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('sep_conv_3x3', 3) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET22K_COL = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('sep_conv_5x5', 1) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_IMAGENET22K_JIG = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 4) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_5x5', 3), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_5x5', 4) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_CITYSCAPES_SEG = Genotype( | ||||
|     normal=[ | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 4), | ||||
|         ('sep_conv_5x5', 2) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_CITYSCAPES_ROT = Genotype( | ||||
|     normal=[ | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 3), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('max_pool_3x3', 0), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('sep_conv_5x5', 1), | ||||
|         ('sep_conv_5x5', 3), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('sep_conv_5x5', 0) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_CITYSCAPES_COL = Genotype( | ||||
|     normal=[ | ||||
|         ('dil_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_5x5', 2), | ||||
|         ('dil_conv_3x3', 3), | ||||
|         ('skip_connect', 0), | ||||
|         ('skip_connect', 0), | ||||
|         ('sep_conv_3x3', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('avg_pool_3x3', 1), | ||||
|         ('skip_connect', 4) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| UNNAS_CITYSCAPES_JIG = Genotype( | ||||
|     normal=[ | ||||
|         ('dil_conv_5x5', 1), | ||||
|         ('sep_conv_5x5', 0), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 1), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('sep_conv_3x3', 2), | ||||
|         ('sep_conv_3x3', 0), | ||||
|         ('dil_conv_5x5', 1) | ||||
|     ], | ||||
|     normal_concat=range(2, 6), | ||||
|     reduce=[ | ||||
|         ('avg_pool_3x3', 0), | ||||
|         ('skip_connect', 1), | ||||
|         ('dil_conv_5x5', 1), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('dil_conv_5x5', 2), | ||||
|         ('dil_conv_5x5', 0), | ||||
|         ('dil_conv_5x5', 3), | ||||
|         ('dil_conv_5x5', 2) | ||||
|     ], | ||||
|     reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Supported genotypes | ||||
| GENOTYPES = { | ||||
|     'nas': NASNet, | ||||
|     'pnas': PNASNet, | ||||
|     'amoeba': AmoebaNet, | ||||
|     'darts_v1': DARTS_V1, | ||||
|     'darts_v2': DARTS_V2, | ||||
|     'pdarts': PDARTS, | ||||
|     'pcdarts_c10': PCDARTS_C10, | ||||
|     'pcdarts_in1k': PCDARTS_IN1K, | ||||
|     'unnas_imagenet_cls': UNNAS_IMAGENET_CLS, | ||||
|     'unnas_imagenet_rot': UNNAS_IMAGENET_ROT, | ||||
|     'unnas_imagenet_col': UNNAS_IMAGENET_COL, | ||||
|     'unnas_imagenet_jig': UNNAS_IMAGENET_JIG, | ||||
|     'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS, | ||||
|     'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT, | ||||
|     'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL, | ||||
|     'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG, | ||||
|     'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG, | ||||
|     'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT, | ||||
|     'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL, | ||||
|     'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG, | ||||
|     'custom': None, | ||||
| } | ||||
							
								
								
									
										299
									
								
								pycls/models/nas/nas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										299
									
								
								pycls/models/nas/nas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,299 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """NAS network (adopted from DARTS).""" | ||||
|  | ||||
| from torch.autograd import Variable | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| import pycls.core.logging as logging | ||||
|  | ||||
| from pycls.core.config import cfg | ||||
| from pycls.models.common import Preprocess | ||||
| from pycls.models.common import Classifier | ||||
| from pycls.models.nas.genotypes import GENOTYPES | ||||
| from pycls.models.nas.genotypes import Genotype | ||||
| from pycls.models.nas.operations import FactorizedReduce | ||||
| from pycls.models.nas.operations import OPS | ||||
| from pycls.models.nas.operations import ReLUConvBN | ||||
| from pycls.models.nas.operations import Identity | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|     """Drop path (ported from DARTS).""" | ||||
|     if drop_prob > 0.: | ||||
|         keep_prob = 1.-drop_prob | ||||
|         mask = Variable( | ||||
|             torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) | ||||
|         ) | ||||
|         x.div_(keep_prob) | ||||
|         x.mul_(mask) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Cell(nn.Module): | ||||
|     """NAS cell (ported from DARTS).""" | ||||
|  | ||||
|     def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): | ||||
|         super(Cell, self).__init__() | ||||
|         logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C)) | ||||
|  | ||||
|         if reduction_prev: | ||||
|             self.preprocess0 = FactorizedReduce(C_prev_prev, C) | ||||
|         else: | ||||
|             self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) | ||||
|         self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) | ||||
|  | ||||
|         if reduction: | ||||
|             op_names, indices = zip(*genotype.reduce) | ||||
|             concat = genotype.reduce_concat | ||||
|         else: | ||||
|             op_names, indices = zip(*genotype.normal) | ||||
|             concat = genotype.normal_concat | ||||
|         self._compile(C, op_names, indices, concat, reduction) | ||||
|  | ||||
|     def _compile(self, C, op_names, indices, concat, reduction): | ||||
|         assert len(op_names) == len(indices) | ||||
|         self._steps = len(op_names) // 2 | ||||
|         self._concat = concat | ||||
|         self.multiplier = len(concat) | ||||
|  | ||||
|         self._ops = nn.ModuleList() | ||||
|         for name, index in zip(op_names, indices): | ||||
|             stride = 2 if reduction and index < 2 else 1 | ||||
|             op = OPS[name](C, stride, True) | ||||
|             self._ops += [op] | ||||
|         self._indices = indices | ||||
|  | ||||
|     def forward(self, s0, s1, drop_prob): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         states = [s0, s1] | ||||
|         for i in range(self._steps): | ||||
|             h1 = states[self._indices[2*i]] | ||||
|             h2 = states[self._indices[2*i+1]] | ||||
|             op1 = self._ops[2*i] | ||||
|             op2 = self._ops[2*i+1] | ||||
|             h1 = op1(h1) | ||||
|             h2 = op2(h2) | ||||
|             if self.training and drop_prob > 0.: | ||||
|                 if not isinstance(op1, Identity): | ||||
|                     h1 = drop_path(h1, drop_prob) | ||||
|                 if not isinstance(op2, Identity): | ||||
|                     h2 = drop_path(h2, drop_prob) | ||||
|             s = h1 + h2 | ||||
|             states += [s] | ||||
|         return torch.cat([states[i] for i in self._concat], dim=1) | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|  | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 8x8""" | ||||
|         super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0),-1)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadImageNet(nn.Module): | ||||
|  | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 14x14""" | ||||
|         super(AuxiliaryHeadImageNet, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             # NOTE: This batchnorm was omitted in my earlier implementation due to a typo. | ||||
|             # Commenting it out for consistency with the experiments in the paper. | ||||
|             # nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0),-1)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class NetworkCIFAR(nn.Module): | ||||
|     """CIFAR network (ported from DARTS).""" | ||||
|  | ||||
|     def __init__(self, C, num_classes, layers, auxiliary, genotype): | ||||
|         super(NetworkCIFAR, self).__init__() | ||||
|         self._layers = layers | ||||
|         self._auxiliary = auxiliary | ||||
|  | ||||
|         stem_multiplier = 3 | ||||
|         C_curr = stem_multiplier*C | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C_curr) | ||||
|         ) | ||||
|  | ||||
|         C_prev_prev, C_prev, C_curr = C_curr, C_curr, C | ||||
|         self.cells = nn.ModuleList() | ||||
|         reduction_prev = False | ||||
|         for i in range(layers): | ||||
|             if i in [layers//3, 2*layers//3]: | ||||
|                 C_curr *= 2 | ||||
|                 reduction = True | ||||
|             else: | ||||
|                 reduction = False | ||||
|             cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|             reduction_prev = reduction | ||||
|             self.cells += [cell] | ||||
|             C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr | ||||
|             if i == 2*layers//3: | ||||
|                 C_to_auxiliary = C_prev | ||||
|  | ||||
|         if auxiliary: | ||||
|             self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) | ||||
|         self.classifier = Classifier(C_prev, num_classes) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         input = Preprocess(input) | ||||
|         logits_aux = None | ||||
|         s0 = s1 = self.stem(input) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|             if i == 2*self._layers//3: | ||||
|                 if self._auxiliary and self.training: | ||||
|                     logits_aux = self.auxiliary_head(s1) | ||||
|         logits = self.classifier(s1, input.shape[2:]) | ||||
|         if self._auxiliary and self.training: | ||||
|             return logits, logits_aux | ||||
|         return logits | ||||
|  | ||||
|  | ||||
| class NetworkImageNet(nn.Module): | ||||
|     """ImageNet network (ported from DARTS).""" | ||||
|  | ||||
|     def __init__(self, C, num_classes, layers, auxiliary, genotype): | ||||
|         super(NetworkImageNet, self).__init__() | ||||
|         self._layers = layers | ||||
|         self._auxiliary = auxiliary | ||||
|  | ||||
|         self.stem0 = nn.Sequential( | ||||
|             nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C // 2), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C), | ||||
|         ) | ||||
|  | ||||
|         self.stem1 = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C), | ||||
|         ) | ||||
|  | ||||
|         C_prev_prev, C_prev, C_curr = C, C, C | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         reduction_prev = True | ||||
|         reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3] | ||||
|         for i in range(layers): | ||||
|             if i in reduction_layers: | ||||
|                 C_curr *= 2 | ||||
|                 reduction = True | ||||
|             else: | ||||
|                 reduction = False | ||||
|             cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|             reduction_prev = reduction | ||||
|             self.cells += [cell] | ||||
|             C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr | ||||
|             if i == 2 * layers // 3: | ||||
|                 C_to_auxiliary = C_prev | ||||
|  | ||||
|         if auxiliary: | ||||
|             self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) | ||||
|         self.classifier = Classifier(C_prev, num_classes) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         input = Preprocess(input) | ||||
|         logits_aux = None | ||||
|         s0 = self.stem0(input) | ||||
|         s1 = self.stem1(s0) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|             if i == 2 * self._layers // 3: | ||||
|                 if self._auxiliary and self.training: | ||||
|                     logits_aux = self.auxiliary_head(s1) | ||||
|         logits = self.classifier(s1, input.shape[2:]) | ||||
|         if self._auxiliary and self.training: | ||||
|             return logits, logits_aux | ||||
|         return logits | ||||
|  | ||||
|  | ||||
| class NAS(nn.Module): | ||||
|     """NAS net wrapper (delegates to nets from DARTS).""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \ | ||||
|             'Training on {} is not supported'.format(cfg.TRAIN.DATASET) | ||||
|         assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \ | ||||
|             'Testing on {} is not supported'.format(cfg.TEST.DATASET) | ||||
|         assert cfg.NAS.GENOTYPE in GENOTYPES, \ | ||||
|             'Genotype {} not supported'.format(cfg.NAS.GENOTYPE) | ||||
|         super(NAS, self).__init__() | ||||
|         logger.info('Constructing NAS: {}'.format(cfg.NAS)) | ||||
|         # Use a custom or predefined genotype | ||||
|         if cfg.NAS.GENOTYPE == 'custom': | ||||
|             genotype = Genotype( | ||||
|                 normal=cfg.NAS.CUSTOM_GENOTYPE[0], | ||||
|                 normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1], | ||||
|                 reduce=cfg.NAS.CUSTOM_GENOTYPE[2], | ||||
|                 reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3], | ||||
|             ) | ||||
|         else: | ||||
|             genotype = GENOTYPES[cfg.NAS.GENOTYPE] | ||||
|         # Determine the network constructor for dataset | ||||
|         if 'cifar' in cfg.TRAIN.DATASET: | ||||
|             net_ctor = NetworkCIFAR | ||||
|         else: | ||||
|             net_ctor = NetworkImageNet | ||||
|         # Construct the network | ||||
|         self.net_ = net_ctor( | ||||
|             C=cfg.NAS.WIDTH, | ||||
|             num_classes=cfg.MODEL.NUM_CLASSES, | ||||
|             layers=cfg.NAS.DEPTH, | ||||
|             auxiliary=cfg.NAS.AUX, | ||||
|             genotype=genotype | ||||
|         ) | ||||
|         # Drop path probability (set / annealed based on epoch) | ||||
|         self.net_.drop_path_prob = 0.0 | ||||
|  | ||||
|     def set_drop_path_prob(self, drop_path_prob): | ||||
|         self.net_.drop_path_prob = drop_path_prob | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.net_.forward(x) | ||||
							
								
								
									
										201
									
								
								pycls/models/nas/operations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								pycls/models/nas/operations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,201 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
|  | ||||
| """NAS ops (adopted from DARTS).""" | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| OPS = { | ||||
|     'none': lambda C, stride, affine: | ||||
|         Zero(stride), | ||||
|     'avg_pool_2x2': lambda C, stride, affine: | ||||
|         nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False), | ||||
|     'avg_pool_3x3': lambda C, stride, affine: | ||||
|         nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), | ||||
|     'avg_pool_5x5': lambda C, stride, affine: | ||||
|         nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False), | ||||
|     'max_pool_2x2': lambda C, stride, affine: | ||||
|         nn.MaxPool2d(2, stride=stride, padding=0), | ||||
|     'max_pool_3x3': lambda C, stride, affine: | ||||
|         nn.MaxPool2d(3, stride=stride, padding=1), | ||||
|     'max_pool_5x5': lambda C, stride, affine: | ||||
|         nn.MaxPool2d(5, stride=stride, padding=2), | ||||
|     'max_pool_7x7': lambda C, stride, affine: | ||||
|         nn.MaxPool2d(7, stride=stride, padding=3), | ||||
|     'skip_connect': lambda C, stride, affine: | ||||
|         Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), | ||||
|     'conv_1x1': lambda C, stride, affine: | ||||
|         nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C, affine=affine) | ||||
|         ), | ||||
|     'conv_3x3': lambda C, stride, affine: | ||||
|         nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C, affine=affine) | ||||
|         ), | ||||
|     'sep_conv_3x3': lambda C, stride, affine: | ||||
|         SepConv(C, C, 3, stride, 1, affine=affine), | ||||
|     'sep_conv_5x5': lambda C, stride, affine: | ||||
|         SepConv(C, C, 5, stride, 2, affine=affine), | ||||
|     'sep_conv_7x7': lambda C, stride, affine: | ||||
|         SepConv(C, C, 7, stride, 3, affine=affine), | ||||
|     'dil_conv_3x3': lambda C, stride, affine: | ||||
|         DilConv(C, C, 3, stride, 2, 2, affine=affine), | ||||
|     'dil_conv_5x5': lambda C, stride, affine: | ||||
|         DilConv(C, C, 5, stride, 4, 2, affine=affine), | ||||
|     'dil_sep_conv_3x3': lambda C, stride, affine: | ||||
|         DilSepConv(C, C, 3, stride, 2, 2, affine=affine), | ||||
|     'conv_3x1_1x3': lambda C, stride, affine: | ||||
|         nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False), | ||||
|             nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False), | ||||
|             nn.BatchNorm2d(C, affine=affine) | ||||
|         ), | ||||
|     'conv_7x1_1x7': lambda C, stride, affine: | ||||
|         nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), | ||||
|             nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), | ||||
|             nn.BatchNorm2d(C, affine=affine) | ||||
|         ), | ||||
| } | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|     def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|         super(ReLUConvBN, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_out, kernel_size, stride=stride, | ||||
|                 padding=padding, bias=False | ||||
|             ), | ||||
|             nn.BatchNorm2d(C_out, affine=affine) | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class DilConv(nn.Module): | ||||
|  | ||||
|     def __init__( | ||||
|         self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True | ||||
|     ): | ||||
|         super(DilConv, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=stride, | ||||
|                 padding=padding, dilation=dilation, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_out, affine=affine), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|  | ||||
|     def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|         super(SepConv, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=stride, | ||||
|                 padding=padding, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_in, affine=affine), | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=1, | ||||
|                 padding=padding, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_out, affine=affine), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class DilSepConv(nn.Module): | ||||
|  | ||||
|     def __init__( | ||||
|         self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True | ||||
|     ): | ||||
|         super(DilSepConv, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=stride, | ||||
|                 padding=padding, dilation=dilation, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_in, affine=affine), | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, C_in, kernel_size=kernel_size, stride=1, | ||||
|                 padding=padding, dilation=dilation, groups=C_in, bias=False | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|             nn.BatchNorm2d(C_out, affine=affine), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|  | ||||
|     def __init__(self): | ||||
|         super(Identity, self).__init__() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|  | ||||
|     def __init__(self, stride): | ||||
|         super(Zero, self).__init__() | ||||
|         self.stride = stride | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.stride == 1: | ||||
|             return x.mul(0.) | ||||
|         return x[:,:,::self.stride,::self.stride].mul(0.) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|  | ||||
|     def __init__(self, C_in, C_out, affine=True): | ||||
|         super(FactorizedReduce, self).__init__() | ||||
|         assert C_out % 2 == 0 | ||||
|         self.relu = nn.ReLU(inplace=False) | ||||
|         self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||||
|         self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||||
|         self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||||
|         self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.relu(x) | ||||
|         y = self.pad(x) | ||||
|         out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1) | ||||
|         out = self.bn(out) | ||||
|         return out | ||||
							
								
								
									
										89
									
								
								pycls/models/regnet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								pycls/models/regnet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """RegNet models.""" | ||||
|  | ||||
| import numpy as np | ||||
| from pycls.core.config import cfg | ||||
| from pycls.models.anynet import AnyNet | ||||
|  | ||||
|  | ||||
| def quantize_float(f, q): | ||||
|     """Converts a float to closest non-zero int divisible by q.""" | ||||
|     return int(round(f / q) * q) | ||||
|  | ||||
|  | ||||
| def adjust_ws_gs_comp(ws, bms, gs): | ||||
|     """Adjusts the compatibility of widths and groups.""" | ||||
|     ws_bot = [int(w * b) for w, b in zip(ws, bms)] | ||||
|     gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)] | ||||
|     ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)] | ||||
|     ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)] | ||||
|     return ws, gs | ||||
|  | ||||
|  | ||||
| def get_stages_from_blocks(ws, rs): | ||||
|     """Gets ws/ds of network at each stage from per block values.""" | ||||
|     ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs) | ||||
|     ts = [w != wp or r != rp for w, wp, r, rp in ts_temp] | ||||
|     s_ws = [w for w, t in zip(ws, ts[:-1]) if t] | ||||
|     s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist() | ||||
|     return s_ws, s_ds | ||||
|  | ||||
|  | ||||
| def generate_regnet(w_a, w_0, w_m, d, q=8): | ||||
|     """Generates per block ws from RegNet parameters.""" | ||||
|     assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 | ||||
|     ws_cont = np.arange(d) * w_a + w_0 | ||||
|     ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) | ||||
|     ws = w_0 * np.power(w_m, ks) | ||||
|     ws = np.round(np.divide(ws, q)) * q | ||||
|     num_stages, max_stage = len(np.unique(ws)), ks.max() + 1 | ||||
|     ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist() | ||||
|     return ws, num_stages, max_stage, ws_cont | ||||
|  | ||||
|  | ||||
| class RegNet(AnyNet): | ||||
|     """RegNet model.""" | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_args(): | ||||
|         """Convert RegNet to AnyNet parameter format.""" | ||||
|         # Generate RegNet ws per block | ||||
|         w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH | ||||
|         ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) | ||||
|         # Convert to per stage format | ||||
|         s_ws, s_ds = get_stages_from_blocks(ws, ws) | ||||
|         # Use the same gw, bm and ss for each stage | ||||
|         s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)] | ||||
|         s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)] | ||||
|         s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)] | ||||
|         # Adjust the compatibility of ws and gws | ||||
|         s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs) | ||||
|         # Get AnyNet arguments defining the RegNet | ||||
|         return { | ||||
|             "stem_type": cfg.REGNET.STEM_TYPE, | ||||
|             "stem_w": cfg.REGNET.STEM_W, | ||||
|             "block_type": cfg.REGNET.BLOCK_TYPE, | ||||
|             "ds": s_ds, | ||||
|             "ws": s_ws, | ||||
|             "ss": s_ss, | ||||
|             "bms": s_bs, | ||||
|             "gws": s_gs, | ||||
|             "se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None, | ||||
|             "nc": cfg.MODEL.NUM_CLASSES, | ||||
|         } | ||||
|  | ||||
|     def __init__(self): | ||||
|         kwargs = RegNet.get_args() | ||||
|         super(RegNet, self).__init__(**kwargs) | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, **kwargs): | ||||
|         """Computes model complexity. If you alter the model, make sure to update.""" | ||||
|         kwargs = RegNet.get_args() if not kwargs else kwargs | ||||
|         return AnyNet.complexity(cx, **kwargs) | ||||
							
								
								
									
										280
									
								
								pycls/models/resnet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										280
									
								
								pycls/models/resnet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,280 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """ResNe(X)t models.""" | ||||
|  | ||||
| import pycls.core.net as net | ||||
| import torch.nn as nn | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| # Stage depths for ImageNet models | ||||
| _IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)} | ||||
|  | ||||
|  | ||||
| def get_trans_fun(name): | ||||
|     """Retrieves the transformation function by name.""" | ||||
|     trans_funs = { | ||||
|         "basic_transform": BasicTransform, | ||||
|         "bottleneck_transform": BottleneckTransform, | ||||
|     } | ||||
|     err_str = "Transformation function '{}' not supported" | ||||
|     assert name in trans_funs.keys(), err_str.format(name) | ||||
|     return trans_funs[name] | ||||
|  | ||||
|  | ||||
| class ResHead(nn.Module): | ||||
|     """ResNet head: AvgPool, 1x1.""" | ||||
|  | ||||
|     def __init__(self, w_in, nc): | ||||
|         super(ResHead, self).__init__() | ||||
|         self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.fc = nn.Linear(w_in, nc, bias=True) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.avg_pool(x) | ||||
|         x = x.view(x.size(0), -1) | ||||
|         x = self.fc(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, nc): | ||||
|         cx["h"], cx["w"] = 1, 1 | ||||
|         cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class BasicTransform(nn.Module): | ||||
|     """Basic transformation: 3x3, BN, ReLU, 3x3, BN.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1): | ||||
|         err_str = "Basic transform does not support w_b and num_gs options" | ||||
|         assert w_b is None and num_gs == 1, err_str | ||||
|         super(BasicTransform, self).__init__() | ||||
|         self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) | ||||
|         self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||||
|         self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) | ||||
|         self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.b_bn.final_bn = True | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1): | ||||
|         err_str = "Basic transform does not support w_b and num_gs options" | ||||
|         assert w_b is None and num_gs == 1, err_str | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class BottleneckTransform(nn.Module): | ||||
|     """Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, w_b, num_gs): | ||||
|         super(BottleneckTransform, self).__init__() | ||||
|         # MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3 | ||||
|         (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) | ||||
|         self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False) | ||||
|         self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||||
|         self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False) | ||||
|         self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) | ||||
|         self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False) | ||||
|         self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.c_bn.final_bn = True | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, w_b, num_gs): | ||||
|         (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_b) | ||||
|         cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_b) | ||||
|         cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResBlock(nn.Module): | ||||
|     """Residual block: x + F(x).""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1): | ||||
|         super(ResBlock, self).__init__() | ||||
|         # Use skip connection with projection if shape changes | ||||
|         self.proj_block = (w_in != w_out) or (stride != 1) | ||||
|         if self.proj_block: | ||||
|             self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) | ||||
|             self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.f = trans_fun(w_in, w_out, stride, w_b, num_gs) | ||||
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.proj_block: | ||||
|             x = self.bn(self.proj(x)) + self.f(x) | ||||
|         else: | ||||
|             x = x + self.f(x) | ||||
|         x = self.relu(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs): | ||||
|         proj_block = (w_in != w_out) or (stride != 1) | ||||
|         if proj_block: | ||||
|             h, w = cx["h"], cx["w"] | ||||
|             cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) | ||||
|             cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|             cx["h"], cx["w"] = h, w  # parallel branch | ||||
|         cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResStage(nn.Module): | ||||
|     """Stage of ResNet.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1): | ||||
|         super(ResStage, self).__init__() | ||||
|         for i in range(d): | ||||
|             b_stride = stride if i == 0 else 1 | ||||
|             b_w_in = w_in if i == 0 else w_out | ||||
|             trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN) | ||||
|             res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs) | ||||
|             self.add_module("b{}".format(i + 1), res_block) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for block in self.children(): | ||||
|             x = block(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1): | ||||
|         for i in range(d): | ||||
|             b_stride = stride if i == 0 else 1 | ||||
|             b_w_in = w_in if i == 0 else w_out | ||||
|             trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN) | ||||
|             cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResStemCifar(nn.Module): | ||||
|     """ResNet stem for CIFAR: 3x3, BN, ReLU.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out): | ||||
|         super(ResStemCifar, self).__init__() | ||||
|         self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False) | ||||
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out): | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResStemIN(nn.Module): | ||||
|     """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool.""" | ||||
|  | ||||
|     def __init__(self, w_in, w_out): | ||||
|         super(ResStemIN, self).__init__() | ||||
|         self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False) | ||||
|         self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) | ||||
|         self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) | ||||
|         self.pool = nn.MaxPool2d(3, stride=2, padding=1) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for layer in self.children(): | ||||
|             x = layer(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx, w_in, w_out): | ||||
|         cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3) | ||||
|         cx = net.complexity_batchnorm2d(cx, w_out) | ||||
|         cx = net.complexity_maxpool2d(cx, 3, 2, 1) | ||||
|         return cx | ||||
|  | ||||
|  | ||||
| class ResNet(nn.Module): | ||||
|     """ResNet model.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         datasets = ["cifar10", "imagenet"] | ||||
|         err_str = "Dataset {} is not supported" | ||||
|         assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET) | ||||
|         assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET) | ||||
|         super(ResNet, self).__init__() | ||||
|         if "cifar" in cfg.TRAIN.DATASET: | ||||
|             self._construct_cifar() | ||||
|         else: | ||||
|             self._construct_imagenet() | ||||
|         self.apply(net.init_weights) | ||||
|  | ||||
|     def _construct_cifar(self): | ||||
|         err_str = "Model depth should be of the format 6n + 2 for cifar" | ||||
|         assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str | ||||
|         d = int((cfg.MODEL.DEPTH - 2) / 6) | ||||
|         self.stem = ResStemCifar(3, 16) | ||||
|         self.s1 = ResStage(16, 16, stride=1, d=d) | ||||
|         self.s2 = ResStage(16, 32, stride=2, d=d) | ||||
|         self.s3 = ResStage(32, 64, stride=2, d=d) | ||||
|         self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES) | ||||
|  | ||||
|     def _construct_imagenet(self): | ||||
|         g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP | ||||
|         (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH] | ||||
|         w_b = gw * g | ||||
|         self.stem = ResStemIN(3, 64) | ||||
|         self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g) | ||||
|         self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g) | ||||
|         self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g) | ||||
|         self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g) | ||||
|         self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for module in self.children(): | ||||
|             x = module(x) | ||||
|         return x | ||||
|  | ||||
|     @staticmethod | ||||
|     def complexity(cx): | ||||
|         """Computes model complexity. If you alter the model, make sure to update.""" | ||||
|         if "cifar" in cfg.TRAIN.DATASET: | ||||
|             d = int((cfg.MODEL.DEPTH - 2) / 6) | ||||
|             cx = ResStemCifar.complexity(cx, 3, 16) | ||||
|             cx = ResStage.complexity(cx, 16, 16, stride=1, d=d) | ||||
|             cx = ResStage.complexity(cx, 16, 32, stride=2, d=d) | ||||
|             cx = ResStage.complexity(cx, 32, 64, stride=2, d=d) | ||||
|             cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES) | ||||
|         else: | ||||
|             g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP | ||||
|             (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH] | ||||
|             w_b = gw * g | ||||
|             cx = ResStemIN.complexity(cx, 3, 64) | ||||
|             cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g) | ||||
|             cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g) | ||||
|             cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g) | ||||
|             cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g) | ||||
|             cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES) | ||||
|         return cx | ||||
		Reference in New Issue
	
	Block a user