28 lines
		
	
	
		
			762 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			28 lines
		
	
	
		
			762 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import torch.nn as nn
 | |
| # Squeeze and Excitation module
 | |
| 
 | |
| class SqEx(nn.Module):
 | |
| 
 | |
|   def __init__(self, n_features, reduction=16):
 | |
|     super(SqEx, self).__init__()
 | |
| 
 | |
|     if n_features % reduction != 0:
 | |
|       raise ValueError('n_features must be divisible by reduction (default = 16)')
 | |
| 
 | |
|     self.linear1 = nn.Linear(n_features, n_features // reduction, bias=True)
 | |
|     self.nonlin1 = nn.ReLU(inplace=True)
 | |
|     self.linear2 = nn.Linear(n_features // reduction, n_features, bias=True)
 | |
|     self.nonlin2 = nn.Sigmoid()
 | |
| 
 | |
|   def forward(self, x):
 | |
| 
 | |
|     y = F.avg_pool2d(x, kernel_size=x.size()[2:4])
 | |
|     y = y.permute(0, 2, 3, 1)
 | |
|     y = self.nonlin1(self.linear1(y))
 | |
|     y = self.nonlin2(self.linear2(y))
 | |
|     y = y.permute(0, 3, 1, 2)
 | |
|     y = x * y
 | |
|     return y
 | |
| 
 |