diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index 182a7d3..8592375 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -59,7 +59,7 @@ def to_layer(config, embed_dim, depth): def extend_transformer_settings(alg2configs, name): config = copy.deepcopy(alg2configs[name]) for i in range(6): - for j in [24, 32, 48, 64]: + for j in [6, 12, 24, 32, 48, 64]: for k in [0, 0.1]: alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer( to_pos_drop(config, k), j, i @@ -104,7 +104,7 @@ def retrieve_configs(): idx, len(alg2configs), alg, path ) ) - alg2configs = extend_transformer_settings(alg2configs, "TSF-A") + alg2configs = extend_transformer_settings(alg2configs, "TSF") return alg2configs @@ -156,7 +156,7 @@ if __name__ == "__main__": parser.add_argument( "--alg", type=str, - choices=list(alg2paths.keys()), + choices=list(alg2configs.keys()), required=True, help="The algorithm name.", ) diff --git a/scripts/trade/tsf.sh b/scripts/trade/tsf.sh new file mode 100644 index 0000000..91a5233 --- /dev/null +++ b/scripts/trade/tsf.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# +# bash scripts/trade/tsf.sh 0 csi300 3 +# bash scripts/trade/tsf.sh 1 csi100 3 +# bash scripts/trade/tsf.sh 1 all 3 +# +set -e +echo script name: $0 +echo $# arguments + +if [ "$#" -ne 3 ] ;then + echo "Input illegal number of parameters " $# + exit 1 +fi + +gpu=$1 +market=$2 +depth=$3 + +channels="6 12 24 32 48 64" + +for channel in ${channels} +do + + python exps/trading/baselines.py --alg TSF-${depth}x${channel}-d0 --gpu ${gpu} --market ${market} + +done