109 lines
3.9 KiB
Python
109 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from pycls.core.config import cfg
|
|
|
|
|
|
def Preprocess(x):
|
|
if cfg.TASK == 'jig':
|
|
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
|
|
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
|
|
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
|
|
return x
|
|
|
|
|
|
class Classifier(nn.Module):
|
|
def __init__(self, channels, num_classes):
|
|
super(Classifier, self).__init__()
|
|
if cfg.TASK == 'jig':
|
|
self.jig_sq = cfg.JIGSAW_GRID ** 2
|
|
self.pooling = nn.AdaptiveAvgPool2d(1)
|
|
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
|
|
elif cfg.TASK == 'col':
|
|
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
|
|
elif cfg.TASK == 'seg':
|
|
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
|
|
else:
|
|
self.pooling = nn.AdaptiveAvgPool2d(1)
|
|
self.classifier = nn.Linear(channels, num_classes)
|
|
|
|
def forward(self, x, shape):
|
|
if cfg.TASK == 'jig':
|
|
x = self.pooling(x)
|
|
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
|
|
x = self.classifier(x.view(x.size(0), -1))
|
|
elif cfg.TASK in ['col', 'seg']:
|
|
x = self.classifier(x)
|
|
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
|
|
else:
|
|
x = self.pooling(x)
|
|
x = self.classifier(x.view(x.size(0), -1))
|
|
return x
|
|
|
|
|
|
class ASPP(nn.Module):
|
|
def __init__(self, in_channels, out_channels, num_classes, rates):
|
|
super(ASPP, self).__init__()
|
|
assert len(rates) in [1, 3]
|
|
self.rates = rates
|
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|
self.aspp1 = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.aspp2 = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
|
|
padding=rates[0], bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
if len(self.rates) == 3:
|
|
self.aspp3 = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
|
|
padding=rates[1], bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.aspp4 = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
|
|
padding=rates[2], bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.aspp5 = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.classifier = nn.Sequential(
|
|
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
|
|
bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(out_channels, num_classes, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x1 = self.aspp1(x)
|
|
x2 = self.aspp2(x)
|
|
x5 = self.global_pooling(x)
|
|
x5 = self.aspp5(x5)
|
|
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
|
|
align_corners=True)(x5)
|
|
if len(self.rates) == 3:
|
|
x3 = self.aspp3(x)
|
|
x4 = self.aspp4(x)
|
|
x = torch.cat((x1, x2, x3, x4, x5), 1)
|
|
else:
|
|
x = torch.cat((x1, x2, x5), 1)
|
|
x = self.classifier(x)
|
|
return x
|