xautodl/lib/nas/SE_Module.py
2019-02-01 01:27:38 +11:00

28 lines
762 B
Python

import torch
import torch.nn as nn
# Squeeze and Excitation module
class SqEx(nn.Module):
def __init__(self, n_features, reduction=16):
super(SqEx, self).__init__()
if n_features % reduction != 0:
raise ValueError('n_features must be divisible by reduction (default = 16)')
self.linear1 = nn.Linear(n_features, n_features // reduction, bias=True)
self.nonlin1 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(n_features // reduction, n_features, bias=True)
self.nonlin2 = nn.Sigmoid()
def forward(self, x):
y = F.avg_pool2d(x, kernel_size=x.size()[2:4])
y = y.permute(0, 2, 3, 1)
y = self.nonlin1(self.linear1(y))
y = self.nonlin2(self.linear2(y))
y = y.permute(0, 3, 1, 2)
y = x * y
return y