2021-03-04 06:42:52 +01:00
|
|
|
import torch.nn as nn
|
|
|
|
from typing import Optional
|
|
|
|
|
2021-03-18 09:02:55 +01:00
|
|
|
|
2021-03-04 06:42:52 +01:00
|
|
|
class MLP(nn.Module):
|
2021-03-18 09:02:55 +01:00
|
|
|
# MLP: FC -> Activation -> Drop -> FC -> Drop
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
in_features,
|
|
|
|
hidden_features: Optional[int] = None,
|
|
|
|
out_features: Optional[int] = None,
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
drop: Optional[float] = None,
|
|
|
|
):
|
|
|
|
super(MLP, self).__init__()
|
|
|
|
out_features = out_features or in_features
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
self.act = act_layer()
|
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
self.drop = nn.Dropout(drop or 0)
|
2021-03-04 06:42:52 +01:00
|
|
|
|
2021-03-18 09:02:55 +01:00
|
|
|
def forward(self, x):
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.act(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
return x
|