import torch import torch.nn as nn # Squeeze and Excitation module class SqEx(nn.Module): def __init__(self, n_features, reduction=16): super(SqEx, self).__init__() if n_features % reduction != 0: raise ValueError('n_features must be divisible by reduction (default = 16)') self.linear1 = nn.Linear(n_features, n_features // reduction, bias=True) self.nonlin1 = nn.ReLU(inplace=True) self.linear2 = nn.Linear(n_features // reduction, n_features, bias=True) self.nonlin2 = nn.Sigmoid() def forward(self, x): y = F.avg_pool2d(x, kernel_size=x.size()[2:4]) y = y.permute(0, 2, 3, 1) y = self.nonlin1(self.linear1(y)) y = self.nonlin2(self.linear2(y)) y = y.permute(0, 3, 1, 2) y = x * y return y