MeCo/pycls/models/common.py
HamsterMimi 189df25fd3 upload
2023-05-04 13:09:03 +08:00

108 lines
3.9 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from pycls.core.config import cfg
def Preprocess(x):
if cfg.TASK == 'jig':
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
return x
class Classifier(nn.Module):
def __init__(self, channels, num_classes):
super(Classifier, self).__init__()
if cfg.TASK == 'jig':
self.jig_sq = cfg.JIGSAW_GRID ** 2
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
elif cfg.TASK == 'col':
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
elif cfg.TASK == 'seg':
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
else:
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels, num_classes)
def forward(self, x, shape):
if cfg.TASK == 'jig':
x = self.pooling(x)
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
x = self.classifier(x.view(x.size(0), -1))
elif cfg.TASK in ['col', 'seg']:
x = self.classifier(x)
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
else:
x = self.pooling(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, num_classes, rates):
super(ASPP, self).__init__()
assert len(rates) in [1, 3]
self.rates = rates
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.aspp1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
padding=rates[0], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
if len(self.rates) == 3:
self.aspp3 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
padding=rates[1], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp4 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
padding=rates[2], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp5 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.classifier = nn.Sequential(
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, num_classes, 1)
)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x5 = self.global_pooling(x)
x5 = self.aspp5(x5)
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
align_corners=True)(x5)
if len(self.rates) == 3:
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x = torch.cat((x1, x2, x3, x4, x5), 1)
else:
x = torch.cat((x1, x2, x5), 1)
x = self.classifier(x)
return x