Reformulate Q-Transformer
This commit is contained in:
		| @@ -4,7 +4,7 @@ | ||||
| # Refer to: | ||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | ||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | ||||
| # python exps/trading/workflow_tt.py --market all --gpu 1 | ||||
| # python exps/trading/workflow_tt.py --gpu 1 --market csi300 | ||||
| ##################################################### | ||||
| import sys, argparse | ||||
| from pathlib import Path | ||||
| @@ -63,7 +63,8 @@ def main(xargs): | ||||
|         "class": "QuantTransformer", | ||||
|         "module_path": "trade_models", | ||||
|         "kwargs": { | ||||
|             "loss": "mse", | ||||
|             "net_config": None, | ||||
|             "opt_config": None, | ||||
|             "GPU": "0", | ||||
|             "metric": "loss", | ||||
|         }, | ||||
| @@ -107,20 +108,23 @@ def main(xargs): | ||||
|     provider_uri = "~/.qlib/qlib_data/cn_data" | ||||
|     qlib.init(provider_uri=provider_uri, region=REG_CN) | ||||
|  | ||||
|     save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market) | ||||
|     dataset = init_instance_by_config(dataset_config) | ||||
|     for irun in range(xargs.times): | ||||
|         xmodel_config = model_config.copy() | ||||
|         xmodel_config = update_gpu(xmodel_config, xags.gpu) | ||||
|         task = dict(model=xmodel_config, dataset=dataset_config, record=record_config) | ||||
|         run_exp(task_config, dataset, "Transformer", "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir) | ||||
|         xmodel_config = update_gpu(xmodel_config, xargs.gpu) | ||||
|         task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config) | ||||
|  | ||||
|         run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Vanilla Transformable Transformer") | ||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/tt-ml-runs", help="The checkpoint directory.") | ||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.") | ||||
|     parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.") | ||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||
|     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||
|     parser.add_argument("--market", type=str, default="csi300", help="The market indicator.") | ||||
|     parser.add_argument("--market", type=str, default="all", help="The market indicator.") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     main(args) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user