Update scripts
This commit is contained in:
		| @@ -59,7 +59,7 @@ def to_layer(config, embed_dim, depth): | |||||||
| def extend_transformer_settings(alg2configs, name): | def extend_transformer_settings(alg2configs, name): | ||||||
|     config = copy.deepcopy(alg2configs[name]) |     config = copy.deepcopy(alg2configs[name]) | ||||||
|     for i in range(6): |     for i in range(6): | ||||||
|         for j in [24, 32, 48, 64]: |         for j in [6, 12, 24, 32, 48, 64]: | ||||||
|             for k in [0, 0.1]: |             for k in [0, 0.1]: | ||||||
|                 alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer( |                 alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer( | ||||||
|                     to_pos_drop(config, k), j, i |                     to_pos_drop(config, k), j, i | ||||||
| @@ -104,7 +104,7 @@ def retrieve_configs(): | |||||||
|                 idx, len(alg2configs), alg, path |                 idx, len(alg2configs), alg, path | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|     alg2configs = extend_transformer_settings(alg2configs, "TSF-A") |     alg2configs = extend_transformer_settings(alg2configs, "TSF") | ||||||
|     return alg2configs |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -156,7 +156,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--alg", |         "--alg", | ||||||
|         type=str, |         type=str, | ||||||
|         choices=list(alg2paths.keys()), |         choices=list(alg2configs.keys()), | ||||||
|         required=True, |         required=True, | ||||||
|         help="The algorithm name.", |         help="The algorithm name.", | ||||||
|     ) |     ) | ||||||
|   | |||||||
							
								
								
									
										27
									
								
								scripts/trade/tsf.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								scripts/trade/tsf.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # | ||||||
|  | # bash scripts/trade/tsf.sh 0 csi300 3 | ||||||
|  | # bash scripts/trade/tsf.sh 1 csi100 3 | ||||||
|  | # bash scripts/trade/tsf.sh 1 all    3 | ||||||
|  | # | ||||||
|  | set -e | ||||||
|  | echo script name: $0 | ||||||
|  | echo $# arguments | ||||||
|  |  | ||||||
|  | if [ "$#" -ne 3 ] ;then | ||||||
|  |   echo "Input illegal number of parameters " $# | ||||||
|  |   exit 1 | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | gpu=$1 | ||||||
|  | market=$2 | ||||||
|  | depth=$3 | ||||||
|  |  | ||||||
|  | channels="6 12 24 32 48 64" | ||||||
|  |  | ||||||
|  | for channel in ${channels} | ||||||
|  | do | ||||||
|  |  | ||||||
|  |   python exps/trading/baselines.py --alg TSF-${depth}x${channel}-d0 --gpu ${gpu} --market ${market} | ||||||
|  |  | ||||||
|  | done | ||||||
		Reference in New Issue
	
	Block a user