import math import torch import torch.nn as nn from torch.nn import functional as F from torch.nn import init from torch.nn.parameter import Parameter from torch.nn.modules.utils import _pair class Linear(nn.Linear): def __init__(self, in_features, out_features, bias=True): super(Linear, self).__init__(in_features, out_features, bias) self.register_buffer('weight_mask', torch.ones(self.weight.shape)) self.register_buffer('score', torch.zeros(self.weight.shape)) if self.bias is not None: self.register_buffer('bias_mask', torch.ones(self.bias.shape)) def forward(self, input): W = self.weight_mask * self.weight if self.bias is not None: b = self.bias_mask * self.bias else: b = self.bias return F.linear(input, W, b) class Conv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): super(Conv2d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) self.register_buffer('weight_mask', torch.ones(self.weight.shape)) self.register_buffer('score', torch.zeros(self.weight.shape)) if self.bias is not None: self.register_buffer('bias_mask', torch.ones(self.bias.shape)) def _conv_forward(self, input, weight, bias): if self.padding_mode != 'zeros': return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _pair(0), self.dilation, self.groups) return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def forward(self, input): W = self.weight_mask * self.weight if self.bias is not None: b = self.bias_mask * self.bias else: b = self.bias return self._conv_forward(input, W, b) class BatchNorm1d(nn.BatchNorm1d): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(BatchNorm1d, self).__init__( num_features, eps, momentum, affine, track_running_stats) if self.affine: self.register_buffer('weight_mask', torch.ones(self.weight.shape)) self.register_buffer('bias_mask', torch.ones(self.bias.shape)) self.register_buffer('score', torch.zeros(self.weight.shape)) def forward(self, input): self._check_input_dim(input) # exponential_average_factor is set to self.momentum # (when it is available) only so that if gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum if self.affine: W = self.weight_mask * self.weight b = self.bias_mask * self.bias else: W = self.weight b = self.bias return F.batch_norm( input, self.running_mean, self.running_var, W, b, self.training or not self.track_running_stats, exponential_average_factor, self.eps) class BatchNorm2d(nn.BatchNorm2d): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(BatchNorm2d, self).__init__( num_features, eps, momentum, affine, track_running_stats) if self.affine: self.register_buffer('weight_mask', torch.ones(self.weight.shape)) self.register_buffer('bias_mask', torch.ones(self.bias.shape)) self.register_buffer('score', torch.zeros(self.weight.shape)) def forward(self, input): self._check_input_dim(input) # exponential_average_factor is set to self.momentum # (when it is available) only so that if gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum if self.affine: W = self.weight_mask * self.weight b = self.bias_mask * self.bias else: W = self.weight b = self.bias return F.batch_norm( input, self.running_mean, self.running_var, W, b, self.training or not self.track_running_stats, exponential_average_factor, self.eps) class Identity1d(nn.Module): def __init__(self, num_features): super(Identity1d, self).__init__() self.num_features = num_features self.weight = Parameter(torch.Tensor(num_features)) self.bias = None self.register_buffer('weight_mask', torch.ones(self.weight.shape)) self.reset_parameters() self.register_buffer('score', torch.zeros(self.weight.shape)) def reset_parameters(self): init.ones_(self.weight) def forward(self, input): W = self.weight_mask * self.weight return input * W class Identity2d(nn.Module): def __init__(self, num_features): super(Identity2d, self).__init__() self.num_features = num_features self.weight = Parameter(torch.Tensor(num_features, 1, 1)) self.bias = None self.register_buffer('weight_mask', torch.ones(self.weight.shape)) self.reset_parameters() self.register_buffer('score', torch.zeros(self.weight.shape)) def reset_parameters(self): init.ones_(self.weight) def forward(self, input): W = self.weight_mask * self.weight return input * W