Update organize
This commit is contained in:
		 Submodule .latent-data/qlib updated: f809f0a063...253378a44e
									
								
							| @@ -131,9 +131,9 @@ def query_info(save_dir, verbose): | |||||||
|         "ICIR": "ICIR", |         "ICIR": "ICIR", | ||||||
|         "Rank IC": "Rank_IC", |         "Rank IC": "Rank_IC", | ||||||
|         "Rank ICIR": "Rank_ICIR", |         "Rank ICIR": "Rank_ICIR", | ||||||
|         "excess_return_with_cost.annualized_return": "Annualized_Return", |         # "excess_return_with_cost.annualized_return": "Annualized_Return", | ||||||
|         # "excess_return_with_cost.information_ratio": "Information_Ratio", |         # "excess_return_with_cost.information_ratio": "Information_Ratio", | ||||||
|         "excess_return_with_cost.max_drawdown": "Max_Drawdown", |         # "excess_return_with_cost.max_drawdown": "Max_Drawdown", | ||||||
|     } |     } | ||||||
|     all_keys = list(key_map.values()) |     all_keys = list(key_map.values()) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -307,25 +307,23 @@ class QuantTransformer(Model): | |||||||
|     def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): |     def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): | ||||||
|         if not self.fitted: |         if not self.fitted: | ||||||
|             raise ValueError("The model is not fitted yet!") |             raise ValueError("The model is not fitted yet!") | ||||||
|         x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) |         x_test = dataset.prepare( | ||||||
|  |             segment, col_set="feature", data_key=DataHandlerLP.DK_I | ||||||
|  |         ) | ||||||
|         index = x_test.index |         index = x_test.index | ||||||
|  |  | ||||||
|  |         with torch.no_grad(): | ||||||
|             self.model.eval() |             self.model.eval() | ||||||
|             x_values = x_test.values |             x_values = x_test.values | ||||||
|             sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"] |             sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"] | ||||||
|             preds = [] |             preds = [] | ||||||
|  |  | ||||||
|             for begin in range(sample_num)[::batch_size]: |             for begin in range(sample_num)[::batch_size]: | ||||||
|  |  | ||||||
|                 if sample_num - begin < batch_size: |                 if sample_num - begin < batch_size: | ||||||
|                     end = sample_num |                     end = sample_num | ||||||
|                 else: |                 else: | ||||||
|                     end = begin + batch_size |                     end = begin + batch_size | ||||||
|  |  | ||||||
|                 x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) |                 x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) | ||||||
|  |  | ||||||
|                 with torch.no_grad(): |                 with torch.no_grad(): | ||||||
|                     pred = self.model(x_batch).detach().cpu().numpy() |                     pred = self.model(x_batch).detach().cpu().numpy() | ||||||
|                 preds.append(pred) |                 preds.append(pred) | ||||||
|  |  | ||||||
|         return pd.Series(np.concatenate(preds), index=index) |         return pd.Series(np.concatenate(preds), index=index) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user