2021-04-29 11:17:44 +02:00
|
|
|
#######################################################
|
2021-04-29 08:28:37 +02:00
|
|
|
# Use module in xlayers to construct different models #
|
|
|
|
#######################################################
|
|
|
|
from typing import List, Text, Dict, Any
|
2021-04-29 10:30:47 +02:00
|
|
|
import torch
|
2021-04-29 08:28:37 +02:00
|
|
|
|
|
|
|
__all__ = ["get_model"]
|
|
|
|
|
|
|
|
|
2021-05-19 09:19:20 +02:00
|
|
|
from xautodl.xlayers.super_core import SuperSequential
|
|
|
|
from xautodl.xlayers.super_core import SuperLinear
|
|
|
|
from xautodl.xlayers.super_core import SuperDropout
|
|
|
|
from xautodl.xlayers.super_core import super_name2norm
|
|
|
|
from xautodl.xlayers.super_core import super_name2activation
|
2021-04-29 08:28:37 +02:00
|
|
|
|
|
|
|
|
|
|
|
def get_model(config: Dict[Text, Any], **kwargs):
|
2021-07-04 13:59:06 +02:00
|
|
|
model_type = config.get("model_type", "simple_mlp").lower()
|
2021-04-29 08:28:37 +02:00
|
|
|
if model_type == "simple_mlp":
|
2021-05-06 10:38:58 +02:00
|
|
|
act_cls = super_name2activation[kwargs["act_cls"]]
|
|
|
|
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
|
|
|
mean, std = kwargs.get("mean", None), kwargs.get("std", None)
|
2021-05-09 19:02:38 +02:00
|
|
|
if "hidden_dim" in kwargs:
|
|
|
|
hidden_dim1 = kwargs.get("hidden_dim")
|
|
|
|
hidden_dim2 = kwargs.get("hidden_dim")
|
|
|
|
else:
|
|
|
|
hidden_dim1 = kwargs.get("hidden_dim1", 200)
|
|
|
|
hidden_dim2 = kwargs.get("hidden_dim2", 100)
|
2021-04-29 08:28:37 +02:00
|
|
|
model = SuperSequential(
|
2021-05-06 10:38:58 +02:00
|
|
|
norm_cls(mean=mean, std=std),
|
|
|
|
SuperLinear(kwargs["input_dim"], hidden_dim1),
|
|
|
|
act_cls(),
|
|
|
|
SuperLinear(hidden_dim1, hidden_dim2),
|
|
|
|
act_cls(),
|
|
|
|
SuperLinear(hidden_dim2, kwargs["output_dim"]),
|
2021-04-29 08:28:37 +02:00
|
|
|
)
|
2021-05-12 10:28:05 +02:00
|
|
|
elif model_type == "norm_mlp":
|
|
|
|
act_cls = super_name2activation[kwargs["act_cls"]]
|
|
|
|
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
|
|
|
sub_layers, last_dim = [], kwargs["input_dim"]
|
|
|
|
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
2021-05-26 10:53:44 +02:00
|
|
|
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
2021-05-26 09:37:39 +02:00
|
|
|
if hidden_dim > 1:
|
|
|
|
sub_layers.append(norm_cls(hidden_dim, elementwise_affine=False))
|
2021-05-12 13:09:17 +02:00
|
|
|
sub_layers.append(act_cls())
|
2021-05-12 10:28:05 +02:00
|
|
|
last_dim = hidden_dim
|
|
|
|
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
|
|
|
model = SuperSequential(*sub_layers)
|
2021-05-12 14:32:50 +02:00
|
|
|
elif model_type == "dual_norm_mlp":
|
|
|
|
act_cls = super_name2activation[kwargs["act_cls"]]
|
|
|
|
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
|
|
|
sub_layers, last_dim = [], kwargs["input_dim"]
|
|
|
|
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
|
|
|
if i > 0:
|
|
|
|
sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
|
|
|
|
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
|
|
|
sub_layers.append(SuperDropout(kwargs["dropout"]))
|
|
|
|
sub_layers.append(SuperLinear(hidden_dim, hidden_dim))
|
|
|
|
sub_layers.append(act_cls())
|
|
|
|
last_dim = hidden_dim
|
|
|
|
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
|
|
|
model = SuperSequential(*sub_layers)
|
2021-07-04 13:59:06 +02:00
|
|
|
elif model_type == "quant_transformer":
|
|
|
|
raise NotImplementedError
|
2021-04-29 08:28:37 +02:00
|
|
|
else:
|
|
|
|
raise TypeError("Unkonwn model type: {:}".format(model_type))
|
2021-05-12 10:28:05 +02:00
|
|
|
return model
|