Update scripts

This commit is contained in:
D-X-Y 2021-03-25 20:41:22 -07:00
parent feeac82cbb
commit 37797177f8
2 changed files with 30 additions and 3 deletions

View File

@ -59,7 +59,7 @@ def to_layer(config, embed_dim, depth):
def extend_transformer_settings(alg2configs, name):
config = copy.deepcopy(alg2configs[name])
for i in range(6):
for j in [24, 32, 48, 64]:
for j in [6, 12, 24, 32, 48, 64]:
for k in [0, 0.1]:
alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer(
to_pos_drop(config, k), j, i
@ -104,7 +104,7 @@ def retrieve_configs():
idx, len(alg2configs), alg, path
)
)
alg2configs = extend_transformer_settings(alg2configs, "TSF-A")
alg2configs = extend_transformer_settings(alg2configs, "TSF")
return alg2configs
@ -156,7 +156,7 @@ if __name__ == "__main__":
parser.add_argument(
"--alg",
type=str,
choices=list(alg2paths.keys()),
choices=list(alg2configs.keys()),
required=True,
help="The algorithm name.",
)

27
scripts/trade/tsf.sh Normal file
View File

@ -0,0 +1,27 @@
#!/bin/bash
#
# bash scripts/trade/tsf.sh 0 csi300 3
# bash scripts/trade/tsf.sh 1 csi100 3
# bash scripts/trade/tsf.sh 1 all 3
#
set -e
echo script name: $0
echo $# arguments
if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $#
exit 1
fi
gpu=$1
market=$2
depth=$3
channels="6 12 24 32 48 64"
for channel in ${channels}
do
python exps/trading/baselines.py --alg TSF-${depth}x${channel}-d0 --gpu ${gpu} --market ${market}
done