From e637cddc39223f7c96685baa9bce302214500044 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 29 Mar 2021 04:23:33 +0000 Subject: [PATCH] Update organize --- .latent-data/qlib | 2 +- exps/trading/organize_results.py | 4 +-- lib/trade_models/quant_transformer.py | 36 +++++++++++++-------------- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/.latent-data/qlib b/.latent-data/qlib index f809f0a..253378a 160000 --- a/.latent-data/qlib +++ b/.latent-data/qlib @@ -1 +1 @@ -Subproject commit f809f0a0636ca7baeb8e7e98c5a8b387096e7217 +Subproject commit 253378a44e88a9fcff17d23b589e2d4832f587aa diff --git a/exps/trading/organize_results.py b/exps/trading/organize_results.py index 50017e7..1a66b8d 100644 --- a/exps/trading/organize_results.py +++ b/exps/trading/organize_results.py @@ -131,9 +131,9 @@ def query_info(save_dir, verbose): "ICIR": "ICIR", "Rank IC": "Rank_IC", "Rank ICIR": "Rank_ICIR", - "excess_return_with_cost.annualized_return": "Annualized_Return", + # "excess_return_with_cost.annualized_return": "Annualized_Return", # "excess_return_with_cost.information_ratio": "Information_Ratio", - "excess_return_with_cost.max_drawdown": "Max_Drawdown", + # "excess_return_with_cost.max_drawdown": "Max_Drawdown", } all_keys = list(key_map.values()) diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index 69cc8c5..993453b 100644 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -307,25 +307,23 @@ class QuantTransformer(Model): def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("The model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index - self.model.eval() - x_values = x_test.values - sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"] - preds = [] - - for begin in range(sample_num)[::batch_size]: - - if sample_num - begin < batch_size: - end = sample_num - else: - end = begin + batch_size - - x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) - - with torch.no_grad(): - pred = self.model(x_batch).detach().cpu().numpy() - preds.append(pred) - + with torch.no_grad(): + self.model.eval() + x_values = x_test.values + sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"] + preds = [] + for begin in range(sample_num)[::batch_size]: + if sample_num - begin < batch_size: + end = sample_num + else: + end = begin + batch_size + x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) + with torch.no_grad(): + pred = self.model(x_batch).detach().cpu().numpy() + preds.append(pred) return pd.Series(np.concatenate(preds), index=index)