Update to accomendate last updates of qlib
This commit is contained in:
		| @@ -3,15 +3,38 @@ | ||||
| ##################################################### | ||||
|  | ||||
| import inspect | ||||
| import os | ||||
| import logging | ||||
|  | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
| from qlib.log import set_log_basic_config | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
|  | ||||
| def set_log_basic_config(filename=None, format=None, level=None): | ||||
|     """ | ||||
|     Set the basic configuration for the logging system. | ||||
|     See details at https://docs.python.org/3/library/logging.html#logging.basicConfig | ||||
|     :param filename: str or None | ||||
|         The path to save the logs. | ||||
|     :param format: the logging format | ||||
|     :param level: int | ||||
|     :return: Logger | ||||
|         Logger object. | ||||
|     """ | ||||
|     from qlib.config import C | ||||
|  | ||||
|     if level is None: | ||||
|         level = C.logging_level | ||||
|  | ||||
|     if format is None: | ||||
|         format = C.logging_config["formatters"]["logger_format"]["format"] | ||||
|  | ||||
|     logging.basicConfig(filename=filename, format=format, level=level) | ||||
|  | ||||
|  | ||||
| def update_gpu(config, gpu): | ||||
|     config = config.copy() | ||||
|     if "task" in config and "model" in config["task"]: | ||||
| @@ -46,8 +69,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|     # Let's start the experiment. | ||||
|     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri): | ||||
|         # Setup log | ||||
|         recorder_root_dir = R.get_recorder().root_uri | ||||
|         log_file = recorder_root_dir / "{:}.log".format(experiment_name) | ||||
|         recorder_root_dir = R.get_recorder().get_local_dir() | ||||
|         log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) | ||||
|         set_log_basic_config(log_file) | ||||
|         logger = get_module_logger("q.run_exp") | ||||
|         logger.info("task_config={:}".format(task_config)) | ||||
| @@ -56,8 +79,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|  | ||||
|         # Train model | ||||
|         R.log_params(**flatten_dict(task_config)) | ||||
|         if 'save_path' in inspect.getfullargspec(model.fit).args: | ||||
|           model_fit_kwargs['save_path'] = str(recorder_root_dir / 'model-ckps') | ||||
|         if "save_path" in inspect.getfullargspec(model.fit).args: | ||||
|             model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model-ckps") | ||||
|         model.fit(**model_fit_kwargs) | ||||
|         # Get the recorder | ||||
|         recorder = R.get_recorder() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user