Update baselines
This commit is contained in:
		 Submodule .latent-data/qlib updated: 9d04ae4676...f809f0a063
									
								
							| @@ -67,6 +67,18 @@ def extend_transformer_settings(alg2configs, name): | |||||||
|     return alg2configs |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def remove_PortAnaRecord(alg2configs): | ||||||
|  |     alg2configs = copy.deepcopy(alg2configs) | ||||||
|  |     for key, config in alg2configs.items(): | ||||||
|  |         xlist = config["task"]["record"] | ||||||
|  |         new_list = [] | ||||||
|  |         for x in xlist: | ||||||
|  |             if x["class"] != "PortAnaRecord": | ||||||
|  |                 new_list.append(x) | ||||||
|  |         config["task"]["record"] = new_list | ||||||
|  |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
| def retrieve_configs(): | def retrieve_configs(): | ||||||
|     # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ |     # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ | ||||||
|     config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() |     config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() | ||||||
| @@ -105,6 +117,12 @@ def retrieve_configs(): | |||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|     alg2configs = extend_transformer_settings(alg2configs, "TSF") |     alg2configs = extend_transformer_settings(alg2configs, "TSF") | ||||||
|  |     alg2configs = remove_PortAnaRecord(alg2configs) | ||||||
|  |     print( | ||||||
|  |         "There are {:} algorithms : {:}".format( | ||||||
|  |             len(alg2configs), list(alg2configs.keys()) | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|     return alg2configs |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -223,7 +223,7 @@ if __name__ == "__main__": | |||||||
|         info_dict["heads"], |         info_dict["heads"], | ||||||
|         info_dict["values"], |         info_dict["values"], | ||||||
|         info_dict["names"], |         info_dict["names"], | ||||||
|         space=12, |         space=14, | ||||||
|         verbose=True, |         verbose=True, | ||||||
|         sort_key=True, |         sort_key=True, | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -8,16 +8,12 @@ import os, math, random | |||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| import numpy as np | import numpy as np | ||||||
| import pandas as pd | import pandas as pd | ||||||
|  | from typing import Text, Union | ||||||
| import copy | import copy | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Optional, Text | from typing import Optional, Text | ||||||
|  |  | ||||||
| from qlib.utils import ( | from qlib.utils import get_or_create_path | ||||||
|     unpack_archive_with_buffer, |  | ||||||
|     save_multiple_parts_file, |  | ||||||
|     get_or_create_path, |  | ||||||
|     drop_nan_by_y_index, |  | ||||||
| ) |  | ||||||
| from qlib.log import get_module_logger | from qlib.log import get_module_logger | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| @@ -308,10 +304,10 @@ class QuantTransformer(Model): | |||||||
|             torch.cuda.empty_cache() |             torch.cuda.empty_cache() | ||||||
|         self.fitted = True |         self.fitted = True | ||||||
|  |  | ||||||
|     def predict(self, dataset, segment="test"): |     def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): | ||||||
|         if not self.fitted: |         if not self.fitted: | ||||||
|             raise ValueError("The model is not fitted yet!") |             raise ValueError("The model is not fitted yet!") | ||||||
|         x_test = dataset.prepare(segment, col_set="feature") |         x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) | ||||||
|         index = x_test.index |         index = x_test.index | ||||||
|  |  | ||||||
|         self.model.eval() |         self.model.eval() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user