##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # ##################################################### import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Callable, Tuple from xautodl import spaces from .super_module import SuperModule from .super_module import IntSpaceType from .super_module import BoolSpaceType class SuperDropout(SuperModule): """Applies a the dropout function element-wise.""" def __init__(self, p: float = 0.5, inplace: bool = False) -> None: super(SuperDropout, self).__init__() self._p = p self._inplace = inplace @property def abstract_search_space(self): return spaces.VirtualNode(id(self)) def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: return self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: return F.dropout(input, self._p, self.training, self._inplace) def forward_with_container(self, input, container, prefix=[]): return self.forward_raw(input) def extra_repr(self) -> str: xstr = "inplace=True" if self._inplace else "" return "p={:}".format(self._p) + ", " + xstr class SuperDrop(SuperModule): """Applies a the drop-path function element-wise.""" def __init__(self, p: float, dims: Tuple[int], recover: bool = True) -> None: super(SuperDrop, self).__init__() self._p = p self._dims = dims self._recover = recover @property def abstract_search_space(self): return spaces.VirtualNode(id(self)) def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: return self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: if not self.training or self._p <= 0: return input keep_prob = 1 - self._p shape = [input.shape[0]] + [ x if y == -1 else y for x, y in zip(input.shape[1:], self._dims) ] random_tensor = keep_prob + torch.rand( shape, dtype=input.dtype, device=input.device ) random_tensor.floor_() # binarize if self._recover: return input.div(keep_prob) * random_tensor else: return input * random_tensor # as masks def forward_with_container(self, input, container, prefix=[]): return self.forward_raw(input) def extra_repr(self) -> str: return ( "p={:}".format(self._p) + ", dims={:}".format(self._dims) + ", recover={:}".format(self._recover) )