Update baselines

This commit is contained in:
D-X-Y 2021-03-28 10:57:20 +00:00
parent 0055511829
commit 53cb5f1fdd
4 changed files with 24 additions and 10 deletions

@ -1 +1 @@
Subproject commit 9d04ae467618505d293df9bb0fa2f20004a6e00c Subproject commit f809f0a0636ca7baeb8e7e98c5a8b387096e7217

View File

@ -67,6 +67,18 @@ def extend_transformer_settings(alg2configs, name):
return alg2configs return alg2configs
def remove_PortAnaRecord(alg2configs):
alg2configs = copy.deepcopy(alg2configs)
for key, config in alg2configs.items():
xlist = config["task"]["record"]
new_list = []
for x in xlist:
if x["class"] != "PortAnaRecord":
new_list.append(x)
config["task"]["record"] = new_list
return alg2configs
def retrieve_configs(): def retrieve_configs():
# https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/
config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() config_dir = (lib_dir / ".." / "configs" / "qlib").resolve()
@ -105,6 +117,12 @@ def retrieve_configs():
) )
) )
alg2configs = extend_transformer_settings(alg2configs, "TSF") alg2configs = extend_transformer_settings(alg2configs, "TSF")
alg2configs = remove_PortAnaRecord(alg2configs)
print(
"There are {:} algorithms : {:}".format(
len(alg2configs), list(alg2configs.keys())
)
)
return alg2configs return alg2configs

View File

@ -223,7 +223,7 @@ if __name__ == "__main__":
info_dict["heads"], info_dict["heads"],
info_dict["values"], info_dict["values"],
info_dict["names"], info_dict["names"],
space=12, space=14,
verbose=True, verbose=True,
sort_key=True, sort_key=True,
) )

View File

@ -8,16 +8,12 @@ import os, math, random
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from typing import Text, Union
import copy import copy
from functools import partial from functools import partial
from typing import Optional, Text from typing import Optional, Text
from qlib.utils import ( from qlib.utils import get_or_create_path
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from qlib.log import get_module_logger from qlib.log import get_module_logger
import torch import torch
@ -308,10 +304,10 @@ class QuantTransformer(Model):
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.fitted = True self.fitted = True
def predict(self, dataset, segment="test"): def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted: if not self.fitted:
raise ValueError("The model is not fitted yet!") raise ValueError("The model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature") x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index index = x_test.index
self.model.eval() self.model.eval()