Update organize
This commit is contained in:
		| @@ -307,25 +307,23 @@ class QuantTransformer(Model): | ||||
|     def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): | ||||
|         if not self.fitted: | ||||
|             raise ValueError("The model is not fitted yet!") | ||||
|         x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) | ||||
|         x_test = dataset.prepare( | ||||
|             segment, col_set="feature", data_key=DataHandlerLP.DK_I | ||||
|         ) | ||||
|         index = x_test.index | ||||
|  | ||||
|         self.model.eval() | ||||
|         x_values = x_test.values | ||||
|         sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"] | ||||
|         preds = [] | ||||
|  | ||||
|         for begin in range(sample_num)[::batch_size]: | ||||
|  | ||||
|             if sample_num - begin < batch_size: | ||||
|                 end = sample_num | ||||
|             else: | ||||
|                 end = begin + batch_size | ||||
|  | ||||
|             x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) | ||||
|  | ||||
|             with torch.no_grad(): | ||||
|                 pred = self.model(x_batch).detach().cpu().numpy() | ||||
|             preds.append(pred) | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             self.model.eval() | ||||
|             x_values = x_test.values | ||||
|             sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"] | ||||
|             preds = [] | ||||
|             for begin in range(sample_num)[::batch_size]: | ||||
|                 if sample_num - begin < batch_size: | ||||
|                     end = sample_num | ||||
|                 else: | ||||
|                     end = begin + batch_size | ||||
|                 x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) | ||||
|                 with torch.no_grad(): | ||||
|                     pred = self.model(x_batch).detach().cpu().numpy() | ||||
|                 preds.append(pred) | ||||
|         return pd.Series(np.concatenate(preds), index=index) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user