Update to accomendate last updates of qlib
This commit is contained in:
parent
731bda649f
commit
58907a2387
.latent-data
configs/qlib
exps/trading
lib/procedures
@ -1 +1 @@
|
||||
Subproject commit d13c9ae01869a31f123285a792b674694f844370
|
||||
Subproject commit 0ef7c8e0e62f1b08f89f676eac97cbbd2bcc1657
|
74
configs/qlib/workflow_config_TabNet_Alpha360.yaml
Normal file
74
configs/qlib/workflow_config_TabNet_Alpha360.yaml
Normal file
@ -0,0 +1,74 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market all
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
@ -22,13 +22,13 @@ if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from procedures.q_exps import update_gpu
|
||||
from procedures.q_exps import update_market
|
||||
from procedures.q_exps import run_exp
|
||||
|
||||
import qlib
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.log import set_log_basic_config
|
||||
|
||||
|
||||
def retrieve_configs():
|
||||
@ -49,6 +49,7 @@ def retrieve_configs():
|
||||
alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml"
|
||||
# DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf
|
||||
alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml"
|
||||
alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml"
|
||||
|
||||
# find the yaml paths
|
||||
alg2paths = OrderedDict()
|
||||
@ -66,6 +67,7 @@ def main(xargs, exp_yaml):
|
||||
|
||||
with open(exp_yaml) as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
config = update_market(config, xargs.market)
|
||||
config = update_gpu(config, xargs.gpu)
|
||||
|
||||
qlib.init(**config.get("qlib_init"))
|
||||
@ -77,7 +79,7 @@ def main(xargs, exp_yaml):
|
||||
|
||||
for irun in range(xargs.times):
|
||||
run_exp(
|
||||
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir
|
||||
config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market)
|
||||
)
|
||||
|
||||
|
||||
@ -87,6 +89,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser("Baselines")
|
||||
parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.")
|
||||
parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.")
|
||||
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
|
||||
parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.")
|
||||
|
@ -105,7 +105,7 @@ def filter_finished(recorders):
|
||||
|
||||
|
||||
def query_info(save_dir, verbose):
|
||||
R.reset_default_uri(save_dir)
|
||||
R.set_uri(save_dir)
|
||||
experiments = R.list_experiments()
|
||||
|
||||
key_map = {
|
||||
|
@ -3,15 +3,38 @@
|
||||
#####################################################
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import logging
|
||||
|
||||
import qlib
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.log import set_log_basic_config
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
def set_log_basic_config(filename=None, format=None, level=None):
|
||||
"""
|
||||
Set the basic configuration for the logging system.
|
||||
See details at https://docs.python.org/3/library/logging.html#logging.basicConfig
|
||||
:param filename: str or None
|
||||
The path to save the logs.
|
||||
:param format: the logging format
|
||||
:param level: int
|
||||
:return: Logger
|
||||
Logger object.
|
||||
"""
|
||||
from qlib.config import C
|
||||
|
||||
if level is None:
|
||||
level = C.logging_level
|
||||
|
||||
if format is None:
|
||||
format = C.logging_config["formatters"]["logger_format"]["format"]
|
||||
|
||||
logging.basicConfig(filename=filename, format=format, level=level)
|
||||
|
||||
|
||||
def update_gpu(config, gpu):
|
||||
config = config.copy()
|
||||
if "task" in config and "model" in config["task"]:
|
||||
@ -46,8 +69,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
|
||||
# Let's start the experiment.
|
||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri):
|
||||
# Setup log
|
||||
recorder_root_dir = R.get_recorder().root_uri
|
||||
log_file = recorder_root_dir / "{:}.log".format(experiment_name)
|
||||
recorder_root_dir = R.get_recorder().get_local_dir()
|
||||
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
|
||||
set_log_basic_config(log_file)
|
||||
logger = get_module_logger("q.run_exp")
|
||||
logger.info("task_config={:}".format(task_config))
|
||||
@ -56,8 +79,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
|
||||
|
||||
# Train model
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
if 'save_path' in inspect.getfullargspec(model.fit).args:
|
||||
model_fit_kwargs['save_path'] = str(recorder_root_dir / 'model-ckps')
|
||||
if "save_path" in inspect.getfullargspec(model.fit).args:
|
||||
model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model-ckps")
|
||||
model.fit(**model_fit_kwargs)
|
||||
# Get the recorder
|
||||
recorder = R.get_recorder()
|
||||
|
Loading…
Reference in New Issue
Block a user