import torch.nn as nn from typing import Optional class MLP(nn.Module): # MLP: FC -> Activation -> Drop -> FC -> Drop def __init__(self, in_features, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer=nn.GELU, drop: Optional[float] = None): super(MLP, self).__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop or 0) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x