Refine TT workflow
This commit is contained in:
		| @@ -21,7 +21,6 @@ from qlib.workflow import R | |||||||
|  |  | ||||||
|  |  | ||||||
| class QResult: | class QResult: | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self._result = defaultdict(list) |         self._result = defaultdict(list) | ||||||
|  |  | ||||||
| @@ -42,14 +41,14 @@ class QResult: | |||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def full_str(xstr, space): |     def full_str(xstr, space): | ||||||
|     xformat = '{:' + str(space) + 's}' |         xformat = "{:" + str(space) + "s}" | ||||||
|         return xformat.format(str(xstr)) |         return xformat.format(str(xstr)) | ||||||
|  |  | ||||||
|   def info(self, keys: List[Text], separate: Text = '', space: int = 25, show=True): |     def info(self, keys: List[Text], separate: Text = "", space: int = 25, show=True): | ||||||
|         avaliable_keys = [] |         avaliable_keys = [] | ||||||
|         for key in keys: |         for key in keys: | ||||||
|             if key not in self.result: |             if key not in self.result: | ||||||
|         print('There are invalid key [{:}].'.format(key)) |                 print("There are invalid key [{:}].".format(key)) | ||||||
|             else: |             else: | ||||||
|                 avaliable_keys.append(key) |                 avaliable_keys.append(key) | ||||||
|         head_str = separate.join([self.full_str(x, space) for x in avaliable_keys]) |         head_str = separate.join([self.full_str(x, space) for x in avaliable_keys]) | ||||||
| @@ -58,7 +57,7 @@ class QResult: | |||||||
|             current_values = self._result[key] |             current_values = self._result[key] | ||||||
|             mean = np.mean(current_values) |             mean = np.mean(current_values) | ||||||
|             std = np.std(current_values) |             std = np.std(current_values) | ||||||
|       values.append('{:.4f} $\pm$ {:.4f}'.format(mean, std)) |             values.append("{:.4f} $\pm$ {:.4f}".format(mean, std)) | ||||||
|         value_str = separate.join([self.full_str(x, space) for x in values]) |         value_str = separate.join([self.full_str(x, space) for x in values]) | ||||||
|         if show: |         if show: | ||||||
|             print(head_str) |             print(head_str) | ||||||
| @@ -69,8 +68,8 @@ class QResult: | |||||||
|  |  | ||||||
| def compare_results(heads, values, names, space=10): | def compare_results(heads, values, names, space=10): | ||||||
|     for idx, x in enumerate(heads): |     for idx, x in enumerate(heads): | ||||||
|     assert x == heads[0], '[{:}] {:} vs {:}'.format(idx, x, heads[0]) |         assert x == heads[0], "[{:}] {:} vs {:}".format(idx, x, heads[0]) | ||||||
|   new_head = QResult.full_str('Name', space) + heads[0] |     new_head = QResult.full_str("Name", space) + heads[0] | ||||||
|     print(new_head) |     print(new_head) | ||||||
|     for name, value in zip(names, values): |     for name, value in zip(names, values): | ||||||
|         xline = QResult.full_str(name, space) + value |         xline = QResult.full_str(name, space) + value | ||||||
| @@ -92,19 +91,21 @@ def main(xargs): | |||||||
|     R.reset_default_uri(xargs.save_dir) |     R.reset_default_uri(xargs.save_dir) | ||||||
|     experiments = R.list_experiments() |     experiments = R.list_experiments() | ||||||
|  |  | ||||||
|     key_map = {"IC": "IC", |     key_map = { | ||||||
|  |         "IC": "IC", | ||||||
|         "ICIR": "ICIR", |         "ICIR": "ICIR", | ||||||
|         "Rank IC": "Rank_IC", |         "Rank IC": "Rank_IC", | ||||||
|         "Rank ICIR": "Rank_ICIR", |         "Rank ICIR": "Rank_ICIR", | ||||||
|         "excess_return_with_cost.annualized_return": "Annualized_Return", |         "excess_return_with_cost.annualized_return": "Annualized_Return", | ||||||
|         "excess_return_with_cost.information_ratio": "Information_Ratio", |         "excess_return_with_cost.information_ratio": "Information_Ratio", | ||||||
|                "excess_return_with_cost.max_drawdown": "Max_Drawdown"} |         "excess_return_with_cost.max_drawdown": "Max_Drawdown", | ||||||
|  |     } | ||||||
|     all_keys = list(key_map.values()) |     all_keys = list(key_map.values()) | ||||||
|  |  | ||||||
|     print("There are {:} experiments.".format(len(experiments))) |     print("There are {:} experiments.".format(len(experiments))) | ||||||
|     head_strs, value_strs, names = [], [], [] |     head_strs, value_strs, names = [], [], [] | ||||||
|     for idx, (key, experiment) in enumerate(experiments.items()): |     for idx, (key, experiment) in enumerate(experiments.items()): | ||||||
|         if experiment.id == '0': |         if experiment.id == "0": | ||||||
|             continue |             continue | ||||||
|         recorders = experiment.list_recorders() |         recorders = experiment.list_recorders() | ||||||
|         recorders, not_finished = filter_finished(recorders) |         recorders, not_finished = filter_finished(recorders) | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ | |||||||
| # Refer to: | # Refer to: | ||||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | ||||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | ||||||
| # python exps/trading/workflow_tt.py --market all | # python exps/trading/workflow_tt.py --market all --gpu 1 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, argparse | import sys, argparse | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -13,6 +13,10 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from procedures.q_exps import update_gpu | ||||||
|  | from procedures.q_exps import update_market | ||||||
|  | from procedures.q_exps import run_exp | ||||||
|  |  | ||||||
| import qlib | import qlib | ||||||
| from qlib.config import C | from qlib.config import C | ||||||
| from qlib.config import REG_CN | from qlib.config import REG_CN | ||||||
| @@ -100,44 +104,23 @@ def main(xargs): | |||||||
|         }, |         }, | ||||||
|     ] |     ] | ||||||
|  |  | ||||||
|     task = dict(model=model_config, dataset=dataset_config, record=record_config) |     provider_uri = "~/.qlib/qlib_data/cn_data" | ||||||
|  |     qlib.init(provider_uri=provider_uri, region=REG_CN) | ||||||
|  |  | ||||||
|     # start exp to train model |  | ||||||
|     with R.start(experiment_name="tt_model", uri=xargs.save_dir + "-" + xargs.market): |  | ||||||
|         set_log_basic_config(R.get_recorder().root_uri / "log.log") |  | ||||||
|  |  | ||||||
|         model = init_instance_by_config(model_config) |  | ||||||
|     dataset = init_instance_by_config(dataset_config) |     dataset = init_instance_by_config(dataset_config) | ||||||
|  |     for irun in range(xargs.times): | ||||||
|         R.log_params(**flatten_dict(task)) |         xmodel_config = model_config.copy() | ||||||
|         model.fit(dataset) |         xmodel_config = update_gpu(xmodel_config, xags.gpu) | ||||||
|         R.save_objects(trained_model=model) |         task = dict(model=xmodel_config, dataset=dataset_config, record=record_config) | ||||||
|  |         run_exp(task_config, dataset, "Transformer", "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir) | ||||||
|         # prediction |  | ||||||
|         recorder = R.get_recorder() |  | ||||||
|         print(recorder) |  | ||||||
|  |  | ||||||
|         for record in task["record"]: |  | ||||||
|             record = record.copy() |  | ||||||
|             if record["class"] == "SignalRecord": |  | ||||||
|                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} |  | ||||||
|                 record["kwargs"].update(srconf) |  | ||||||
|                 sr = init_instance_by_config(record) |  | ||||||
|                 sr.generate() |  | ||||||
|             else: |  | ||||||
|                 rconf = {"recorder": recorder} |  | ||||||
|                 record["kwargs"].update(rconf) |  | ||||||
|                 ar = init_instance_by_config(record) |  | ||||||
|                 ar.generate() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     parser = argparse.ArgumentParser("Vanilla Transformable Transformer") |     parser = argparse.ArgumentParser("Vanilla Transformable Transformer") | ||||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/tt-ml-runs", help="The checkpoint directory.") |     parser.add_argument("--save_dir", type=str, default="./outputs/tt-ml-runs", help="The checkpoint directory.") | ||||||
|  |     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||||
|  |     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||||
|     parser.add_argument("--market", type=str, default="csi300", help="The market indicator.") |     parser.add_argument("--market", type=str, default="csi300", help="The market indicator.") | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     provider_uri = "~/.qlib/qlib_data/cn_data" |  | ||||||
|     qlib.init(provider_uri=provider_uri, region=REG_CN) |  | ||||||
|  |  | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
							
								
								
									
										63
									
								
								lib/procedures/q_exps.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								lib/procedures/q_exps.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
|  | ##################################################### | ||||||
|  |  | ||||||
|  | import qlib | ||||||
|  | from qlib.utils import init_instance_by_config | ||||||
|  | from qlib.workflow import R | ||||||
|  | from qlib.utils import flatten_dict | ||||||
|  | from qlib.log import set_log_basic_config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def update_gpu(config, gpu): | ||||||
|  |     config = config.copy() | ||||||
|  |     if "task" in config and "GPU" in config["task"]["model"]: | ||||||
|  |         config["task"]["model"]["GPU"] = gpu | ||||||
|  |     elif "model" in config and "GPU" in config["model"]: | ||||||
|  |         config["model"]["GPU"] = gpu | ||||||
|  |     elif "GPU" in config: | ||||||
|  |         config["GPU"] = gpu | ||||||
|  |     return config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def update_market(config, market): | ||||||
|  |     config = config.copy() | ||||||
|  |     config["market"] = market | ||||||
|  |     config["data_handler_config"]["instruments"] = market | ||||||
|  |     return config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||||
|  |  | ||||||
|  |     # model initiaiton | ||||||
|  |     print("") | ||||||
|  |     print("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri)) | ||||||
|  |     print("dataset={:}".format(dataset)) | ||||||
|  |  | ||||||
|  |     model = init_instance_by_config(task_config["model"]) | ||||||
|  |  | ||||||
|  |     # start exp | ||||||
|  |     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri): | ||||||
|  |  | ||||||
|  |         log_file = R.get_recorder().root_uri / "{:}.log".format(experiment_name) | ||||||
|  |         set_log_basic_config(log_file) | ||||||
|  |  | ||||||
|  |         # train model | ||||||
|  |         R.log_params(**flatten_dict(task_config)) | ||||||
|  |         model.fit(dataset) | ||||||
|  |         recorder = R.get_recorder() | ||||||
|  |         R.save_objects(**{"model.pkl": model}) | ||||||
|  |  | ||||||
|  |         # generate records: prediction, backtest, and analysis | ||||||
|  |         for record in task_config["record"]: | ||||||
|  |             record = record.copy() | ||||||
|  |             if record["class"] == "SignalRecord": | ||||||
|  |                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | ||||||
|  |                 record["kwargs"].update(srconf) | ||||||
|  |                 sr = init_instance_by_config(record) | ||||||
|  |                 sr.generate() | ||||||
|  |             else: | ||||||
|  |                 rconf = {"recorder": recorder} | ||||||
|  |                 record["kwargs"].update(rconf) | ||||||
|  |                 ar = init_instance_by_config(record) | ||||||
|  |                 ar.generate() | ||||||
		Reference in New Issue
	
	Block a user