Update baselines
This commit is contained in:
parent
0055511829
commit
53cb5f1fdd
@ -1 +1 @@
|
|||||||
Subproject commit 9d04ae467618505d293df9bb0fa2f20004a6e00c
|
Subproject commit f809f0a0636ca7baeb8e7e98c5a8b387096e7217
|
@ -67,6 +67,18 @@ def extend_transformer_settings(alg2configs, name):
|
|||||||
return alg2configs
|
return alg2configs
|
||||||
|
|
||||||
|
|
||||||
|
def remove_PortAnaRecord(alg2configs):
|
||||||
|
alg2configs = copy.deepcopy(alg2configs)
|
||||||
|
for key, config in alg2configs.items():
|
||||||
|
xlist = config["task"]["record"]
|
||||||
|
new_list = []
|
||||||
|
for x in xlist:
|
||||||
|
if x["class"] != "PortAnaRecord":
|
||||||
|
new_list.append(x)
|
||||||
|
config["task"]["record"] = new_list
|
||||||
|
return alg2configs
|
||||||
|
|
||||||
|
|
||||||
def retrieve_configs():
|
def retrieve_configs():
|
||||||
# https://github.com/microsoft/qlib/blob/main/examples/benchmarks/
|
# https://github.com/microsoft/qlib/blob/main/examples/benchmarks/
|
||||||
config_dir = (lib_dir / ".." / "configs" / "qlib").resolve()
|
config_dir = (lib_dir / ".." / "configs" / "qlib").resolve()
|
||||||
@ -105,6 +117,12 @@ def retrieve_configs():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
||||||
|
alg2configs = remove_PortAnaRecord(alg2configs)
|
||||||
|
print(
|
||||||
|
"There are {:} algorithms : {:}".format(
|
||||||
|
len(alg2configs), list(alg2configs.keys())
|
||||||
|
)
|
||||||
|
)
|
||||||
return alg2configs
|
return alg2configs
|
||||||
|
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ if __name__ == "__main__":
|
|||||||
info_dict["heads"],
|
info_dict["heads"],
|
||||||
info_dict["values"],
|
info_dict["values"],
|
||||||
info_dict["names"],
|
info_dict["names"],
|
||||||
space=12,
|
space=14,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
sort_key=True,
|
sort_key=True,
|
||||||
)
|
)
|
||||||
|
@ -8,16 +8,12 @@ import os, math, random
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from typing import Text, Union
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Text
|
from typing import Optional, Text
|
||||||
|
|
||||||
from qlib.utils import (
|
from qlib.utils import get_or_create_path
|
||||||
unpack_archive_with_buffer,
|
|
||||||
save_multiple_parts_file,
|
|
||||||
get_or_create_path,
|
|
||||||
drop_nan_by_y_index,
|
|
||||||
)
|
|
||||||
from qlib.log import get_module_logger
|
from qlib.log import get_module_logger
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -308,10 +304,10 @@ class QuantTransformer(Model):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.fitted = True
|
self.fitted = True
|
||||||
|
|
||||||
def predict(self, dataset, segment="test"):
|
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||||
if not self.fitted:
|
if not self.fitted:
|
||||||
raise ValueError("The model is not fitted yet!")
|
raise ValueError("The model is not fitted yet!")
|
||||||
x_test = dataset.prepare(segment, col_set="feature")
|
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||||
index = x_test.index
|
index = x_test.index
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user