Update baselines
This commit is contained in:
		
							
								
								
									
										29
									
								
								lib/models/xcore.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								lib/models/xcore.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ####################################################### | ||||
| # Use module in xlayers to construct different models # | ||||
| ####################################################### | ||||
| from typing import List, Text, Dict, Any | ||||
|  | ||||
| __all__ = ["get_model"] | ||||
|  | ||||
|  | ||||
| from xlayers.super_core import SuperSequential, SuperMLPv1 | ||||
| from xlayers.super_core import SuperSimpleNorm | ||||
| from xlayers.super_core import SuperLinear | ||||
|  | ||||
|  | ||||
| def get_model(config: Dict[Text, Any], **kwargs): | ||||
|     model_type = config.get("model_type", "simple_mlp") | ||||
|     if model_type == "simple_mlp": | ||||
|         model = SuperSequential( | ||||
|             SuperSimpleNorm(kwargs["mean"], kwargs["std"]), | ||||
|             SuperLinear(kwargs["input_dim"], 200), | ||||
|             torch.nn.LeakyReLU(), | ||||
|             SuperLinear(200, 100), | ||||
|             torch.nn.LeakyReLU(), | ||||
|             SuperLinear(100, kwargs["output_dim"]), | ||||
|         ) | ||||
|     else: | ||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||
|     return model | ||||
		Reference in New Issue
	
	Block a user