Update baselines
This commit is contained in:
parent
0055511829
commit
53cb5f1fdd
@ -1 +1 @@
|
||||
Subproject commit 9d04ae467618505d293df9bb0fa2f20004a6e00c
|
||||
Subproject commit f809f0a0636ca7baeb8e7e98c5a8b387096e7217
|
@ -67,6 +67,18 @@ def extend_transformer_settings(alg2configs, name):
|
||||
return alg2configs
|
||||
|
||||
|
||||
def remove_PortAnaRecord(alg2configs):
|
||||
alg2configs = copy.deepcopy(alg2configs)
|
||||
for key, config in alg2configs.items():
|
||||
xlist = config["task"]["record"]
|
||||
new_list = []
|
||||
for x in xlist:
|
||||
if x["class"] != "PortAnaRecord":
|
||||
new_list.append(x)
|
||||
config["task"]["record"] = new_list
|
||||
return alg2configs
|
||||
|
||||
|
||||
def retrieve_configs():
|
||||
# https://github.com/microsoft/qlib/blob/main/examples/benchmarks/
|
||||
config_dir = (lib_dir / ".." / "configs" / "qlib").resolve()
|
||||
@ -105,6 +117,12 @@ def retrieve_configs():
|
||||
)
|
||||
)
|
||||
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
||||
alg2configs = remove_PortAnaRecord(alg2configs)
|
||||
print(
|
||||
"There are {:} algorithms : {:}".format(
|
||||
len(alg2configs), list(alg2configs.keys())
|
||||
)
|
||||
)
|
||||
return alg2configs
|
||||
|
||||
|
||||
|
@ -223,7 +223,7 @@ if __name__ == "__main__":
|
||||
info_dict["heads"],
|
||||
info_dict["values"],
|
||||
info_dict["names"],
|
||||
space=12,
|
||||
space=14,
|
||||
verbose=True,
|
||||
sort_key=True,
|
||||
)
|
||||
|
@ -8,16 +8,12 @@ import os, math, random
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional, Text
|
||||
|
||||
from qlib.utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
import torch
|
||||
@ -308,10 +304,10 @@ class QuantTransformer(Model):
|
||||
torch.cuda.empty_cache()
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset, segment="test"):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
|
||||
self.model.eval()
|
||||
|
Loading…
Reference in New Issue
Block a user