From 9f7eca0e581e538fac9c843cd5b171139f70af7a Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 11 Mar 2021 13:07:08 +0000 Subject: [PATCH] Add scripts for Q --- .latent-data/qlib | 2 +- exps/trading/baselines.py | 3 ++- lib/procedures/q_exps.py | 2 +- lib/trade_models/quant_transformer.py | 4 ++-- scripts/trade/baseline.sh | 22 ++++++++++++++++++++++ 5 files changed, 28 insertions(+), 5 deletions(-) create mode 100644 scripts/trade/baseline.sh diff --git a/.latent-data/qlib b/.latent-data/qlib index 0ef7c8e..e626264 160000 --- a/.latent-data/qlib +++ b/.latent-data/qlib @@ -1 +1 @@ -Subproject commit 0ef7c8e0e62f1b08f89f676eac97cbbd2bcc1657 +Subproject commit e626264d5aebfb28295abd00d44827621bec33a5 diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index 2cfb696..8646bf9 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -1,10 +1,11 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ##################################################### +# python exps/trading/baselines.py --alg MLP # # python exps/trading/baselines.py --alg GRU # # python exps/trading/baselines.py --alg LSTM # # python exps/trading/baselines.py --alg ALSTM # -# python exps/trading/baselines.py --alg MLP # +# # # python exps/trading/baselines.py --alg SFM # # python exps/trading/baselines.py --alg XGBoost # # python exps/trading/baselines.py --alg LightGBM # diff --git a/lib/procedures/q_exps.py b/lib/procedures/q_exps.py index bb26ff4..197b1d1 100644 --- a/lib/procedures/q_exps.py +++ b/lib/procedures/q_exps.py @@ -80,7 +80,7 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): # Train model R.log_params(**flatten_dict(task_config)) if "save_path" in inspect.getfullargspec(model.fit).args: - model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model-ckps") + model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model.ckps") model.fit(**model_fit_kwargs) # Get the recorder recorder = R.get_recorder() diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index ea236f3..7e94028 100755 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -17,7 +17,7 @@ import logging from qlib.utils import ( unpack_archive_with_buffer, save_multiple_parts_file, - create_save_path, + get_or_create_path, drop_nan_by_y_index, ) from qlib.log import get_module_logger, TimeInspector @@ -176,7 +176,7 @@ class QuantTransformer(Model): _prepare_loader(test_dataset, False), ) - save_path = create_save_path(save_path) + save_path = get_or_create_path(save_path) self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_path)) def _internal_test(ckp_epoch=None, results_dict=None): diff --git a/scripts/trade/baseline.sh b/scripts/trade/baseline.sh new file mode 100644 index 0000000..31565e1 --- /dev/null +++ b/scripts/trade/baseline.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# bash scripts/trade/baseline.sh 0 csi300 +set -e +echo script name: $0 +echo $# arguments + +if [ "$#" -ne 2 ] ;then + echo "Input illegal number of parameters " $# + exit 1 +fi + +gpu=$1 +market=$2 + +algorithms="MLP GRU LSTM ALSTM XGBoost LightGBM" + +for alg in ${algorithms} +do + + python exps/trading/baselines.py --alg ${alg} --gpu ${gpu} --market ${market} + +done