Temp / 0.5
This commit is contained in:
		| @@ -1,11 +1,15 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import math | ||||
|  | ||||
| class PositionalEncoder(nn.Module): | ||||
|   # Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf | ||||
|   # https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65 | ||||
|  | ||||
|   def __init__(self, d_model, max_seq_len): | ||||
|   def __init__(self, d_model, max_seq_len, dropout=0.1): | ||||
|     super(PositionalEncoder, self).__init__() | ||||
|     self.d_model = d_model | ||||
|     # create constant 'pe' matrix with values dependant on  | ||||
| @@ -26,4 +30,6 @@ class PositionalEncoder(nn.Module): | ||||
|   def forward(self, x): | ||||
|     batch, seq, fdim = x.shape[:3] | ||||
|     embeddings = self.pe[:, :seq, :fdim] | ||||
|     import pdb; pdb.set_trace() | ||||
|     outs = self.dropout(x + embeddings) | ||||
|     return x + embeddings | ||||
|   | ||||
		Reference in New Issue
	
	Block a user