From 028bc88430974104b3386253c5953caf614f9c0c Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 7 Apr 2021 22:42:00 -0700 Subject: [PATCH] Allow different timeframes --- .latent-data/NATS-Bench | 2 +- exps/trading/baselines.py | 40 ++++++++++++++++++++++++++++++++++----- scripts/trade/tsf-time.sh | 26 +++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 scripts/trade/tsf-time.sh diff --git a/.latent-data/NATS-Bench b/.latent-data/NATS-Bench index 391d482..3a87943 160000 --- a/.latent-data/NATS-Bench +++ b/.latent-data/NATS-Bench @@ -1 +1 @@ -Subproject commit 391d4823b70898cdd3b70045519d9cde42979ada +Subproject commit 3a8794322f0b990499a44db1b2cb05ef2bb33851 diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index 130792d..cf42cf7 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -13,13 +13,15 @@ # python exps/trading/baselines.py --alg LightGBM # # python exps/trading/baselines.py --alg DoubleE # # python exps/trading/baselines.py --alg TabNet # -# # -# python exps/trading/baselines.py --alg Transformer# +# ############################# +# python exps/trading/baselines.py --alg Transformer # python exps/trading/baselines.py --alg TSF -# python exps/trading/baselines.py --alg TSF-4x64-drop0_0 -##################################################### +# python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300 +# python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300 +################################################################################# import sys import copy +from datetime import datetime import argparse from collections import OrderedDict from pathlib import Path @@ -60,7 +62,7 @@ def to_layer(config, embed_dim, depth): def extend_transformer_settings(alg2configs, name): config = copy.deepcopy(alg2configs[name]) - for i in range(1, 8): + for i in range(1, 9): for j in (6, 12, 24, 32, 48, 64): for k1 in (0, 0.1, 0.2): for k2 in (0, 0.1): @@ -70,6 +72,31 @@ def extend_transformer_settings(alg2configs, name): return alg2configs +def replace_start_time(config, start_time): + config = copy.deepcopy(config) + xtime = datetime.strptime(start_time, "%Y-%m-%d") + config["data_handler_config"]["start_time"] = xtime.date() + config["data_handler_config"]["fit_start_time"] = xtime.date() + config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date() + return config + + +def extend_train_data(alg2configs, name): + config = copy.deepcopy(alg2configs[name]) + start_times = ( + "2008-01-01", + "2009-01-01", + "2010-01-01", + "2011-01-01", + "2012-01-01", + "2013-01-01", + ) + for start_time in start_times: + config = replace_start_time(config, start_time) + alg2configs[name + "s{:}".format(start_time)] = config + return alg2configs + + def refresh_record(alg2configs): alg2configs = copy.deepcopy(alg2configs) for key, config in alg2configs.items(): @@ -133,6 +160,9 @@ def retrieve_configs(): ) alg2configs = extend_transformer_settings(alg2configs, "TSF") alg2configs = refresh_record(alg2configs) + # extend the algorithms by different train-data + for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"): + alg2configs = extend_train_data(alg2configs, name) print( "There are {:} algorithms : {:}".format( len(alg2configs), list(alg2configs.keys()) diff --git a/scripts/trade/tsf-time.sh b/scripts/trade/tsf-time.sh new file mode 100644 index 0000000..0687a63 --- /dev/null +++ b/scripts/trade/tsf-time.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# +# bash scripts/trade/tsf-time.sh 0 csi300 TSF-2x24-drop0_0 +# bash scripts/trade/tsf-time.sh 1 csi100 +# bash scripts/trade/tsf-time.sh 1 all +# +set -e +echo script name: $0 +echo $# arguments + +if [ "$#" -ne 3 ] ;then + echo "Input illegal number of parameters " $# + exit 1 +fi + +gpu=$1 +market=$2 +base=$3 +xtimes="2008-01-01 2009-01-01 2010-01-01 2011-01-01 2012-01-01 2013-01-01" + +for xtime in ${xtimes} +do + + python exps/trading/baselines.py --alg ${base}s${xtime} --gpu ${gpu} --market ${market} --shared_dataset False + +done