20 lines
		
	
	
		
			729 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			20 lines
		
	
	
		
			729 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import torch.nn as nn
 | |
| 
 | |
| 
 | |
| class ImageNetHEAD(nn.Sequential):
 | |
|   def __init__(self, C, stride=2):
 | |
|     super(ImageNetHEAD, self).__init__()
 | |
|     self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False))
 | |
|     self.add_module('bn1'  , nn.BatchNorm2d(C // 2))
 | |
|     self.add_module('relu1', nn.ReLU(inplace=True))
 | |
|     self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False))
 | |
|     self.add_module('bn2'  , nn.BatchNorm2d(C))
 | |
| 
 | |
| 
 | |
| class CifarHEAD(nn.Sequential):
 | |
|   def __init__(self, C):
 | |
|     super(CifarHEAD, self).__init__()
 | |
|     self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
 | |
|     self.add_module('bn', nn.BatchNorm2d(C))
 |