Update baselines
This commit is contained in:
		 Submodule .latent-data/qlib updated: 9d04ae4676...f809f0a063
									
								
							| @@ -67,6 +67,18 @@ def extend_transformer_settings(alg2configs, name): | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def remove_PortAnaRecord(alg2configs): | ||||
|     alg2configs = copy.deepcopy(alg2configs) | ||||
|     for key, config in alg2configs.items(): | ||||
|         xlist = config["task"]["record"] | ||||
|         new_list = [] | ||||
|         for x in xlist: | ||||
|             if x["class"] != "PortAnaRecord": | ||||
|                 new_list.append(x) | ||||
|         config["task"]["record"] = new_list | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def retrieve_configs(): | ||||
|     # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ | ||||
|     config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() | ||||
| @@ -105,6 +117,12 @@ def retrieve_configs(): | ||||
|             ) | ||||
|         ) | ||||
|     alg2configs = extend_transformer_settings(alg2configs, "TSF") | ||||
|     alg2configs = remove_PortAnaRecord(alg2configs) | ||||
|     print( | ||||
|         "There are {:} algorithms : {:}".format( | ||||
|             len(alg2configs), list(alg2configs.keys()) | ||||
|         ) | ||||
|     ) | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -223,7 +223,7 @@ if __name__ == "__main__": | ||||
|         info_dict["heads"], | ||||
|         info_dict["values"], | ||||
|         info_dict["names"], | ||||
|         space=12, | ||||
|         space=14, | ||||
|         verbose=True, | ||||
|         sort_key=True, | ||||
|     ) | ||||
|   | ||||
| @@ -8,16 +8,12 @@ import os, math, random | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| from typing import Text, Union | ||||
| import copy | ||||
| from functools import partial | ||||
| from typing import Optional, Text | ||||
|  | ||||
| from qlib.utils import ( | ||||
|     unpack_archive_with_buffer, | ||||
|     save_multiple_parts_file, | ||||
|     get_or_create_path, | ||||
|     drop_nan_by_y_index, | ||||
| ) | ||||
| from qlib.utils import get_or_create_path | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
| import torch | ||||
| @@ -308,10 +304,10 @@ class QuantTransformer(Model): | ||||
|             torch.cuda.empty_cache() | ||||
|         self.fitted = True | ||||
|  | ||||
|     def predict(self, dataset, segment="test"): | ||||
|     def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): | ||||
|         if not self.fitted: | ||||
|             raise ValueError("The model is not fitted yet!") | ||||
|         x_test = dataset.prepare(segment, col_set="feature") | ||||
|         x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) | ||||
|         index = x_test.index | ||||
|  | ||||
|         self.model.eval() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user