####################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ####################################################### # Use module in xlayers to construct different models # ####################################################### from typing import List, Text, Dict, Any import torch __all__ = ["get_model"] from xlayers.super_core import SuperSequential from xlayers.super_core import SuperLinear from xlayers.super_core import super_name2norm from xlayers.super_core import super_name2activation def get_model(config: Dict[Text, Any], **kwargs): model_type = config.get("model_type", "simple_mlp") if model_type == "simple_mlp": act_cls = super_name2activation[kwargs["act_cls"]] norm_cls = super_name2norm[kwargs["norm_cls"]] mean, std = kwargs.get("mean", None), kwargs.get("std", None) if "hidden_dim" in kwargs: hidden_dim1 = kwargs.get("hidden_dim") hidden_dim2 = kwargs.get("hidden_dim") else: hidden_dim1 = kwargs.get("hidden_dim1", 200) hidden_dim2 = kwargs.get("hidden_dim2", 100) model = SuperSequential( norm_cls(mean=mean, std=std), SuperLinear(kwargs["input_dim"], hidden_dim1), act_cls(), SuperLinear(hidden_dim1, hidden_dim2), act_cls(), SuperLinear(hidden_dim2, kwargs["output_dim"]), ) elif model_type == "norm_mlp": act_cls = super_name2activation[kwargs["act_cls"]] norm_cls = super_name2norm[kwargs["norm_cls"]] sub_layers, last_dim = [], kwargs["input_dim"] for i, hidden_dim in enumerate(kwargs["hidden_dims"]): sub_layers.extend( [ norm_cls(last_dim, elementwise_affine=False), SuperLinear(last_dim, hidden_dim), act_cls(), ] ) last_dim = hidden_dim sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) model = SuperSequential(*sub_layers) else: raise TypeError("Unkonwn model type: {:}".format(model_type)) return model