Fix bugs
This commit is contained in:
		| @@ -23,10 +23,12 @@ if str(lib_dir) not in sys.path: | ||||
| import qlib | ||||
| from qlib import config as qconfig | ||||
| from qlib.workflow import R | ||||
| qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) | ||||
|  | ||||
| qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) | ||||
|  | ||||
| from utils.qlib_utils import QResult | ||||
|  | ||||
|  | ||||
| def filter_finished(recorders): | ||||
|     returned_recorders = dict() | ||||
|     not_finished = 0 | ||||
| @@ -41,9 +43,10 @@ def filter_finished(recorders): | ||||
| def add_to_dict(xdict, timestamp, value): | ||||
|     date = timestamp.date().strftime("%Y-%m-%d") | ||||
|     if date in xdict: | ||||
|       raise ValueError("This date [{:}] is already in the dict".format(date)) | ||||
|         raise ValueError("This date [{:}] is already in the dict".format(date)) | ||||
|     xdict[date] = value | ||||
|  | ||||
|  | ||||
| def query_info(save_dir, verbose, name_filter, key_map): | ||||
|     if isinstance(save_dir, list): | ||||
|         results = [] | ||||
| @@ -61,7 +64,10 @@ def query_info(save_dir, verbose, name_filter, key_map): | ||||
|     for idx, (key, experiment) in enumerate(experiments.items()): | ||||
|         if experiment.id == "0": | ||||
|             continue | ||||
|         if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None: | ||||
|         if ( | ||||
|             name_filter is not None | ||||
|             and re.fullmatch(name_filter, experiment.name) is None | ||||
|         ): | ||||
|             continue | ||||
|         recorders = experiment.list_recorders() | ||||
|         recorders, not_finished = filter_finished(recorders) | ||||
| @@ -77,10 +83,10 @@ def query_info(save_dir, verbose, name_filter, key_map): | ||||
|             ) | ||||
|         result = QResult(experiment.name) | ||||
|         for recorder_id, recorder in recorders.items(): | ||||
|             file_names = ['results-train.pkl', 'results-valid.pkl', 'results-test.pkl'] | ||||
|             file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"] | ||||
|             date2IC = OrderedDict() | ||||
|             for file_name in file_names: | ||||
|                 xtemp = recorder.load_object(file_name)['all-IC'] | ||||
|                 xtemp = recorder.load_object(file_name)["all-IC"] | ||||
|                 timestamps, values = xtemp.index.tolist(), xtemp.tolist() | ||||
|                 for timestamp, value in zip(timestamps, values): | ||||
|                     add_to_dict(date2IC, timestamp, value) | ||||
| @@ -104,7 +110,7 @@ def query_info(save_dir, verbose, name_filter, key_map): | ||||
|  | ||||
|  | ||||
| ## | ||||
| paths = [root_dir / 'outputs' / 'qlib-baselines-csi300'] | ||||
| paths = [root_dir / "outputs" / "qlib-baselines-csi300"] | ||||
| paths = [path.resolve() for path in paths] | ||||
| print(paths) | ||||
|  | ||||
| @@ -112,12 +118,12 @@ key_map = dict() | ||||
| for xset in ("train", "valid", "test"): | ||||
|     key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset) | ||||
|     key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset) | ||||
| qresults = query_info(paths, False, 'TSF-2x24-drop0_0s.*-.*-01', key_map) | ||||
| print('Find {:} results'.format(len(qresults))) | ||||
| qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map) | ||||
| print("Find {:} results".format(len(qresults))) | ||||
| times = [] | ||||
| for qresult in qresults: | ||||
|     times.append(qresult.name.split('0_0s')[-1]) | ||||
|     times.append(qresult.name.split("0_0s")[-1]) | ||||
| print(times) | ||||
| save_path = os.path.join(note_dir, 'temp-time-x.pth') | ||||
| save_path = os.path.join(note_dir, "temp-time-x.pth") | ||||
| torch.save(qresults, save_path) | ||||
| print(save_path) | ||||
|   | ||||
| @@ -24,38 +24,38 @@ from qlib.model.base import Model | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
| qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) | ||||
| qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) | ||||
|  | ||||
| dataset_config = { | ||||
|             "class": "DatasetH", | ||||
|             "module_path": "qlib.data.dataset", | ||||
|     "class": "DatasetH", | ||||
|     "module_path": "qlib.data.dataset", | ||||
|     "kwargs": { | ||||
|         "handler": { | ||||
|             "class": "Alpha360", | ||||
|             "module_path": "qlib.contrib.data.handler", | ||||
|             "kwargs": { | ||||
|                 "handler": { | ||||
|                     "class": "Alpha360", | ||||
|                     "module_path": "qlib.contrib.data.handler", | ||||
|                     "kwargs": { | ||||
|                         "start_time": "2008-01-01", | ||||
|                         "end_time": "2020-08-01", | ||||
|                         "fit_start_time": "2008-01-01", | ||||
|                         "fit_end_time": "2014-12-31", | ||||
|                         "instruments": "csi100", | ||||
|                     }, | ||||
|                 }, | ||||
|                 "segments": { | ||||
|                     "train": ("2008-01-01", "2014-12-31"), | ||||
|                     "valid": ("2015-01-01", "2016-12-31"), | ||||
|                     "test": ("2017-01-01", "2020-08-01"), | ||||
|                 }, | ||||
|                 "start_time": "2008-01-01", | ||||
|                 "end_time": "2020-08-01", | ||||
|                 "fit_start_time": "2008-01-01", | ||||
|                 "fit_end_time": "2014-12-31", | ||||
|                 "instruments": "csi100", | ||||
|             }, | ||||
|         } | ||||
|         }, | ||||
|         "segments": { | ||||
|             "train": ("2008-01-01", "2014-12-31"), | ||||
|             "valid": ("2015-01-01", "2016-12-31"), | ||||
|             "test": ("2017-01-01", "2020-08-01"), | ||||
|         }, | ||||
|     }, | ||||
| } | ||||
| pprint.pprint(dataset_config) | ||||
| dataset = init_instance_by_config(dataset_config) | ||||
|  | ||||
| df_train, df_valid, df_test = dataset.prepare( | ||||
|             ["train", "valid", "test"], | ||||
|             col_set=["feature", "label"], | ||||
|             data_key=DataHandlerLP.DK_L, | ||||
|         ) | ||||
|     ["train", "valid", "test"], | ||||
|     col_set=["feature", "label"], | ||||
|     data_key=DataHandlerLP.DK_L, | ||||
| ) | ||||
| model = get_transformer(None) | ||||
| print(model) | ||||
|  | ||||
| @@ -72,4 +72,5 @@ label = labels[batch][mask] | ||||
| loss = torch.nn.functional.mse_loss(pred, label) | ||||
|  | ||||
| from sklearn.metrics import mean_squared_error | ||||
|  | ||||
| mse_loss = mean_squared_error(pred.numpy(), label.numpy()) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user