diff --git a/configs/qlib/workflow_config_lightgbm_Alpha360.yaml b/configs/qlib/workflow_config_lightgbm_Alpha360.yaml index 04e9d1d..4816ae2 100644 --- a/configs/qlib/workflow_config_lightgbm_Alpha360.yaml +++ b/configs/qlib/workflow_config_lightgbm_Alpha360.yaml @@ -10,7 +10,7 @@ data_handler_config: &data_handler_config fit_end_time: 2014-12-31 instruments: *market infer_processors: - - class: RobustZScoreNorm + - class: RobustZScoreNorm kwargs: fields_group: feature clip_outlier: true diff --git a/configs/qlib/workflow_config_transformer_Alpha360.yaml b/configs/qlib/workflow_config_transformer_Alpha360.yaml new file mode 100644 index 0000000..36a3c28 --- /dev/null +++ b/configs/qlib/workflow_config_transformer_Alpha360.yaml @@ -0,0 +1,78 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market all +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: DropnaLabel + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: QuantTransformer + module_path: trade_models.quant_transformer + kwargs: + net_config: + opt_config: + loss: mse + GPU: 0 + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha360 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SignalMseRecord + module_path: qlib.contrib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/configs/qlib/workflow_config_xgboost_Alpha360.yaml b/configs/qlib/workflow_config_xgboost_Alpha360.yaml index 9e5ef7b..2fa21f3 100644 --- a/configs/qlib/workflow_config_xgboost_Alpha360.yaml +++ b/configs/qlib/workflow_config_xgboost_Alpha360.yaml @@ -10,7 +10,7 @@ data_handler_config: &data_handler_config fit_end_time: 2014-12-31 instruments: *market infer_processors: - - class: RobustZScoreNorm + - class: RobustZScoreNorm kwargs: fields_group: feature clip_outlier: true diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index f3501c3..027a691 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -56,6 +56,7 @@ def retrieve_configs(): alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml" alg2names["NAIVE-V1"] = "workflow_config_naive_v1_Alpha360.yaml" alg2names["NAIVE-V2"] = "workflow_config_naive_v2_Alpha360.yaml" + alg2names["Transformer"] = "workflow_config_transformer_Alpha360.yaml" # find the yaml paths alg2paths = OrderedDict() diff --git a/lib/trade_models/transformers.py b/lib/trade_models/transformers.py index 23ea865..1a03c60 100755 --- a/lib/trade_models/transformers.py +++ b/lib/trade_models/transformers.py @@ -17,15 +17,15 @@ import layers as xlayers DEFAULT_NET_CONFIG = dict( d_feat=6, - embed_dim=48, + embed_dim=64, depth=5, num_heads=4, mlp_ratio=4.0, qkv_bias=True, - pos_drop=0.1, - mlp_drop_rate=0.1, - attn_drop_rate=0.1, - drop_path_rate=0.1, + pos_drop=0.0, + mlp_drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, )