Allow different timeframes

This commit is contained in:
D-X-Y 2021-04-07 22:42:00 -07:00
parent 2595b11a8c
commit 028bc88430
3 changed files with 62 additions and 6 deletions

@ -1 +1 @@
Subproject commit 391d4823b70898cdd3b70045519d9cde42979ada Subproject commit 3a8794322f0b990499a44db1b2cb05ef2bb33851

View File

@ -13,13 +13,15 @@
# python exps/trading/baselines.py --alg LightGBM # # python exps/trading/baselines.py --alg LightGBM #
# python exps/trading/baselines.py --alg DoubleE # # python exps/trading/baselines.py --alg DoubleE #
# python exps/trading/baselines.py --alg TabNet # # python exps/trading/baselines.py --alg TabNet #
# # # #############################
# python exps/trading/baselines.py --alg Transformer# # python exps/trading/baselines.py --alg Transformer
# python exps/trading/baselines.py --alg TSF # python exps/trading/baselines.py --alg TSF
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0 # python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300
##################################################### # python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300
#################################################################################
import sys import sys
import copy import copy
from datetime import datetime
import argparse import argparse
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
@ -60,7 +62,7 @@ def to_layer(config, embed_dim, depth):
def extend_transformer_settings(alg2configs, name): def extend_transformer_settings(alg2configs, name):
config = copy.deepcopy(alg2configs[name]) config = copy.deepcopy(alg2configs[name])
for i in range(1, 8): for i in range(1, 9):
for j in (6, 12, 24, 32, 48, 64): for j in (6, 12, 24, 32, 48, 64):
for k1 in (0, 0.1, 0.2): for k1 in (0, 0.1, 0.2):
for k2 in (0, 0.1): for k2 in (0, 0.1):
@ -70,6 +72,31 @@ def extend_transformer_settings(alg2configs, name):
return alg2configs return alg2configs
def replace_start_time(config, start_time):
config = copy.deepcopy(config)
xtime = datetime.strptime(start_time, "%Y-%m-%d")
config["data_handler_config"]["start_time"] = xtime.date()
config["data_handler_config"]["fit_start_time"] = xtime.date()
config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date()
return config
def extend_train_data(alg2configs, name):
config = copy.deepcopy(alg2configs[name])
start_times = (
"2008-01-01",
"2009-01-01",
"2010-01-01",
"2011-01-01",
"2012-01-01",
"2013-01-01",
)
for start_time in start_times:
config = replace_start_time(config, start_time)
alg2configs[name + "s{:}".format(start_time)] = config
return alg2configs
def refresh_record(alg2configs): def refresh_record(alg2configs):
alg2configs = copy.deepcopy(alg2configs) alg2configs = copy.deepcopy(alg2configs)
for key, config in alg2configs.items(): for key, config in alg2configs.items():
@ -133,6 +160,9 @@ def retrieve_configs():
) )
alg2configs = extend_transformer_settings(alg2configs, "TSF") alg2configs = extend_transformer_settings(alg2configs, "TSF")
alg2configs = refresh_record(alg2configs) alg2configs = refresh_record(alg2configs)
# extend the algorithms by different train-data
for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"):
alg2configs = extend_train_data(alg2configs, name)
print( print(
"There are {:} algorithms : {:}".format( "There are {:} algorithms : {:}".format(
len(alg2configs), list(alg2configs.keys()) len(alg2configs), list(alg2configs.keys())

26
scripts/trade/tsf-time.sh Normal file
View File

@ -0,0 +1,26 @@
#!/bin/bash
#
# bash scripts/trade/tsf-time.sh 0 csi300 TSF-2x24-drop0_0
# bash scripts/trade/tsf-time.sh 1 csi100
# bash scripts/trade/tsf-time.sh 1 all
#
set -e
echo script name: $0
echo $# arguments
if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $#
exit 1
fi
gpu=$1
market=$2
base=$3
xtimes="2008-01-01 2009-01-01 2010-01-01 2011-01-01 2012-01-01 2013-01-01"
for xtime in ${xtimes}
do
python exps/trading/baselines.py --alg ${base}s${xtime} --gpu ${gpu} --market ${market} --shared_dataset False
done