134 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			134 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import functools
 | |
| 
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| 
 | |
| __all__ = ['ShuffleNetV2']
 | |
| 
 | |
| 
 | |
| def channel_shuffle(x, groups):
 | |
|   batchsize, num_channels, height, width = x.data.size()
 | |
|   channels_per_group = num_channels // groups
 | |
| 
 | |
|   # reshape
 | |
|   x = x.view(batchsize, groups, channels_per_group, height, width)
 | |
| 
 | |
|   x = torch.transpose(x, 1, 2).contiguous()
 | |
| 
 | |
|   # flatten
 | |
|   x = x.view(batchsize, -1, height, width)
 | |
| 
 | |
|   return x
 | |
| 
 | |
| 
 | |
| class InvertedResidual(nn.Module):
 | |
|   def __init__(self, inp, oup, stride):
 | |
|     super(InvertedResidual, self).__init__()
 | |
| 
 | |
|     if not (1 <= stride <= 3):
 | |
|       raise ValueError('illegal stride value')
 | |
|     self.stride = stride
 | |
| 
 | |
|     branch_features = oup // 2
 | |
|     assert (self.stride != 1) or (inp == branch_features << 1)
 | |
| 
 | |
|     pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False)
 | |
|     dw_conv33 = functools.partial(self.depthwise_conv, kernel_size=3, stride=self.stride, padding=1)
 | |
| 
 | |
|     if self.stride > 1:
 | |
|       self.branch1 = nn.Sequential(
 | |
|         dw_conv33(inp, inp),
 | |
|         nn.BatchNorm2d(inp),
 | |
|         pw_conv11(inp, branch_features),
 | |
|         nn.BatchNorm2d(branch_features),
 | |
|         nn.ReLU(inplace=True),
 | |
|       )
 | |
| 
 | |
|     self.branch2 = nn.Sequential(
 | |
|       pw_conv11(inp if (self.stride > 1) else branch_features, branch_features),
 | |
|       nn.BatchNorm2d(branch_features),
 | |
|       nn.ReLU(inplace=True),
 | |
|       dw_conv33(branch_features, branch_features),
 | |
|       nn.BatchNorm2d(branch_features),
 | |
|       pw_conv11(branch_features, branch_features),
 | |
|       nn.BatchNorm2d(branch_features),
 | |
|       nn.ReLU(inplace=True),
 | |
|     )
 | |
| 
 | |
|   @staticmethod
 | |
|   def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
 | |
|     return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
 | |
| 
 | |
|   def forward(self, x):
 | |
|     if self.stride == 1:
 | |
|       x1, x2 = x.chunk(2, dim=1)
 | |
|       out = torch.cat((x1, self.branch2(x2)), dim=1)
 | |
|     else:
 | |
|       out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
 | |
| 
 | |
|     out = channel_shuffle(out, 2)
 | |
|     return out
 | |
| 
 | |
| 
 | |
| class ShuffleNetV2(nn.Module):
 | |
|   def __init__(self, num_classes, stages):
 | |
|     super(ShuffleNetV2, self).__init__()
 | |
| 
 | |
|     self.stage_out_channels = stages
 | |
|     assert len(stages) == 5, 'invalid stages : {:}'.format(stages)
 | |
|     self.message = 'stages: ' + ' '.join([str(x) for x in stages])
 | |
| 
 | |
|     input_channels = 3
 | |
|     output_channels = self.stage_out_channels[0]
 | |
|     self.conv1 = nn.Sequential(
 | |
|       nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
 | |
|       nn.BatchNorm2d(output_channels),
 | |
|       nn.ReLU(inplace=True),
 | |
|     )
 | |
|     input_channels = output_channels
 | |
| 
 | |
|     self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 | |
| 
 | |
|     stage_names = ['stage{:}'.format(i) for i in [2, 3, 4]]
 | |
|     stage_repeats = [4, 8, 4]
 | |
|     for name, repeats, output_channels in zip(
 | |
|         stage_names, stage_repeats, self.stage_out_channels[1:]):
 | |
|       seq = [InvertedResidual(input_channels, output_channels, 2)]
 | |
|       for i in range(repeats - 1):
 | |
|         seq.append(InvertedResidual(output_channels, output_channels, 1))
 | |
|       setattr(self, name, nn.Sequential(*seq))
 | |
|       input_channels = output_channels
 | |
| 
 | |
|     output_channels = self.stage_out_channels[-1]
 | |
|     self.conv5 = nn.Sequential(
 | |
|       nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
 | |
|       nn.BatchNorm2d(output_channels),
 | |
|       nn.ReLU(inplace=True),
 | |
|     )
 | |
| 
 | |
|     self.fc = nn.Linear(output_channels, num_classes)
 | |
| 
 | |
|   def get_message(self):
 | |
|     return self.message
 | |
| 
 | |
|   def forward(self, inputs):
 | |
|     x = self.conv1( inputs )
 | |
|     x = self.maxpool(x)
 | |
|     x = self.stage2(x)
 | |
|     x = self.stage3(x)
 | |
|     x = self.stage4(x)
 | |
|     x = self.conv5(x)
 | |
|     features = x.mean([2, 3])  # globalpool
 | |
|     predicts = self.fc(features)
 | |
|     return features, predicts
 | |
| 
 | |
|   #@staticmethod
 | |
|   #def _getStages(mult):
 | |
|   #  stages = {
 | |
|   #    '0.5': [24, 48,  96 , 192, 1024],
 | |
|   #    '1.0': [24, 116, 232, 464, 1024],
 | |
|   #    '1.5': [24, 176, 352, 704, 1024],
 | |
|   #    '2.0': [24, 244, 488, 976, 2048],
 | |
|   #  }
 | |
|   #  return stages[str(mult)]
 |