Update to accomendate last updates of qlib

This commit is contained in:
D-X-Y 2021-03-11 03:09:55 +00:00
parent 731bda649f
commit 58907a2387
5 changed files with 109 additions and 9 deletions

@ -1 +1 @@
Subproject commit d13c9ae01869a31f123285a792b674694f844370
Subproject commit 0ef7c8e0e62f1b08f89f676eac97cbbd2bcc1657

View File

@ -0,0 +1,74 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market all
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TabnetModel
module_path: qlib.contrib.model.pytorch_tabnet
kwargs:
pretrain: True
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
pretrain: [2008-01-01, 2014-12-31]
pretrain_validation: [2015-01-01, 2020-08-01]
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@ -22,13 +22,13 @@ if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from procedures.q_exps import update_gpu
from procedures.q_exps import update_market
from procedures.q_exps import run_exp
import qlib
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.utils import flatten_dict
from qlib.log import set_log_basic_config
def retrieve_configs():
@ -49,6 +49,7 @@ def retrieve_configs():
alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml"
# DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf
alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml"
alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml"
# find the yaml paths
alg2paths = OrderedDict()
@ -66,6 +67,7 @@ def main(xargs, exp_yaml):
with open(exp_yaml) as fp:
config = yaml.safe_load(fp)
config = update_market(config, xargs.market)
config = update_gpu(config, xargs.gpu)
qlib.init(**config.get("qlib_init"))
@ -77,7 +79,7 @@ def main(xargs, exp_yaml):
for irun in range(xargs.times):
run_exp(
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market)
)
@ -87,6 +89,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser("Baselines")
parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.")
parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.")
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.")

View File

@ -105,7 +105,7 @@ def filter_finished(recorders):
def query_info(save_dir, verbose):
R.reset_default_uri(save_dir)
R.set_uri(save_dir)
experiments = R.list_experiments()
key_map = {

View File

@ -3,15 +3,38 @@
#####################################################
import inspect
import os
import logging
import qlib
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.utils import flatten_dict
from qlib.log import set_log_basic_config
from qlib.log import get_module_logger
def set_log_basic_config(filename=None, format=None, level=None):
"""
Set the basic configuration for the logging system.
See details at https://docs.python.org/3/library/logging.html#logging.basicConfig
:param filename: str or None
The path to save the logs.
:param format: the logging format
:param level: int
:return: Logger
Logger object.
"""
from qlib.config import C
if level is None:
level = C.logging_level
if format is None:
format = C.logging_config["formatters"]["logger_format"]["format"]
logging.basicConfig(filename=filename, format=format, level=level)
def update_gpu(config, gpu):
config = config.copy()
if "task" in config and "model" in config["task"]:
@ -46,8 +69,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
# Let's start the experiment.
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri):
# Setup log
recorder_root_dir = R.get_recorder().root_uri
log_file = recorder_root_dir / "{:}.log".format(experiment_name)
recorder_root_dir = R.get_recorder().get_local_dir()
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
set_log_basic_config(log_file)
logger = get_module_logger("q.run_exp")
logger.info("task_config={:}".format(task_config))
@ -56,8 +79,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
# Train model
R.log_params(**flatten_dict(task_config))
if 'save_path' in inspect.getfullargspec(model.fit).args:
model_fit_kwargs['save_path'] = str(recorder_root_dir / 'model-ckps')
if "save_path" in inspect.getfullargspec(model.fit).args:
model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model-ckps")
model.fit(**model_fit_kwargs)
# Get the recorder
recorder = R.get_recorder()