From 53cb5f1fddf06ac3581dd14acb92db7f27b60614 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 28 Mar 2021 10:57:20 +0000 Subject: [PATCH] Update baselines --- .latent-data/qlib | 2 +- exps/trading/baselines.py | 18 ++++++++++++++++++ exps/trading/organize_results.py | 2 +- lib/trade_models/quant_transformer.py | 12 ++++-------- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/.latent-data/qlib b/.latent-data/qlib index 9d04ae4..f809f0a 160000 --- a/.latent-data/qlib +++ b/.latent-data/qlib @@ -1 +1 @@ -Subproject commit 9d04ae467618505d293df9bb0fa2f20004a6e00c +Subproject commit f809f0a0636ca7baeb8e7e98c5a8b387096e7217 diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index fa54c85..18d6df1 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -67,6 +67,18 @@ def extend_transformer_settings(alg2configs, name): return alg2configs +def remove_PortAnaRecord(alg2configs): + alg2configs = copy.deepcopy(alg2configs) + for key, config in alg2configs.items(): + xlist = config["task"]["record"] + new_list = [] + for x in xlist: + if x["class"] != "PortAnaRecord": + new_list.append(x) + config["task"]["record"] = new_list + return alg2configs + + def retrieve_configs(): # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() @@ -105,6 +117,12 @@ def retrieve_configs(): ) ) alg2configs = extend_transformer_settings(alg2configs, "TSF") + alg2configs = remove_PortAnaRecord(alg2configs) + print( + "There are {:} algorithms : {:}".format( + len(alg2configs), list(alg2configs.keys()) + ) + ) return alg2configs diff --git a/exps/trading/organize_results.py b/exps/trading/organize_results.py index c166bb4..50017e7 100644 --- a/exps/trading/organize_results.py +++ b/exps/trading/organize_results.py @@ -223,7 +223,7 @@ if __name__ == "__main__": info_dict["heads"], info_dict["values"], info_dict["names"], - space=12, + space=14, verbose=True, sort_key=True, ) diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index 624939e..69cc8c5 100644 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -8,16 +8,12 @@ import os, math, random from collections import OrderedDict import numpy as np import pandas as pd +from typing import Text, Union import copy from functools import partial from typing import Optional, Text -from qlib.utils import ( - unpack_archive_with_buffer, - save_multiple_parts_file, - get_or_create_path, - drop_nan_by_y_index, -) +from qlib.utils import get_or_create_path from qlib.log import get_module_logger import torch @@ -308,10 +304,10 @@ class QuantTransformer(Model): torch.cuda.empty_cache() self.fitted = True - def predict(self, dataset, segment="test"): + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("The model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) index = x_test.index self.model.eval()