This commit is contained in:
D-X-Y 2021-03-17 03:51:48 +00:00
parent e04f17116d
commit 1ba1585f20
5 changed files with 86 additions and 7 deletions

View File

@ -10,7 +10,7 @@ data_handler_config: &data_handler_config
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true

View File

@ -0,0 +1,78 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market all
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: QuantTransformer
module_path: trade_models.quant_transformer
kwargs:
net_config:
opt_config:
loss: mse
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SignalMseRecord
module_path: qlib.contrib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@ -10,7 +10,7 @@ data_handler_config: &data_handler_config
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true

View File

@ -56,6 +56,7 @@ def retrieve_configs():
alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml"
alg2names["NAIVE-V1"] = "workflow_config_naive_v1_Alpha360.yaml"
alg2names["NAIVE-V2"] = "workflow_config_naive_v2_Alpha360.yaml"
alg2names["Transformer"] = "workflow_config_transformer_Alpha360.yaml"
# find the yaml paths
alg2paths = OrderedDict()

View File

@ -17,15 +17,15 @@ import layers as xlayers
DEFAULT_NET_CONFIG = dict(
d_feat=6,
embed_dim=48,
embed_dim=64,
depth=5,
num_heads=4,
mlp_ratio=4.0,
qkv_bias=True,
pos_drop=0.1,
mlp_drop_rate=0.1,
attn_drop_rate=0.1,
drop_path_rate=0.1,
pos_drop=0.0,
mlp_drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
)