Allow different timeframes
This commit is contained in:
parent
2595b11a8c
commit
028bc88430
@ -1 +1 @@
|
||||
Subproject commit 391d4823b70898cdd3b70045519d9cde42979ada
|
||||
Subproject commit 3a8794322f0b990499a44db1b2cb05ef2bb33851
|
@ -13,13 +13,15 @@
|
||||
# python exps/trading/baselines.py --alg LightGBM #
|
||||
# python exps/trading/baselines.py --alg DoubleE #
|
||||
# python exps/trading/baselines.py --alg TabNet #
|
||||
# #
|
||||
# python exps/trading/baselines.py --alg Transformer#
|
||||
# #############################
|
||||
# python exps/trading/baselines.py --alg Transformer
|
||||
# python exps/trading/baselines.py --alg TSF
|
||||
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
|
||||
#####################################################
|
||||
# python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300
|
||||
# python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300
|
||||
#################################################################################
|
||||
import sys
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
@ -60,7 +62,7 @@ def to_layer(config, embed_dim, depth):
|
||||
|
||||
def extend_transformer_settings(alg2configs, name):
|
||||
config = copy.deepcopy(alg2configs[name])
|
||||
for i in range(1, 8):
|
||||
for i in range(1, 9):
|
||||
for j in (6, 12, 24, 32, 48, 64):
|
||||
for k1 in (0, 0.1, 0.2):
|
||||
for k2 in (0, 0.1):
|
||||
@ -70,6 +72,31 @@ def extend_transformer_settings(alg2configs, name):
|
||||
return alg2configs
|
||||
|
||||
|
||||
def replace_start_time(config, start_time):
|
||||
config = copy.deepcopy(config)
|
||||
xtime = datetime.strptime(start_time, "%Y-%m-%d")
|
||||
config["data_handler_config"]["start_time"] = xtime.date()
|
||||
config["data_handler_config"]["fit_start_time"] = xtime.date()
|
||||
config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date()
|
||||
return config
|
||||
|
||||
|
||||
def extend_train_data(alg2configs, name):
|
||||
config = copy.deepcopy(alg2configs[name])
|
||||
start_times = (
|
||||
"2008-01-01",
|
||||
"2009-01-01",
|
||||
"2010-01-01",
|
||||
"2011-01-01",
|
||||
"2012-01-01",
|
||||
"2013-01-01",
|
||||
)
|
||||
for start_time in start_times:
|
||||
config = replace_start_time(config, start_time)
|
||||
alg2configs[name + "s{:}".format(start_time)] = config
|
||||
return alg2configs
|
||||
|
||||
|
||||
def refresh_record(alg2configs):
|
||||
alg2configs = copy.deepcopy(alg2configs)
|
||||
for key, config in alg2configs.items():
|
||||
@ -133,6 +160,9 @@ def retrieve_configs():
|
||||
)
|
||||
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
||||
alg2configs = refresh_record(alg2configs)
|
||||
# extend the algorithms by different train-data
|
||||
for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"):
|
||||
alg2configs = extend_train_data(alg2configs, name)
|
||||
print(
|
||||
"There are {:} algorithms : {:}".format(
|
||||
len(alg2configs), list(alg2configs.keys())
|
||||
|
26
scripts/trade/tsf-time.sh
Normal file
26
scripts/trade/tsf-time.sh
Normal file
@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# bash scripts/trade/tsf-time.sh 0 csi300 TSF-2x24-drop0_0
|
||||
# bash scripts/trade/tsf-time.sh 1 csi100
|
||||
# bash scripts/trade/tsf-time.sh 1 all
|
||||
#
|
||||
set -e
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
|
||||
if [ "$#" -ne 3 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
exit 1
|
||||
fi
|
||||
|
||||
gpu=$1
|
||||
market=$2
|
||||
base=$3
|
||||
xtimes="2008-01-01 2009-01-01 2010-01-01 2011-01-01 2012-01-01 2013-01-01"
|
||||
|
||||
for xtime in ${xtimes}
|
||||
do
|
||||
|
||||
python exps/trading/baselines.py --alg ${base}s${xtime} --gpu ${gpu} --market ${market} --shared_dataset False
|
||||
|
||||
done
|
Loading…
Reference in New Issue
Block a user