176 lines
6.9 KiB
Python
176 lines
6.9 KiB
Python
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from torch.nn import init
|
|
from torch.nn.parameter import Parameter
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
|
|
class Linear(nn.Linear):
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
super(Linear, self).__init__(in_features, out_features, bias)
|
|
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
|
self.register_buffer('score', torch.zeros(self.weight.shape))
|
|
if self.bias is not None:
|
|
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
|
|
|
|
def forward(self, input):
|
|
W = self.weight_mask * self.weight
|
|
if self.bias is not None:
|
|
b = self.bias_mask * self.bias
|
|
else:
|
|
b = self.bias
|
|
return F.linear(input, W, b)
|
|
|
|
|
|
class Conv2d(nn.Conv2d):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
|
padding=0, dilation=1, groups=1,
|
|
bias=True, padding_mode='zeros'):
|
|
super(Conv2d, self).__init__(
|
|
in_channels, out_channels, kernel_size, stride, padding,
|
|
dilation, groups, bias, padding_mode)
|
|
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
|
self.register_buffer('score', torch.zeros(self.weight.shape))
|
|
if self.bias is not None:
|
|
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
|
|
|
|
def _conv_forward(self, input, weight, bias):
|
|
if self.padding_mode != 'zeros':
|
|
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
|
|
weight, bias, self.stride,
|
|
_pair(0), self.dilation, self.groups)
|
|
return F.conv2d(input, weight, bias, self.stride,
|
|
self.padding, self.dilation, self.groups)
|
|
|
|
def forward(self, input):
|
|
W = self.weight_mask * self.weight
|
|
if self.bias is not None:
|
|
b = self.bias_mask * self.bias
|
|
else:
|
|
b = self.bias
|
|
return self._conv_forward(input, W, b)
|
|
|
|
|
|
class BatchNorm1d(nn.BatchNorm1d):
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
|
track_running_stats=True):
|
|
super(BatchNorm1d, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats)
|
|
if self.affine:
|
|
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
|
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
|
|
self.register_buffer('score', torch.zeros(self.weight.shape))
|
|
def forward(self, input):
|
|
self._check_input_dim(input)
|
|
|
|
# exponential_average_factor is set to self.momentum
|
|
# (when it is available) only so that if gets updated
|
|
# in ONNX graph when this node is exported to ONNX.
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and self.track_running_stats:
|
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
|
if self.num_batches_tracked is not None:
|
|
self.num_batches_tracked = self.num_batches_tracked + 1
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
if self.affine:
|
|
W = self.weight_mask * self.weight
|
|
b = self.bias_mask * self.bias
|
|
else:
|
|
W = self.weight
|
|
b = self.bias
|
|
|
|
return F.batch_norm(
|
|
input, self.running_mean, self.running_var, W, b,
|
|
self.training or not self.track_running_stats,
|
|
exponential_average_factor, self.eps)
|
|
|
|
|
|
class BatchNorm2d(nn.BatchNorm2d):
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
|
track_running_stats=True):
|
|
super(BatchNorm2d, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats)
|
|
if self.affine:
|
|
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
|
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
|
|
self.register_buffer('score', torch.zeros(self.weight.shape))
|
|
def forward(self, input):
|
|
self._check_input_dim(input)
|
|
|
|
# exponential_average_factor is set to self.momentum
|
|
# (when it is available) only so that if gets updated
|
|
# in ONNX graph when this node is exported to ONNX.
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and self.track_running_stats:
|
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
|
if self.num_batches_tracked is not None:
|
|
self.num_batches_tracked = self.num_batches_tracked + 1
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
if self.affine:
|
|
W = self.weight_mask * self.weight
|
|
b = self.bias_mask * self.bias
|
|
else:
|
|
W = self.weight
|
|
b = self.bias
|
|
|
|
return F.batch_norm(
|
|
input, self.running_mean, self.running_var, W, b,
|
|
self.training or not self.track_running_stats,
|
|
exponential_average_factor, self.eps)
|
|
|
|
|
|
class Identity1d(nn.Module):
|
|
def __init__(self, num_features):
|
|
super(Identity1d, self).__init__()
|
|
self.num_features = num_features
|
|
self.weight = Parameter(torch.Tensor(num_features))
|
|
self.bias = None
|
|
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
|
self.reset_parameters()
|
|
self.register_buffer('score', torch.zeros(self.weight.shape))
|
|
|
|
def reset_parameters(self):
|
|
init.ones_(self.weight)
|
|
|
|
def forward(self, input):
|
|
W = self.weight_mask * self.weight
|
|
return input * W
|
|
|
|
|
|
class Identity2d(nn.Module):
|
|
def __init__(self, num_features):
|
|
super(Identity2d, self).__init__()
|
|
self.num_features = num_features
|
|
self.weight = Parameter(torch.Tensor(num_features, 1, 1))
|
|
self.bias = None
|
|
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
|
self.reset_parameters()
|
|
self.register_buffer('score', torch.zeros(self.weight.shape))
|
|
|
|
def reset_parameters(self):
|
|
init.ones_(self.weight)
|
|
|
|
def forward(self, input):
|
|
W = self.weight_mask * self.weight
|
|
return input * W
|
|
|
|
|
|
|
|
|