Allow different timeframes
This commit is contained in:
parent
2595b11a8c
commit
028bc88430
@ -1 +1 @@
|
|||||||
Subproject commit 391d4823b70898cdd3b70045519d9cde42979ada
|
Subproject commit 3a8794322f0b990499a44db1b2cb05ef2bb33851
|
@ -13,13 +13,15 @@
|
|||||||
# python exps/trading/baselines.py --alg LightGBM #
|
# python exps/trading/baselines.py --alg LightGBM #
|
||||||
# python exps/trading/baselines.py --alg DoubleE #
|
# python exps/trading/baselines.py --alg DoubleE #
|
||||||
# python exps/trading/baselines.py --alg TabNet #
|
# python exps/trading/baselines.py --alg TabNet #
|
||||||
# #
|
# #############################
|
||||||
# python exps/trading/baselines.py --alg Transformer#
|
# python exps/trading/baselines.py --alg Transformer
|
||||||
# python exps/trading/baselines.py --alg TSF
|
# python exps/trading/baselines.py --alg TSF
|
||||||
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
|
# python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300
|
||||||
#####################################################
|
# python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300
|
||||||
|
#################################################################################
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
|
from datetime import datetime
|
||||||
import argparse
|
import argparse
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -60,7 +62,7 @@ def to_layer(config, embed_dim, depth):
|
|||||||
|
|
||||||
def extend_transformer_settings(alg2configs, name):
|
def extend_transformer_settings(alg2configs, name):
|
||||||
config = copy.deepcopy(alg2configs[name])
|
config = copy.deepcopy(alg2configs[name])
|
||||||
for i in range(1, 8):
|
for i in range(1, 9):
|
||||||
for j in (6, 12, 24, 32, 48, 64):
|
for j in (6, 12, 24, 32, 48, 64):
|
||||||
for k1 in (0, 0.1, 0.2):
|
for k1 in (0, 0.1, 0.2):
|
||||||
for k2 in (0, 0.1):
|
for k2 in (0, 0.1):
|
||||||
@ -70,6 +72,31 @@ def extend_transformer_settings(alg2configs, name):
|
|||||||
return alg2configs
|
return alg2configs
|
||||||
|
|
||||||
|
|
||||||
|
def replace_start_time(config, start_time):
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
xtime = datetime.strptime(start_time, "%Y-%m-%d")
|
||||||
|
config["data_handler_config"]["start_time"] = xtime.date()
|
||||||
|
config["data_handler_config"]["fit_start_time"] = xtime.date()
|
||||||
|
config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date()
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def extend_train_data(alg2configs, name):
|
||||||
|
config = copy.deepcopy(alg2configs[name])
|
||||||
|
start_times = (
|
||||||
|
"2008-01-01",
|
||||||
|
"2009-01-01",
|
||||||
|
"2010-01-01",
|
||||||
|
"2011-01-01",
|
||||||
|
"2012-01-01",
|
||||||
|
"2013-01-01",
|
||||||
|
)
|
||||||
|
for start_time in start_times:
|
||||||
|
config = replace_start_time(config, start_time)
|
||||||
|
alg2configs[name + "s{:}".format(start_time)] = config
|
||||||
|
return alg2configs
|
||||||
|
|
||||||
|
|
||||||
def refresh_record(alg2configs):
|
def refresh_record(alg2configs):
|
||||||
alg2configs = copy.deepcopy(alg2configs)
|
alg2configs = copy.deepcopy(alg2configs)
|
||||||
for key, config in alg2configs.items():
|
for key, config in alg2configs.items():
|
||||||
@ -133,6 +160,9 @@ def retrieve_configs():
|
|||||||
)
|
)
|
||||||
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
||||||
alg2configs = refresh_record(alg2configs)
|
alg2configs = refresh_record(alg2configs)
|
||||||
|
# extend the algorithms by different train-data
|
||||||
|
for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"):
|
||||||
|
alg2configs = extend_train_data(alg2configs, name)
|
||||||
print(
|
print(
|
||||||
"There are {:} algorithms : {:}".format(
|
"There are {:} algorithms : {:}".format(
|
||||||
len(alg2configs), list(alg2configs.keys())
|
len(alg2configs), list(alg2configs.keys())
|
||||||
|
26
scripts/trade/tsf-time.sh
Normal file
26
scripts/trade/tsf-time.sh
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# bash scripts/trade/tsf-time.sh 0 csi300 TSF-2x24-drop0_0
|
||||||
|
# bash scripts/trade/tsf-time.sh 1 csi100
|
||||||
|
# bash scripts/trade/tsf-time.sh 1 all
|
||||||
|
#
|
||||||
|
set -e
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
|
||||||
|
if [ "$#" -ne 3 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
gpu=$1
|
||||||
|
market=$2
|
||||||
|
base=$3
|
||||||
|
xtimes="2008-01-01 2009-01-01 2010-01-01 2011-01-01 2012-01-01 2013-01-01"
|
||||||
|
|
||||||
|
for xtime in ${xtimes}
|
||||||
|
do
|
||||||
|
|
||||||
|
python exps/trading/baselines.py --alg ${base}s${xtime} --gpu ${gpu} --market ${market} --shared_dataset False
|
||||||
|
|
||||||
|
done
|
Loading…
Reference in New Issue
Block a user