Allow different timeframes
This commit is contained in:
		 Submodule .latent-data/NATS-Bench updated: 391d4823b7...3a8794322f
									
								
							| @@ -13,13 +13,15 @@ | ||||
| # python exps/trading/baselines.py --alg LightGBM   # | ||||
| # python exps/trading/baselines.py --alg DoubleE    # | ||||
| # python exps/trading/baselines.py --alg TabNet     # | ||||
| #                                                   # | ||||
| # python exps/trading/baselines.py --alg Transformer# | ||||
| #                                                   ############################# | ||||
| # python exps/trading/baselines.py --alg Transformer | ||||
| # python exps/trading/baselines.py --alg TSF | ||||
| # python exps/trading/baselines.py --alg TSF-4x64-drop0_0 | ||||
| ##################################################### | ||||
| # python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300 | ||||
| # python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300 | ||||
| ################################################################################# | ||||
| import sys | ||||
| import copy | ||||
| from datetime import datetime | ||||
| import argparse | ||||
| from collections import OrderedDict | ||||
| from pathlib import Path | ||||
| @@ -60,7 +62,7 @@ def to_layer(config, embed_dim, depth): | ||||
|  | ||||
| def extend_transformer_settings(alg2configs, name): | ||||
|     config = copy.deepcopy(alg2configs[name]) | ||||
|     for i in range(1, 8): | ||||
|     for i in range(1, 9): | ||||
|         for j in (6, 12, 24, 32, 48, 64): | ||||
|             for k1 in (0, 0.1, 0.2): | ||||
|                 for k2 in (0, 0.1): | ||||
| @@ -70,6 +72,31 @@ def extend_transformer_settings(alg2configs, name): | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def replace_start_time(config, start_time): | ||||
|     config = copy.deepcopy(config) | ||||
|     xtime = datetime.strptime(start_time, "%Y-%m-%d") | ||||
|     config["data_handler_config"]["start_time"] = xtime.date() | ||||
|     config["data_handler_config"]["fit_start_time"] = xtime.date() | ||||
|     config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date() | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def extend_train_data(alg2configs, name): | ||||
|     config = copy.deepcopy(alg2configs[name]) | ||||
|     start_times = ( | ||||
|         "2008-01-01", | ||||
|         "2009-01-01", | ||||
|         "2010-01-01", | ||||
|         "2011-01-01", | ||||
|         "2012-01-01", | ||||
|         "2013-01-01", | ||||
|     ) | ||||
|     for start_time in start_times: | ||||
|         config = replace_start_time(config, start_time) | ||||
|         alg2configs[name + "s{:}".format(start_time)] = config | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def refresh_record(alg2configs): | ||||
|     alg2configs = copy.deepcopy(alg2configs) | ||||
|     for key, config in alg2configs.items(): | ||||
| @@ -133,6 +160,9 @@ def retrieve_configs(): | ||||
|         ) | ||||
|     alg2configs = extend_transformer_settings(alg2configs, "TSF") | ||||
|     alg2configs = refresh_record(alg2configs) | ||||
|     # extend the algorithms by different train-data | ||||
|     for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"): | ||||
|         alg2configs = extend_train_data(alg2configs, name) | ||||
|     print( | ||||
|         "There are {:} algorithms : {:}".format( | ||||
|             len(alg2configs), list(alg2configs.keys()) | ||||
|   | ||||
							
								
								
									
										26
									
								
								scripts/trade/tsf-time.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								scripts/trade/tsf-time.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| #!/bin/bash | ||||
| # | ||||
| # bash scripts/trade/tsf-time.sh 0 csi300 TSF-2x24-drop0_0 | ||||
| # bash scripts/trade/tsf-time.sh 1 csi100 | ||||
| # bash scripts/trade/tsf-time.sh 1 all | ||||
| # | ||||
| set -e | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
|  | ||||
| if [ "$#" -ne 3 ] ;then | ||||
|   echo "Input illegal number of parameters " $# | ||||
|   exit 1 | ||||
| fi | ||||
|  | ||||
| gpu=$1 | ||||
| market=$2 | ||||
| base=$3 | ||||
| xtimes="2008-01-01 2009-01-01 2010-01-01 2011-01-01 2012-01-01 2013-01-01" | ||||
|  | ||||
| for xtime in ${xtimes} | ||||
| do | ||||
|  | ||||
|   python exps/trading/baselines.py --alg ${base}s${xtime} --gpu ${gpu} --market ${market} --shared_dataset False | ||||
|  | ||||
| done | ||||
		Reference in New Issue
	
	Block a user