Update organize

This commit is contained in:
D-X-Y 2021-03-29 04:23:33 +00:00
parent 53cb5f1fdd
commit e637cddc39
3 changed files with 20 additions and 22 deletions

@ -1 +1 @@
Subproject commit f809f0a0636ca7baeb8e7e98c5a8b387096e7217 Subproject commit 253378a44e88a9fcff17d23b589e2d4832f587aa

View File

@ -131,9 +131,9 @@ def query_info(save_dir, verbose):
"ICIR": "ICIR", "ICIR": "ICIR",
"Rank IC": "Rank_IC", "Rank IC": "Rank_IC",
"Rank ICIR": "Rank_ICIR", "Rank ICIR": "Rank_ICIR",
"excess_return_with_cost.annualized_return": "Annualized_Return", # "excess_return_with_cost.annualized_return": "Annualized_Return",
# "excess_return_with_cost.information_ratio": "Information_Ratio", # "excess_return_with_cost.information_ratio": "Information_Ratio",
"excess_return_with_cost.max_drawdown": "Max_Drawdown", # "excess_return_with_cost.max_drawdown": "Max_Drawdown",
} }
all_keys = list(key_map.values()) all_keys = list(key_map.values())

View File

@ -307,25 +307,23 @@ class QuantTransformer(Model):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted: if not self.fitted:
raise ValueError("The model is not fitted yet!") raise ValueError("The model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) x_test = dataset.prepare(
segment, col_set="feature", data_key=DataHandlerLP.DK_I
)
index = x_test.index index = x_test.index
self.model.eval() with torch.no_grad():
x_values = x_test.values self.model.eval()
sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"] x_values = x_test.values
preds = [] sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"]
preds = []
for begin in range(sample_num)[::batch_size]: for begin in range(sample_num)[::batch_size]:
if sample_num - begin < batch_size:
if sample_num - begin < batch_size: end = sample_num
end = sample_num else:
else: end = begin + batch_size
end = begin + batch_size x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) pred = self.model(x_batch).detach().cpu().numpy()
preds.append(pred)
with torch.no_grad():
pred = self.model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index) return pd.Series(np.concatenate(preds), index=index)