125 lines
3.8 KiB
Python
125 lines
3.8 KiB
Python
#####################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
|
#####################################################
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import math
|
|
from typing import Optional, Callable
|
|
|
|
from xautodl import spaces
|
|
from .super_module import SuperModule
|
|
from .super_module import IntSpaceType
|
|
from .super_module import BoolSpaceType
|
|
|
|
|
|
class SuperReLU(SuperModule):
|
|
"""Applies a the rectified linear unit function element-wise."""
|
|
|
|
def __init__(self, inplace: bool = False) -> None:
|
|
super(SuperReLU, self).__init__()
|
|
self._inplace = inplace
|
|
|
|
@property
|
|
def abstract_search_space(self):
|
|
return spaces.VirtualNode(id(self))
|
|
|
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
|
return self.forward_raw(input)
|
|
|
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
|
return F.relu(input, inplace=self._inplace)
|
|
|
|
def forward_with_container(self, input, container, prefix=[]):
|
|
return self.forward_raw(input)
|
|
|
|
def extra_repr(self) -> str:
|
|
return "inplace=True" if self._inplace else ""
|
|
|
|
|
|
class SuperGELU(SuperModule):
|
|
"""Applies a the Gaussian Error Linear Units function element-wise."""
|
|
|
|
def __init__(self) -> None:
|
|
super(SuperGELU, self).__init__()
|
|
|
|
@property
|
|
def abstract_search_space(self):
|
|
return spaces.VirtualNode(id(self))
|
|
|
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
|
return self.forward_raw(input)
|
|
|
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
|
return F.gelu(input)
|
|
|
|
def forward_with_container(self, input, container, prefix=[]):
|
|
return self.forward_raw(input)
|
|
|
|
|
|
class SuperSigmoid(SuperModule):
|
|
"""Applies a the Sigmoid function element-wise."""
|
|
|
|
def __init__(self) -> None:
|
|
super(SuperSigmoid, self).__init__()
|
|
|
|
@property
|
|
def abstract_search_space(self):
|
|
return spaces.VirtualNode(id(self))
|
|
|
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
|
return self.forward_raw(input)
|
|
|
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
|
return torch.sigmoid(input)
|
|
|
|
def forward_with_container(self, input, container, prefix=[]):
|
|
return self.forward_raw(input)
|
|
|
|
|
|
class SuperLeakyReLU(SuperModule):
|
|
"""https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU"""
|
|
|
|
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
|
|
super(SuperLeakyReLU, self).__init__()
|
|
self._negative_slope = negative_slope
|
|
self._inplace = inplace
|
|
|
|
@property
|
|
def abstract_search_space(self):
|
|
return spaces.VirtualNode(id(self))
|
|
|
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
|
return self.forward_raw(input)
|
|
|
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
|
return F.leaky_relu(input, self._negative_slope, self._inplace)
|
|
|
|
def forward_with_container(self, input, container, prefix=[]):
|
|
return self.forward_raw(input)
|
|
|
|
def extra_repr(self) -> str:
|
|
inplace_str = "inplace=True" if self._inplace else ""
|
|
return "negative_slope={}{}".format(self._negative_slope, inplace_str)
|
|
|
|
|
|
class SuperTanh(SuperModule):
|
|
"""Applies a the Tanh function element-wise."""
|
|
|
|
def __init__(self) -> None:
|
|
super(SuperTanh, self).__init__()
|
|
|
|
@property
|
|
def abstract_search_space(self):
|
|
return spaces.VirtualNode(id(self))
|
|
|
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
|
return self.forward_raw(input)
|
|
|
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
|
return torch.tanh(input)
|
|
|
|
def forward_with_container(self, input, container, prefix=[]):
|
|
return self.forward_raw(input)
|