Allow different timeframes
This commit is contained in:
		 Submodule .latent-data/NATS-Bench updated: 391d4823b7...3a8794322f
									
								
							| @@ -13,13 +13,15 @@ | |||||||
| # python exps/trading/baselines.py --alg LightGBM   # | # python exps/trading/baselines.py --alg LightGBM   # | ||||||
| # python exps/trading/baselines.py --alg DoubleE    # | # python exps/trading/baselines.py --alg DoubleE    # | ||||||
| # python exps/trading/baselines.py --alg TabNet     # | # python exps/trading/baselines.py --alg TabNet     # | ||||||
| #                                                   # | #                                                   ############################# | ||||||
| # python exps/trading/baselines.py --alg Transformer# | # python exps/trading/baselines.py --alg Transformer | ||||||
| # python exps/trading/baselines.py --alg TSF | # python exps/trading/baselines.py --alg TSF | ||||||
| # python exps/trading/baselines.py --alg TSF-4x64-drop0_0 | # python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300 | ||||||
| ##################################################### | # python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300 | ||||||
|  | ################################################################################# | ||||||
| import sys | import sys | ||||||
| import copy | import copy | ||||||
|  | from datetime import datetime | ||||||
| import argparse | import argparse | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -60,7 +62,7 @@ def to_layer(config, embed_dim, depth): | |||||||
|  |  | ||||||
| def extend_transformer_settings(alg2configs, name): | def extend_transformer_settings(alg2configs, name): | ||||||
|     config = copy.deepcopy(alg2configs[name]) |     config = copy.deepcopy(alg2configs[name]) | ||||||
|     for i in range(1, 8): |     for i in range(1, 9): | ||||||
|         for j in (6, 12, 24, 32, 48, 64): |         for j in (6, 12, 24, 32, 48, 64): | ||||||
|             for k1 in (0, 0.1, 0.2): |             for k1 in (0, 0.1, 0.2): | ||||||
|                 for k2 in (0, 0.1): |                 for k2 in (0, 0.1): | ||||||
| @@ -70,6 +72,31 @@ def extend_transformer_settings(alg2configs, name): | |||||||
|     return alg2configs |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def replace_start_time(config, start_time): | ||||||
|  |     config = copy.deepcopy(config) | ||||||
|  |     xtime = datetime.strptime(start_time, "%Y-%m-%d") | ||||||
|  |     config["data_handler_config"]["start_time"] = xtime.date() | ||||||
|  |     config["data_handler_config"]["fit_start_time"] = xtime.date() | ||||||
|  |     config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date() | ||||||
|  |     return config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def extend_train_data(alg2configs, name): | ||||||
|  |     config = copy.deepcopy(alg2configs[name]) | ||||||
|  |     start_times = ( | ||||||
|  |         "2008-01-01", | ||||||
|  |         "2009-01-01", | ||||||
|  |         "2010-01-01", | ||||||
|  |         "2011-01-01", | ||||||
|  |         "2012-01-01", | ||||||
|  |         "2013-01-01", | ||||||
|  |     ) | ||||||
|  |     for start_time in start_times: | ||||||
|  |         config = replace_start_time(config, start_time) | ||||||
|  |         alg2configs[name + "s{:}".format(start_time)] = config | ||||||
|  |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
| def refresh_record(alg2configs): | def refresh_record(alg2configs): | ||||||
|     alg2configs = copy.deepcopy(alg2configs) |     alg2configs = copy.deepcopy(alg2configs) | ||||||
|     for key, config in alg2configs.items(): |     for key, config in alg2configs.items(): | ||||||
| @@ -133,6 +160,9 @@ def retrieve_configs(): | |||||||
|         ) |         ) | ||||||
|     alg2configs = extend_transformer_settings(alg2configs, "TSF") |     alg2configs = extend_transformer_settings(alg2configs, "TSF") | ||||||
|     alg2configs = refresh_record(alg2configs) |     alg2configs = refresh_record(alg2configs) | ||||||
|  |     # extend the algorithms by different train-data | ||||||
|  |     for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"): | ||||||
|  |         alg2configs = extend_train_data(alg2configs, name) | ||||||
|     print( |     print( | ||||||
|         "There are {:} algorithms : {:}".format( |         "There are {:} algorithms : {:}".format( | ||||||
|             len(alg2configs), list(alg2configs.keys()) |             len(alg2configs), list(alg2configs.keys()) | ||||||
|   | |||||||
							
								
								
									
										26
									
								
								scripts/trade/tsf-time.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								scripts/trade/tsf-time.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # | ||||||
|  | # bash scripts/trade/tsf-time.sh 0 csi300 TSF-2x24-drop0_0 | ||||||
|  | # bash scripts/trade/tsf-time.sh 1 csi100 | ||||||
|  | # bash scripts/trade/tsf-time.sh 1 all | ||||||
|  | # | ||||||
|  | set -e | ||||||
|  | echo script name: $0 | ||||||
|  | echo $# arguments | ||||||
|  |  | ||||||
|  | if [ "$#" -ne 3 ] ;then | ||||||
|  |   echo "Input illegal number of parameters " $# | ||||||
|  |   exit 1 | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | gpu=$1 | ||||||
|  | market=$2 | ||||||
|  | base=$3 | ||||||
|  | xtimes="2008-01-01 2009-01-01 2010-01-01 2011-01-01 2012-01-01 2013-01-01" | ||||||
|  |  | ||||||
|  | for xtime in ${xtimes} | ||||||
|  | do | ||||||
|  |  | ||||||
|  |   python exps/trading/baselines.py --alg ${base}s${xtime} --gpu ${gpu} --market ${market} --shared_dataset False | ||||||
|  |  | ||||||
|  | done | ||||||
		Reference in New Issue
	
	Block a user