autodl-projects/xautodl/xlayers/super_positional_embedding.py

106 lines
3.7 KiB
Python
Raw Permalink Normal View History

2021-03-21 13:52:22 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
import torch
import torch.nn as nn
import math
2021-05-19 07:00:33 +02:00
from xautodl import spaces
2021-03-21 13:52:22 +01:00
from .super_module import SuperModule
from .super_module import IntSpaceType
2021-05-12 09:45:45 +02:00
class SuperDynamicPositionE(SuperModule):
"""Applies a positional encoding to the input positions."""
def __init__(self, dimension: int, scale: float = 1.0) -> None:
super(SuperDynamicPositionE, self).__init__()
self._scale = scale
self._dimension = dimension
# weights to be optimized
self.register_buffer(
"_div_term",
torch.exp(
torch.arange(0, dimension, 2).float() * (-math.log(10000.0) / dimension)
),
)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
return root_node
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
2021-05-22 10:41:54 +02:00
positions = torch.unsqueeze(input * self._scale, dim=-1)
divisions = torch.reshape(
self._div_term, [1] * input.ndim + [self._div_term.numel()]
)
values = positions / divisions
embeds = torch.cat((torch.sin(values), torch.cos(values)), dim=-1)
return embeds
2021-05-12 09:45:45 +02:00
def extra_repr(self) -> str:
return "scale={:}, dim={:}".format(self._scale, self._dimension)
2021-03-21 13:52:22 +01:00
class SuperPositionalEncoder(SuperModule):
"""Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65
"""
def __init__(self, d_model: IntSpaceType, max_seq_len: int, dropout: float = 0.1):
super(SuperPositionalEncoder, self).__init__()
self._d_model = d_model
# create constant 'pe' matrix with values dependant on
# pos and i
self.dropout = nn.Dropout(p=dropout)
self.register_buffer("pe", self.create_pos_embed(max_seq_len, self.d_model))
@property
def d_model(self):
return spaces.get_max(self._d_model)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._d_model):
root_node.append("_d_model", self._d_model.abstract(reuse_last=True))
return root_node
def create_pos_embed(self, max_seq_len, d_model):
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
for i in range(0, d_model):
div = 10000 ** ((i // 2) * 2 / d_model)
value = pos / div
if i % 2 == 0:
pe[pos, i] = math.sin(value)
else:
pe[pos, i] = math.cos(value)
return pe.unsqueeze(0)
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
batch, seq, fdim = input.shape[:3]
embeddings = self.pe[:, :seq]
if not spaces.is_determined(self._d_model):
expected_d_model = self.abstract_child["_d_model"].value
else:
expected_d_model = spaces.get_determined_value(self._d_model)
assert fdim == expected_d_model, "{:} vs {:}".format(fdim, expected_d_model)
embeddings = torch.nn.functional.interpolate(
embeddings, size=(expected_d_model), mode="linear", align_corners=True
)
outs = self.dropout(input + embeddings)
return outs
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
batch, seq, fdim = input.shape[:3]
embeddings = self.pe[:, :seq]
outs = self.dropout(input + embeddings)
return outs