Update baselines
This commit is contained in:
		 Submodule .latent-data/qlib updated: b14a559a52...49697b1f15
									
								
							
							
								
								
									
										83
									
								
								configs/qlib/workflow_config_alstm_Alpha360.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								configs/qlib/workflow_config_alstm_Alpha360.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | |||||||
|  | qlib_init: | ||||||
|  |     provider_uri: "~/.qlib/qlib_data/cn_data" | ||||||
|  |     region: cn | ||||||
|  | market: &market all | ||||||
|  | benchmark: &benchmark SH000300 | ||||||
|  | data_handler_config: &data_handler_config | ||||||
|  |     start_time: 2008-01-01 | ||||||
|  |     end_time: 2020-08-01 | ||||||
|  |     fit_start_time: 2008-01-01 | ||||||
|  |     fit_end_time: 2014-12-31 | ||||||
|  |     instruments: *market | ||||||
|  |     infer_processors: | ||||||
|  |         - class: RobustZScoreNorm | ||||||
|  |           kwargs: | ||||||
|  |               fields_group: feature | ||||||
|  |               clip_outlier: true | ||||||
|  |         - class: Fillna | ||||||
|  |           kwargs: | ||||||
|  |               fields_group: feature | ||||||
|  |     learn_processors: | ||||||
|  |         - class: DropnaLabel | ||||||
|  |         - class: CSRankNorm | ||||||
|  |           kwargs: | ||||||
|  |               fields_group: label | ||||||
|  |     label: ["Ref($close, -2) / Ref($close, -1) - 1"] | ||||||
|  | port_analysis_config: &port_analysis_config | ||||||
|  |     strategy: | ||||||
|  |         class: TopkDropoutStrategy | ||||||
|  |         module_path: qlib.contrib.strategy.strategy | ||||||
|  |         kwargs: | ||||||
|  |             topk: 50 | ||||||
|  |             n_drop: 5 | ||||||
|  |     backtest: | ||||||
|  |         verbose: False | ||||||
|  |         limit_threshold: 0.095 | ||||||
|  |         account: 100000000 | ||||||
|  |         benchmark: *benchmark | ||||||
|  |         deal_price: close | ||||||
|  |         open_cost: 0.0005 | ||||||
|  |         close_cost: 0.0015 | ||||||
|  |         min_cost: 5 | ||||||
|  | task: | ||||||
|  |     model: | ||||||
|  |         class: ALSTM | ||||||
|  |         module_path: qlib.contrib.model.pytorch_alstm | ||||||
|  |         kwargs: | ||||||
|  |             d_feat: 6 | ||||||
|  |             hidden_size: 64 | ||||||
|  |             num_layers: 2 | ||||||
|  |             dropout: 0.0 | ||||||
|  |             n_epochs: 200 | ||||||
|  |             lr: 1e-3 | ||||||
|  |             early_stop: 20 | ||||||
|  |             batch_size: 800 | ||||||
|  |             metric: loss | ||||||
|  |             loss: mse | ||||||
|  |             GPU: 0 | ||||||
|  |             rnn_type: GRU | ||||||
|  |     dataset: | ||||||
|  |         class: DatasetH | ||||||
|  |         module_path: qlib.data.dataset | ||||||
|  |         kwargs: | ||||||
|  |             handler: | ||||||
|  |                 class: Alpha360 | ||||||
|  |                 module_path: qlib.contrib.data.handler | ||||||
|  |                 kwargs: *data_handler_config | ||||||
|  |             segments: | ||||||
|  |                 train: [2008-01-01, 2014-12-31] | ||||||
|  |                 valid: [2015-01-01, 2016-12-31] | ||||||
|  |                 test: [2017-01-01, 2020-08-01] | ||||||
|  |     record:  | ||||||
|  |         - class: SignalRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs: {} | ||||||
|  |         - class: SigAnaRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs:  | ||||||
|  |             ana_long_short: False | ||||||
|  |             ann_scaler: 252 | ||||||
|  |         - class: PortAnaRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs:  | ||||||
|  |             config: *port_analysis_config | ||||||
							
								
								
									
										73
									
								
								configs/qlib/workflow_config_lightgbm_Alpha360.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								configs/qlib/workflow_config_lightgbm_Alpha360.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | |||||||
|  | qlib_init: | ||||||
|  |     provider_uri: "~/.qlib/qlib_data/cn_data" | ||||||
|  |     region: cn | ||||||
|  | market: &market all | ||||||
|  | benchmark: &benchmark SH000300 | ||||||
|  | data_handler_config: &data_handler_config | ||||||
|  |     start_time: 2008-01-01 | ||||||
|  |     end_time: 2020-08-01 | ||||||
|  |     fit_start_time: 2008-01-01 | ||||||
|  |     fit_end_time: 2014-12-31 | ||||||
|  |     instruments: *market | ||||||
|  |     infer_processors: [] | ||||||
|  |     learn_processors: | ||||||
|  |         - class: DropnaLabel | ||||||
|  |         - class: CSRankNorm | ||||||
|  |           kwargs: | ||||||
|  |               fields_group: label | ||||||
|  |     label: ["Ref($close, -2) / Ref($close, -1) - 1"] | ||||||
|  | port_analysis_config: &port_analysis_config | ||||||
|  |     strategy: | ||||||
|  |         class: TopkDropoutStrategy | ||||||
|  |         module_path: qlib.contrib.strategy.strategy | ||||||
|  |         kwargs: | ||||||
|  |             topk: 50 | ||||||
|  |             n_drop: 5 | ||||||
|  |     backtest: | ||||||
|  |         verbose: False | ||||||
|  |         limit_threshold: 0.095 | ||||||
|  |         account: 100000000 | ||||||
|  |         benchmark: *benchmark | ||||||
|  |         deal_price: close | ||||||
|  |         open_cost: 0.0005 | ||||||
|  |         close_cost: 0.0015 | ||||||
|  |         min_cost: 5 | ||||||
|  | task: | ||||||
|  |     model: | ||||||
|  |         class: LGBModel | ||||||
|  |         module_path: qlib.contrib.model.gbdt | ||||||
|  |         kwargs: | ||||||
|  |             loss: mse | ||||||
|  |             colsample_bytree: 0.8879 | ||||||
|  |             learning_rate: 0.0421 | ||||||
|  |             subsample: 0.8789 | ||||||
|  |             lambda_l1: 205.6999 | ||||||
|  |             lambda_l2: 580.9768 | ||||||
|  |             max_depth: 8 | ||||||
|  |             num_leaves: 210 | ||||||
|  |             num_threads: 20 | ||||||
|  |     dataset: | ||||||
|  |         class: DatasetH | ||||||
|  |         module_path: qlib.data.dataset | ||||||
|  |         kwargs: | ||||||
|  |             handler: | ||||||
|  |                 class: Alpha360 | ||||||
|  |                 module_path: qlib.contrib.data.handler | ||||||
|  |                 kwargs: *data_handler_config | ||||||
|  |             segments: | ||||||
|  |                 train: [2008-01-01, 2014-12-31] | ||||||
|  |                 valid: [2015-01-01, 2016-12-31] | ||||||
|  |                 test: [2017-01-01, 2020-08-01] | ||||||
|  |     record:  | ||||||
|  |         - class: SignalRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs: {} | ||||||
|  |         - class: SigAnaRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs:  | ||||||
|  |             ana_long_short: False | ||||||
|  |             ann_scaler: 252 | ||||||
|  |         - class: PortAnaRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs:  | ||||||
|  |             config: *port_analysis_config | ||||||
							
								
								
									
										127
									
								
								exps/trading/baselines.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								exps/trading/baselines.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,127 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
|  | ##################################################### | ||||||
|  | # python exps/trading/baselines.py --alg GRU | ||||||
|  | # python exps/trading/baselines.py --alg LSTM | ||||||
|  | # python exps/trading/baselines.py --alg ALSTM | ||||||
|  | # python exps/trading/baselines.py --alg XGBoost | ||||||
|  | # python exps/trading/baselines.py --alg LightGBM | ||||||
|  | ##################################################### | ||||||
|  | import sys, argparse | ||||||
|  | from collections import OrderedDict | ||||||
|  | from pathlib import Path | ||||||
|  | from pprint import pprint | ||||||
|  | import ruamel.yaml as yaml | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | import qlib | ||||||
|  | from qlib.utils import init_instance_by_config | ||||||
|  | from qlib.workflow import R | ||||||
|  | from qlib.utils import flatten_dict | ||||||
|  | from qlib.log import set_log_basic_config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def retrieve_configs(): | ||||||
|  |     # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ | ||||||
|  |     config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() | ||||||
|  |     # algorithm to file names | ||||||
|  |     alg2names = OrderedDict() | ||||||
|  |     alg2names["GRU"] = "workflow_config_gru_Alpha360.yaml" | ||||||
|  |     alg2names["LSTM"] = "workflow_config_lstm_Alpha360.yaml" | ||||||
|  |     # A dual-stage attention-based recurrent neural network for time series prediction, IJCAI-2017 | ||||||
|  |     alg2names["ALSTM"] = "workflow_config_alstm_Alpha360.yaml" | ||||||
|  |     # XGBoost: A Scalable Tree Boosting System, KDD-2016 | ||||||
|  |     alg2names["XGBoost"] = "workflow_config_xgboost_Alpha360.yaml" | ||||||
|  |     # LightGBM: A Highly Efficient Gradient Boosting Decision Tree, NeurIPS-2017 | ||||||
|  |     alg2names["LightGBM"] = "workflow_config_lightgbm_Alpha360.yaml" | ||||||
|  |  | ||||||
|  |     # find the yaml paths | ||||||
|  |     alg2paths = OrderedDict() | ||||||
|  |     for idx, (alg, name) in enumerate(alg2names.items()): | ||||||
|  |         path = config_dir / name | ||||||
|  |         assert path.exists(), "{:} does not exist.".format(path) | ||||||
|  |         alg2paths[alg] = str(path) | ||||||
|  |         print("The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format(idx, len(alg2names), alg, path)) | ||||||
|  |     return alg2paths | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def update_gpu(config, gpu): | ||||||
|  |     config = config.copy() | ||||||
|  |     if "GPU" in config["task"]["model"]: | ||||||
|  |         config["task"]["model"]["GPU"] = gpu | ||||||
|  |     return config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def update_market(config, market): | ||||||
|  |     config = config.copy() | ||||||
|  |     config["market"] = market | ||||||
|  |     config["data_handler_config"]["instruments"] = market | ||||||
|  |     return config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||||
|  |  | ||||||
|  |     # model initiaiton | ||||||
|  |     model = init_instance_by_config(task_config["model"]) | ||||||
|  |  | ||||||
|  |     # start exp | ||||||
|  |     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri): | ||||||
|  |  | ||||||
|  |         log_file = R.get_recorder().root_uri / '{:}.log'.format(experiment_name) | ||||||
|  |         set_log_basic_config(log_file) | ||||||
|  |  | ||||||
|  |         # train model | ||||||
|  |         R.log_params(**flatten_dict(task_config)) | ||||||
|  |         model.fit(dataset) | ||||||
|  |         recorder = R.get_recorder() | ||||||
|  |         R.save_objects(**{"model.pkl": model}) | ||||||
|  |  | ||||||
|  |         # generate records: prediction, backtest, and analysis | ||||||
|  |         for record in task_config["record"]: | ||||||
|  |             record = record.copy() | ||||||
|  |             if record["class"] == "SignalRecord": | ||||||
|  |                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | ||||||
|  |                 record["kwargs"].update(srconf) | ||||||
|  |                 sr = init_instance_by_config(record) | ||||||
|  |                 sr.generate() | ||||||
|  |             else: | ||||||
|  |                 rconf = {"recorder": recorder} | ||||||
|  |                 record["kwargs"].update(rconf) | ||||||
|  |                 ar = init_instance_by_config(record) | ||||||
|  |                 ar.generate() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(xargs, exp_yaml): | ||||||
|  |     assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml) | ||||||
|  |  | ||||||
|  |     with open(exp_yaml) as fp: | ||||||
|  |         config = yaml.safe_load(fp) | ||||||
|  |     config = update_gpu(config, xargs.gpu) | ||||||
|  |     # config = update_market(config, 'csi300') | ||||||
|  |  | ||||||
|  |     qlib.init(**config.get("qlib_init")) | ||||||
|  |     dataset_config = config.get("task").get("dataset") | ||||||
|  |     dataset = init_instance_by_config(dataset_config) | ||||||
|  |     pprint('args: {:}'.format(xargs)) | ||||||
|  |     pprint(dataset_config) | ||||||
|  |     pprint(dataset) | ||||||
|  |  | ||||||
|  |     for irun in range(xargs.times): | ||||||
|  |         run_exp(config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |  | ||||||
|  |     alg2paths = retrieve_configs() | ||||||
|  |  | ||||||
|  |     parser = argparse.ArgumentParser("Baselines") | ||||||
|  |     parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.") | ||||||
|  |     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||||
|  |     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||||
|  |     parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     main(args, alg2paths[args.alg]) | ||||||
| @@ -104,7 +104,7 @@ def main(xargs): | |||||||
|  |  | ||||||
|  |  | ||||||
|     # start exp to train model |     # start exp to train model | ||||||
|     with R.start(experiment_name="train_tt_model"): |     with R.start(experiment_name="tt_model", uri=xargs.save_dir): | ||||||
|         set_log_basic_config(R.get_recorder().root_uri / 'log.log') |         set_log_basic_config(R.get_recorder().root_uri / 'log.log') | ||||||
|  |  | ||||||
|         model = init_instance_by_config(model_config) |         model = init_instance_by_config(model_config) | ||||||
| @@ -139,8 +139,6 @@ if __name__ == "__main__": | |||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir |     provider_uri = "~/.qlib/qlib_data/cn_data"  # target_dir | ||||||
|     exp_manager = C.exp_manager |     qlib.init(provider_uri=provider_uri, region=REG_CN) | ||||||
|     exp_manager["kwargs"]["uri"] = "file:{:}".format(Path(args.save_dir).resolve()) |  | ||||||
|     qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager) |  | ||||||
|  |  | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user