Update Q workflow
This commit is contained in:
		| @@ -1,12 +1,12 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # Refer to: | # Refer to: | ||||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | ||||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | ||||||
| # python exps/trading/workflow_tt.py | # python exps/trading/workflow_tt.py | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, site, argparse | import sys, argparse | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||||
| @@ -15,19 +15,11 @@ if str(lib_dir) not in sys.path: | |||||||
|  |  | ||||||
| import qlib | import qlib | ||||||
| from qlib.config import C | from qlib.config import C | ||||||
| import pandas as pd |  | ||||||
| from qlib.config import REG_CN | from qlib.config import REG_CN | ||||||
| from qlib.contrib.model.gbdt import LGBModel | from qlib.utils import init_instance_by_config | ||||||
| from qlib.contrib.data.handler import Alpha158 |  | ||||||
| from qlib.contrib.strategy.strategy import TopkDropoutStrategy |  | ||||||
| from qlib.contrib.evaluate import ( |  | ||||||
|     backtest as normal_backtest, |  | ||||||
|     risk_analysis, |  | ||||||
| ) |  | ||||||
| from qlib.utils import exists_qlib_data, init_instance_by_config |  | ||||||
| from qlib.workflow import R | from qlib.workflow import R | ||||||
| from qlib.workflow.record_temp import SignalRecord, PortAnaRecord |  | ||||||
| from qlib.utils import flatten_dict | from qlib.utils import flatten_dict | ||||||
|  | from qlib.log import set_log_basic_config | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs): | def main(xargs): | ||||||
| @@ -73,13 +65,51 @@ def main(xargs): | |||||||
|         }, |         }, | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     task = {"model": model_config, "dataset": dataset_config} |     port_analysis_config = { | ||||||
|  |         "strategy": { | ||||||
|  |             "class": "TopkDropoutStrategy", | ||||||
|  |             "module_path": "qlib.contrib.strategy.strategy", | ||||||
|  |             "kwargs": { | ||||||
|  |                 "topk": 50, | ||||||
|  |                 "n_drop": 5, | ||||||
|  |             }, | ||||||
|  |         }, | ||||||
|  |         "backtest": { | ||||||
|  |             "verbose": False, | ||||||
|  |             "limit_threshold": 0.095, | ||||||
|  |             "account": 100000000, | ||||||
|  |             "benchmark": "SH000300", | ||||||
|  |             "deal_price": "close", | ||||||
|  |             "open_cost": 0.0005, | ||||||
|  |             "close_cost": 0.0015, | ||||||
|  |             "min_cost": 5, | ||||||
|  |         }, | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     record_config = [ | ||||||
|  |         {"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()}, | ||||||
|  |         { | ||||||
|  |             "class": "SigAnaRecord", | ||||||
|  |             "module_path": "qlib.workflow.record_temp", | ||||||
|  |             "kwargs": dict(ana_long_short=False, ann_scaler=252), | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |             "class": "PortAnaRecord", | ||||||
|  |             "module_path": "qlib.workflow.record_temp", | ||||||
|  |             "kwargs": dict(config=port_analysis_config), | ||||||
|  |         }, | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     task = dict(model=model_config, dataset=dataset_config, record=record_config) | ||||||
|  |  | ||||||
|     model = init_instance_by_config(model_config) |  | ||||||
|     dataset = init_instance_by_config(dataset_config) |  | ||||||
|  |  | ||||||
|     # start exp to train model |     # start exp to train model | ||||||
|     with R.start(experiment_name="train_tt_model"): |     with R.start(experiment_name="train_tt_model"): | ||||||
|  |         set_log_basic_config(R.get_recorder().root_uri / 'log.log') | ||||||
|  |  | ||||||
|  |         model = init_instance_by_config(model_config) | ||||||
|  |         dataset = init_instance_by_config(dataset_config) | ||||||
|  |  | ||||||
|         R.log_params(**flatten_dict(task)) |         R.log_params(**flatten_dict(task)) | ||||||
|         model.fit(dataset) |         model.fit(dataset) | ||||||
|         R.save_objects(trained_model=model) |         R.save_objects(trained_model=model) | ||||||
| @@ -87,14 +117,19 @@ def main(xargs): | |||||||
|         # prediction |         # prediction | ||||||
|         recorder = R.get_recorder() |         recorder = R.get_recorder() | ||||||
|         print(recorder) |         print(recorder) | ||||||
|         sr = SignalRecord(model, dataset, recorder) |  | ||||||
|         sr.generate() |  | ||||||
|  |  | ||||||
|         # backtest. If users want to use backtest based on their own prediction, |         for record in task["record"]: | ||||||
|         # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. |             record = record.copy() | ||||||
|         par = PortAnaRecord(recorder, port_analysis_config) |             if record["class"] == "SignalRecord": | ||||||
|         par.generate() |                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | ||||||
|      |                 record["kwargs"].update(srconf) | ||||||
|  |                 sr = init_instance_by_config(record) | ||||||
|  |                 sr.generate() | ||||||
|  |             else: | ||||||
|  |                 rconf = {"recorder": recorder} | ||||||
|  |                 record["kwargs"].update(rconf) | ||||||
|  |                 ar = init_instance_by_config(record) | ||||||
|  |                 ar.generate() | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|   | |||||||
| @@ -41,8 +41,8 @@ class QuantTransformer(Model): | |||||||
|   def __init__( |   def __init__( | ||||||
|     self, |     self, | ||||||
|     d_feat=6, |     d_feat=6, | ||||||
|     hidden_size=64, |     hidden_size=48, | ||||||
|     num_layers=2, |     depth=5, | ||||||
|     dropout=0.0, |     dropout=0.0, | ||||||
|     n_epochs=200, |     n_epochs=200, | ||||||
|     lr=0.001, |     lr=0.001, | ||||||
| @@ -62,7 +62,7 @@ class QuantTransformer(Model): | |||||||
|     # set hyper-parameters. |     # set hyper-parameters. | ||||||
|     self.d_feat = d_feat |     self.d_feat = d_feat | ||||||
|     self.hidden_size = hidden_size |     self.hidden_size = hidden_size | ||||||
|     self.num_layers = num_layers |     self.depth = depth | ||||||
|     self.dropout = dropout |     self.dropout = dropout | ||||||
|     self.n_epochs = n_epochs |     self.n_epochs = n_epochs | ||||||
|     self.lr = lr |     self.lr = lr | ||||||
| @@ -79,7 +79,7 @@ class QuantTransformer(Model): | |||||||
|       "Transformer parameters setting:" |       "Transformer parameters setting:" | ||||||
|       "\nd_feat : {}" |       "\nd_feat : {}" | ||||||
|       "\nhidden_size : {}" |       "\nhidden_size : {}" | ||||||
|       "\nnum_layers : {}" |       "\ndepth : {}" | ||||||
|       "\ndropout : {}" |       "\ndropout : {}" | ||||||
|       "\nn_epochs : {}" |       "\nn_epochs : {}" | ||||||
|       "\nlr : {}" |       "\nlr : {}" | ||||||
| @@ -93,7 +93,7 @@ class QuantTransformer(Model): | |||||||
|       "\nseed : {}".format( |       "\nseed : {}".format( | ||||||
|         d_feat, |         d_feat, | ||||||
|         hidden_size, |         hidden_size, | ||||||
|         num_layers, |         depth, | ||||||
|         dropout, |         dropout, | ||||||
|         n_epochs, |         n_epochs, | ||||||
|         lr, |         lr, | ||||||
| @@ -112,7 +112,9 @@ class QuantTransformer(Model): | |||||||
|       np.random.seed(self.seed) |       np.random.seed(self.seed) | ||||||
|       torch.manual_seed(self.seed) |       torch.manual_seed(self.seed) | ||||||
|  |  | ||||||
|     self.model = TransformerModel(d_feat=self.d_feat) |     self.model = TransformerModel(d_feat=self.d_feat, | ||||||
|  |                                   embed_dim=self.hidden_size, | ||||||
|  |                                   depth=self.depth) | ||||||
|     self.logger.info('model: {:}'.format(self.model)) |     self.logger.info('model: {:}'.format(self.model)) | ||||||
|     self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model))) |     self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model))) | ||||||
|    |    | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user