Fix small bugs
This commit is contained in:
		| @@ -6,6 +6,7 @@ | ||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | ||||
| # python exps/trading/workflow_tt.py --gpu 1 --market csi300 | ||||
| ##################################################### | ||||
| import yaml | ||||
| import argparse | ||||
|  | ||||
| from xautodl.procedures.q_exps import update_gpu | ||||
| @@ -57,7 +58,7 @@ def main(xargs): | ||||
|  | ||||
|     model_config = { | ||||
|         "class": "QuantTransformer", | ||||
|         "module_path": "trade_models", | ||||
|         "module_path": "xautodl.trade_models.quant_transformer", | ||||
|         "kwargs": { | ||||
|             "net_config": None, | ||||
|             "opt_config": None, | ||||
| @@ -108,6 +109,62 @@ def main(xargs): | ||||
|     provider_uri = "~/.qlib/qlib_data/cn_data" | ||||
|     qlib.init(provider_uri=provider_uri, region=REG_CN) | ||||
|  | ||||
|     from qlib.utils import init_instance_by_config | ||||
|  | ||||
|     xconfig = """ | ||||
| model: | ||||
|         class: SFM | ||||
|         module_path: qlib.contrib.model.pytorch_sfm | ||||
|         kwargs: | ||||
|             d_feat: 6 | ||||
|             hidden_size: 64 | ||||
|             output_dim: 32 | ||||
|             freq_dim: 25 | ||||
|             dropout_W: 0.5 | ||||
|             dropout_U: 0.5 | ||||
|             n_epochs: 20 | ||||
|             lr: 1e-3 | ||||
|             batch_size: 1600 | ||||
|             early_stop: 20 | ||||
|             eval_steps: 5 | ||||
|             loss: mse | ||||
|             optimizer: adam | ||||
|             GPU: 0 | ||||
| """ | ||||
|     xconfig = """ | ||||
| model: | ||||
|         class: TabnetModel | ||||
|         module_path: qlib.contrib.model.pytorch_tabnet | ||||
|         kwargs: | ||||
|             d_feat: 360 | ||||
|             pretrain: True | ||||
| """ | ||||
|     xconfig = """ | ||||
| model: | ||||
|         class: GRU | ||||
|         module_path: qlib.contrib.model.pytorch_gru | ||||
|         kwargs: | ||||
|             d_feat: 6 | ||||
|             hidden_size: 64 | ||||
|             num_layers: 4 | ||||
|             dropout: 0.0 | ||||
|             n_epochs: 200 | ||||
|             lr: 0.001 | ||||
|             early_stop: 20 | ||||
|             batch_size: 800 | ||||
|             metric: loss | ||||
|             loss: mse | ||||
|             GPU: 0 | ||||
| """ | ||||
|     xconfig = yaml.safe_load(xconfig) | ||||
|     model = init_instance_by_config(xconfig["model"]) | ||||
|     from xautodl.utils.flop_benchmark import count_parameters_in_MB | ||||
|  | ||||
|     # print(count_parameters_in_MB(model.tabnet_model)) | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|  | ||||
|     save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market) | ||||
|     dataset = init_instance_by_config(dataset_config) | ||||
|     for irun in range(xargs.times): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user