Allow different timeframes

This commit is contained in:
D-X-Y 2021-04-07 22:42:00 -07:00
parent 2595b11a8c
commit 028bc88430
3 changed files with 62 additions and 6 deletions

@ -1 +1 @@
Subproject commit 391d4823b70898cdd3b70045519d9cde42979ada
Subproject commit 3a8794322f0b990499a44db1b2cb05ef2bb33851

View File

@ -13,13 +13,15 @@
# python exps/trading/baselines.py --alg LightGBM #
# python exps/trading/baselines.py --alg DoubleE #
# python exps/trading/baselines.py --alg TabNet #
# #
# python exps/trading/baselines.py --alg Transformer#
# #############################
# python exps/trading/baselines.py --alg Transformer
# python exps/trading/baselines.py --alg TSF
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
#####################################################
# python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300
# python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300
#################################################################################
import sys
import copy
from datetime import datetime
import argparse
from collections import OrderedDict
from pathlib import Path
@ -60,7 +62,7 @@ def to_layer(config, embed_dim, depth):
def extend_transformer_settings(alg2configs, name):
config = copy.deepcopy(alg2configs[name])
for i in range(1, 8):
for i in range(1, 9):
for j in (6, 12, 24, 32, 48, 64):
for k1 in (0, 0.1, 0.2):
for k2 in (0, 0.1):
@ -70,6 +72,31 @@ def extend_transformer_settings(alg2configs, name):
return alg2configs
def replace_start_time(config, start_time):
config = copy.deepcopy(config)
xtime = datetime.strptime(start_time, "%Y-%m-%d")
config["data_handler_config"]["start_time"] = xtime.date()
config["data_handler_config"]["fit_start_time"] = xtime.date()
config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date()
return config
def extend_train_data(alg2configs, name):
config = copy.deepcopy(alg2configs[name])
start_times = (
"2008-01-01",
"2009-01-01",
"2010-01-01",
"2011-01-01",
"2012-01-01",
"2013-01-01",
)
for start_time in start_times:
config = replace_start_time(config, start_time)
alg2configs[name + "s{:}".format(start_time)] = config
return alg2configs
def refresh_record(alg2configs):
alg2configs = copy.deepcopy(alg2configs)
for key, config in alg2configs.items():
@ -133,6 +160,9 @@ def retrieve_configs():
)
alg2configs = extend_transformer_settings(alg2configs, "TSF")
alg2configs = refresh_record(alg2configs)
# extend the algorithms by different train-data
for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"):
alg2configs = extend_train_data(alg2configs, name)
print(
"There are {:} algorithms : {:}".format(
len(alg2configs), list(alg2configs.keys())

26
scripts/trade/tsf-time.sh Normal file
View File

@ -0,0 +1,26 @@
#!/bin/bash
#
# bash scripts/trade/tsf-time.sh 0 csi300 TSF-2x24-drop0_0
# bash scripts/trade/tsf-time.sh 1 csi100
# bash scripts/trade/tsf-time.sh 1 all
#
set -e
echo script name: $0
echo $# arguments
if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $#
exit 1
fi
gpu=$1
market=$2
base=$3
xtimes="2008-01-01 2009-01-01 2010-01-01 2011-01-01 2012-01-01 2013-01-01"
for xtime in ${xtimes}
do
python exps/trading/baselines.py --alg ${base}s${xtime} --gpu ${gpu} --market ${market} --shared_dataset False
done