Update organize
This commit is contained in:
parent
53cb5f1fdd
commit
e637cddc39
@ -1 +1 @@
|
|||||||
Subproject commit f809f0a0636ca7baeb8e7e98c5a8b387096e7217
|
Subproject commit 253378a44e88a9fcff17d23b589e2d4832f587aa
|
@ -131,9 +131,9 @@ def query_info(save_dir, verbose):
|
|||||||
"ICIR": "ICIR",
|
"ICIR": "ICIR",
|
||||||
"Rank IC": "Rank_IC",
|
"Rank IC": "Rank_IC",
|
||||||
"Rank ICIR": "Rank_ICIR",
|
"Rank ICIR": "Rank_ICIR",
|
||||||
"excess_return_with_cost.annualized_return": "Annualized_Return",
|
# "excess_return_with_cost.annualized_return": "Annualized_Return",
|
||||||
# "excess_return_with_cost.information_ratio": "Information_Ratio",
|
# "excess_return_with_cost.information_ratio": "Information_Ratio",
|
||||||
"excess_return_with_cost.max_drawdown": "Max_Drawdown",
|
# "excess_return_with_cost.max_drawdown": "Max_Drawdown",
|
||||||
}
|
}
|
||||||
all_keys = list(key_map.values())
|
all_keys = list(key_map.values())
|
||||||
|
|
||||||
|
@ -307,25 +307,23 @@ class QuantTransformer(Model):
|
|||||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||||
if not self.fitted:
|
if not self.fitted:
|
||||||
raise ValueError("The model is not fitted yet!")
|
raise ValueError("The model is not fitted yet!")
|
||||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
x_test = dataset.prepare(
|
||||||
|
segment, col_set="feature", data_key=DataHandlerLP.DK_I
|
||||||
|
)
|
||||||
index = x_test.index
|
index = x_test.index
|
||||||
|
|
||||||
self.model.eval()
|
with torch.no_grad():
|
||||||
x_values = x_test.values
|
self.model.eval()
|
||||||
sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"]
|
x_values = x_test.values
|
||||||
preds = []
|
sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"]
|
||||||
|
preds = []
|
||||||
for begin in range(sample_num)[::batch_size]:
|
for begin in range(sample_num)[::batch_size]:
|
||||||
|
if sample_num - begin < batch_size:
|
||||||
if sample_num - begin < batch_size:
|
end = sample_num
|
||||||
end = sample_num
|
else:
|
||||||
else:
|
end = begin + batch_size
|
||||||
end = begin + batch_size
|
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
pred = self.model(x_batch).detach().cpu().numpy()
|
||||||
|
preds.append(pred)
|
||||||
with torch.no_grad():
|
|
||||||
pred = self.model(x_batch).detach().cpu().numpy()
|
|
||||||
preds.append(pred)
|
|
||||||
|
|
||||||
return pd.Series(np.concatenate(preds), index=index)
|
return pd.Series(np.concatenate(preds), index=index)
|
||||||
|
Loading…
Reference in New Issue
Block a user