Update Q workflow
This commit is contained in:
		| @@ -1,12 +1,12 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ##################################################### | ||||
| # Refer to: | ||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | ||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | ||||
| # python exps/trading/workflow_tt.py | ||||
| ##################################################### | ||||
| import sys, site, argparse | ||||
| import sys, argparse | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| @@ -15,19 +15,11 @@ if str(lib_dir) not in sys.path: | ||||
|  | ||||
| import qlib | ||||
| from qlib.config import C | ||||
| import pandas as pd | ||||
| from qlib.config import REG_CN | ||||
| from qlib.contrib.model.gbdt import LGBModel | ||||
| from qlib.contrib.data.handler import Alpha158 | ||||
| from qlib.contrib.strategy.strategy import TopkDropoutStrategy | ||||
| from qlib.contrib.evaluate import ( | ||||
|     backtest as normal_backtest, | ||||
|     risk_analysis, | ||||
| ) | ||||
| from qlib.utils import exists_qlib_data, init_instance_by_config | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.workflow.record_temp import SignalRecord, PortAnaRecord | ||||
| from qlib.utils import flatten_dict | ||||
| from qlib.log import set_log_basic_config | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
| @@ -73,13 +65,51 @@ def main(xargs): | ||||
|         }, | ||||
|     } | ||||
|  | ||||
|     task = {"model": model_config, "dataset": dataset_config} | ||||
|     port_analysis_config = { | ||||
|         "strategy": { | ||||
|             "class": "TopkDropoutStrategy", | ||||
|             "module_path": "qlib.contrib.strategy.strategy", | ||||
|             "kwargs": { | ||||
|                 "topk": 50, | ||||
|                 "n_drop": 5, | ||||
|             }, | ||||
|         }, | ||||
|         "backtest": { | ||||
|             "verbose": False, | ||||
|             "limit_threshold": 0.095, | ||||
|             "account": 100000000, | ||||
|             "benchmark": "SH000300", | ||||
|             "deal_price": "close", | ||||
|             "open_cost": 0.0005, | ||||
|             "close_cost": 0.0015, | ||||
|             "min_cost": 5, | ||||
|         }, | ||||
|     } | ||||
|  | ||||
|     record_config = [ | ||||
|         {"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()}, | ||||
|         { | ||||
|             "class": "SigAnaRecord", | ||||
|             "module_path": "qlib.workflow.record_temp", | ||||
|             "kwargs": dict(ana_long_short=False, ann_scaler=252), | ||||
|         }, | ||||
|         { | ||||
|             "class": "PortAnaRecord", | ||||
|             "module_path": "qlib.workflow.record_temp", | ||||
|             "kwargs": dict(config=port_analysis_config), | ||||
|         }, | ||||
|     ] | ||||
|  | ||||
|     task = dict(model=model_config, dataset=dataset_config, record=record_config) | ||||
|  | ||||
|  | ||||
|     # start exp to train model | ||||
|     with R.start(experiment_name="train_tt_model"): | ||||
|         set_log_basic_config(R.get_recorder().root_uri / 'log.log') | ||||
|  | ||||
|         model = init_instance_by_config(model_config) | ||||
|         dataset = init_instance_by_config(dataset_config) | ||||
|  | ||||
|     # start exp to train model | ||||
|     with R.start(experiment_name="train_tt_model"): | ||||
|         R.log_params(**flatten_dict(task)) | ||||
|         model.fit(dataset) | ||||
|         R.save_objects(trained_model=model) | ||||
| @@ -87,14 +117,19 @@ def main(xargs): | ||||
|         # prediction | ||||
|         recorder = R.get_recorder() | ||||
|         print(recorder) | ||||
|         sr = SignalRecord(model, dataset, recorder) | ||||
|  | ||||
|         for record in task["record"]: | ||||
|             record = record.copy() | ||||
|             if record["class"] == "SignalRecord": | ||||
|                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | ||||
|                 record["kwargs"].update(srconf) | ||||
|                 sr = init_instance_by_config(record) | ||||
|                 sr.generate() | ||||
|  | ||||
|         # backtest. If users want to use backtest based on their own prediction, | ||||
|         # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. | ||||
|         par = PortAnaRecord(recorder, port_analysis_config) | ||||
|         par.generate() | ||||
|      | ||||
|             else: | ||||
|                 rconf = {"recorder": recorder} | ||||
|                 record["kwargs"].update(rconf) | ||||
|                 ar = init_instance_by_config(record) | ||||
|                 ar.generate() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
| @@ -41,8 +41,8 @@ class QuantTransformer(Model): | ||||
|   def __init__( | ||||
|     self, | ||||
|     d_feat=6, | ||||
|     hidden_size=64, | ||||
|     num_layers=2, | ||||
|     hidden_size=48, | ||||
|     depth=5, | ||||
|     dropout=0.0, | ||||
|     n_epochs=200, | ||||
|     lr=0.001, | ||||
| @@ -62,7 +62,7 @@ class QuantTransformer(Model): | ||||
|     # set hyper-parameters. | ||||
|     self.d_feat = d_feat | ||||
|     self.hidden_size = hidden_size | ||||
|     self.num_layers = num_layers | ||||
|     self.depth = depth | ||||
|     self.dropout = dropout | ||||
|     self.n_epochs = n_epochs | ||||
|     self.lr = lr | ||||
| @@ -79,7 +79,7 @@ class QuantTransformer(Model): | ||||
|       "Transformer parameters setting:" | ||||
|       "\nd_feat : {}" | ||||
|       "\nhidden_size : {}" | ||||
|       "\nnum_layers : {}" | ||||
|       "\ndepth : {}" | ||||
|       "\ndropout : {}" | ||||
|       "\nn_epochs : {}" | ||||
|       "\nlr : {}" | ||||
| @@ -93,7 +93,7 @@ class QuantTransformer(Model): | ||||
|       "\nseed : {}".format( | ||||
|         d_feat, | ||||
|         hidden_size, | ||||
|         num_layers, | ||||
|         depth, | ||||
|         dropout, | ||||
|         n_epochs, | ||||
|         lr, | ||||
| @@ -112,7 +112,9 @@ class QuantTransformer(Model): | ||||
|       np.random.seed(self.seed) | ||||
|       torch.manual_seed(self.seed) | ||||
|  | ||||
|     self.model = TransformerModel(d_feat=self.d_feat) | ||||
|     self.model = TransformerModel(d_feat=self.d_feat, | ||||
|                                   embed_dim=self.hidden_size, | ||||
|                                   depth=self.depth) | ||||
|     self.logger.info('model: {:}'.format(self.model)) | ||||
|     self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model))) | ||||
|    | ||||
|   | ||||
		Reference in New Issue
	
	Block a user