Update models
This commit is contained in:
parent
d557c328a8
commit
30fb8fad67
@ -1,6 +1,4 @@
|
|||||||
#######################################################
|
#######################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
|
||||||
#######################################################
|
|
||||||
# Use module in xlayers to construct different models #
|
# Use module in xlayers to construct different models #
|
||||||
#######################################################
|
#######################################################
|
||||||
from typing import List, Text, Dict, Any
|
from typing import List, Text, Dict, Any
|
||||||
@ -41,8 +39,8 @@ def get_model(config: Dict[Text, Any], **kwargs):
|
|||||||
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
||||||
sub_layers, last_dim = [], kwargs["input_dim"]
|
sub_layers, last_dim = [], kwargs["input_dim"]
|
||||||
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
||||||
if last_dim > 1:
|
if hidden_dim > 1:
|
||||||
sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
|
sub_layers.append(norm_cls(hidden_dim, elementwise_affine=False))
|
||||||
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
||||||
sub_layers.append(act_cls())
|
sub_layers.append(act_cls())
|
||||||
last_dim = hidden_dim
|
last_dim = hidden_dim
|
||||||
|
Loading…
Reference in New Issue
Block a user