xautodl/lib/models/xcore.py

32 lines
1.1 KiB
Python
Raw Normal View History

2021-04-29 11:17:44 +02:00
#######################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
2021-04-29 08:28:37 +02:00
#######################################################
# Use module in xlayers to construct different models #
#######################################################
from typing import List, Text, Dict, Any
2021-04-29 10:30:47 +02:00
import torch
2021-04-29 08:28:37 +02:00
__all__ = ["get_model"]
2021-04-29 10:30:47 +02:00
from xlayers.super_core import SuperSequential
2021-04-29 08:28:37 +02:00
from xlayers.super_core import SuperSimpleNorm
2021-04-29 10:30:47 +02:00
from xlayers.super_core import SuperLeakyReLU
2021-04-29 08:28:37 +02:00
from xlayers.super_core import SuperLinear
def get_model(config: Dict[Text, Any], **kwargs):
model_type = config.get("model_type", "simple_mlp")
if model_type == "simple_mlp":
model = SuperSequential(
SuperSimpleNorm(kwargs["mean"], kwargs["std"]),
SuperLinear(kwargs["input_dim"], 200),
2021-04-29 10:30:47 +02:00
SuperLeakyReLU(),
2021-04-29 08:28:37 +02:00
SuperLinear(200, 100),
2021-04-29 10:30:47 +02:00
SuperLeakyReLU(),
2021-04-29 08:28:37 +02:00
SuperLinear(100, kwargs["output_dim"]),
)
else:
raise TypeError("Unkonwn model type: {:}".format(model_type))
return model