Update to accomendate last updates of qlib
This commit is contained in:
		| @@ -22,13 +22,13 @@ if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from procedures.q_exps import update_gpu | ||||
| from procedures.q_exps import update_market | ||||
| from procedures.q_exps import run_exp | ||||
|  | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
| from qlib.log import set_log_basic_config | ||||
|  | ||||
|  | ||||
| def retrieve_configs(): | ||||
| @@ -49,6 +49,7 @@ def retrieve_configs(): | ||||
|     alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml" | ||||
|     # DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf | ||||
|     alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml" | ||||
|     alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml" | ||||
|  | ||||
|     # find the yaml paths | ||||
|     alg2paths = OrderedDict() | ||||
| @@ -66,6 +67,7 @@ def main(xargs, exp_yaml): | ||||
|  | ||||
|     with open(exp_yaml) as fp: | ||||
|         config = yaml.safe_load(fp) | ||||
|     config = update_market(config, xargs.market) | ||||
|     config = update_gpu(config, xargs.gpu) | ||||
|  | ||||
|     qlib.init(**config.get("qlib_init")) | ||||
| @@ -77,7 +79,7 @@ def main(xargs, exp_yaml): | ||||
|  | ||||
|     for irun in range(xargs.times): | ||||
|         run_exp( | ||||
|             config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir | ||||
|             config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @@ -87,6 +89,7 @@ if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Baselines") | ||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.") | ||||
|     parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.") | ||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||
|     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||
|     parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.") | ||||
|   | ||||
| @@ -105,7 +105,7 @@ def filter_finished(recorders): | ||||
|  | ||||
|  | ||||
| def query_info(save_dir, verbose): | ||||
|     R.reset_default_uri(save_dir) | ||||
|     R.set_uri(save_dir) | ||||
|     experiments = R.list_experiments() | ||||
|  | ||||
|     key_map = { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user