Re-org baselines
This commit is contained in:
parent
e7467fd474
commit
b85bfcbe41
@ -1 +1 @@
|
||||
Subproject commit ba56e4071efd1c08003eaf7e23978aaf81376dd1
|
||||
Subproject commit 3886022669912fbe875b71f652b439d3ab7f7ce2
|
@ -15,9 +15,11 @@
|
||||
# python exps/trading/baselines.py --alg TabNet #
|
||||
# #
|
||||
# python exps/trading/baselines.py --alg Transformer#
|
||||
# python exps/trading/baselines.py --alg TSF-A #
|
||||
# python exps/trading/baselines.py --alg TSF #
|
||||
# python exps/trading/baselines.py --alg TSF-4x64-d0
|
||||
#####################################################
|
||||
import sys
|
||||
import copy
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
@ -38,6 +40,33 @@ from qlib.workflow import R
|
||||
from qlib.utils import flatten_dict
|
||||
|
||||
|
||||
def to_pos_drop(config, value):
|
||||
config = copy.deepcopy(config)
|
||||
net = config["task"]["model"]["kwargs"]["net_config"]
|
||||
net["pos_drop"] = value
|
||||
return config
|
||||
|
||||
|
||||
def to_layer(config, embed_dim, depth):
|
||||
config = copy.deepcopy(config)
|
||||
net = config["task"]["model"]["kwargs"]["net_config"]
|
||||
net["embed_dim"] = embed_dim
|
||||
net["num_heads"] = [4] * depth
|
||||
net["mlp_hidden_multipliers"] = [4] * depth
|
||||
return config
|
||||
|
||||
|
||||
def extend_transformer_settings(alg2configs, name):
|
||||
config = copy.deepcopy(alg2configs[name])
|
||||
for i in range(6):
|
||||
for j in [24, 32, 48, 64]:
|
||||
for k in [0, 0.1]:
|
||||
alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer(
|
||||
to_pos_drop(config, k), j, i
|
||||
)
|
||||
return alg2configs
|
||||
|
||||
|
||||
def retrieve_configs():
|
||||
# https://github.com/microsoft/qlib/blob/main/examples/benchmarks/
|
||||
config_dir = (lib_dir / ".." / "configs" / "qlib").resolve()
|
||||
@ -60,29 +89,28 @@ def retrieve_configs():
|
||||
alg2names["NAIVE-V1"] = "workflow_config_naive_v1_Alpha360.yaml"
|
||||
alg2names["NAIVE-V2"] = "workflow_config_naive_v2_Alpha360.yaml"
|
||||
alg2names["Transformer"] = "workflow_config_transformer_Alpha360.yaml"
|
||||
alg2names["TSF-A"] = "workflow_config_transformer_basic_Alpha360.yaml"
|
||||
alg2names["TSF"] = "workflow_config_transformer_basic_Alpha360.yaml"
|
||||
|
||||
# find the yaml paths
|
||||
alg2paths = OrderedDict()
|
||||
alg2configs = OrderedDict()
|
||||
print("Start retrieving the algorithm configurations")
|
||||
for idx, (alg, name) in enumerate(alg2names.items()):
|
||||
path = config_dir / name
|
||||
assert path.exists(), "{:} does not exist.".format(path)
|
||||
alg2paths[alg] = str(path)
|
||||
with open(path) as fp:
|
||||
alg2configs[alg] = yaml.safe_load(fp)
|
||||
print(
|
||||
"The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format(
|
||||
idx, len(alg2names), alg, path
|
||||
idx, len(alg2configs), alg, path
|
||||
)
|
||||
)
|
||||
return alg2paths
|
||||
alg2configs = extend_transformer_settings(alg2configs, "TSF-A")
|
||||
return alg2configs
|
||||
|
||||
|
||||
def main(xargs, exp_yaml):
|
||||
assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml)
|
||||
def main(xargs, config):
|
||||
|
||||
pprint("Run {:}".format(xargs.alg))
|
||||
with open(exp_yaml) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
config = update_market(config, xargs.market)
|
||||
config = update_gpu(config, xargs.gpu)
|
||||
|
||||
@ -105,7 +133,7 @@ def main(xargs, exp_yaml):
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
alg2paths = retrieve_configs()
|
||||
alg2configs = retrieve_configs()
|
||||
|
||||
parser = argparse.ArgumentParser("Baselines")
|
||||
parser.add_argument(
|
||||
@ -121,7 +149,7 @@ if __name__ == "__main__":
|
||||
choices=["csi100", "csi300", "all"],
|
||||
help="The market indicator.",
|
||||
)
|
||||
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
|
||||
parser.add_argument("--times", type=int, default=5, help="The repeated run times.")
|
||||
parser.add_argument(
|
||||
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
|
||||
)
|
||||
@ -134,4 +162,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args, alg2paths[args.alg])
|
||||
main(args, alg2configs[args.alg])
|
||||
|
@ -126,7 +126,7 @@ def query_info(save_dir, verbose):
|
||||
experiments = R.list_experiments()
|
||||
|
||||
key_map = {
|
||||
"RMSE": "RMSE",
|
||||
# "RMSE": "RMSE",
|
||||
"IC": "IC",
|
||||
"ICIR": "ICIR",
|
||||
"Rank IC": "Rank_IC",
|
||||
|
@ -18,7 +18,7 @@ class SuperLayerNorm1D(SuperModule):
|
||||
"""Super Layer Norm."""
|
||||
|
||||
def __init__(
|
||||
self, dim: IntSpaceType, eps: float = 1e-5, elementwise_affine: bool = True
|
||||
self, dim: IntSpaceType, eps: float = 1e-6, elementwise_affine: bool = True
|
||||
) -> None:
|
||||
super(SuperLayerNorm1D, self).__init__()
|
||||
self._in_dim = dim
|
||||
|
Loading…
Reference in New Issue
Block a user