Simplify Q-T
This commit is contained in:
		| @@ -37,7 +37,9 @@ from qlib.data.dataset.handler import DataHandlerLP | |||||||
|  |  | ||||||
| default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1) | default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1) | ||||||
|  |  | ||||||
| default_opt_config = dict(epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam") | default_opt_config = dict( | ||||||
|  |     epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam", num_workers=4 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuantTransformer(Model): | class QuantTransformer(Model): | ||||||
| @@ -159,6 +161,16 @@ class QuantTransformer(Model): | |||||||
|                 torch.from_numpy(df_data["label"].values).squeeze().float(), |                 torch.from_numpy(df_data["label"].values).squeeze().float(), | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |         def _prepare_loader(dataset, shuffle): | ||||||
|  |             return th_data.DataLoader( | ||||||
|  |                 dataset, | ||||||
|  |                 batch_size=self.opt_config["batch_size"], | ||||||
|  |                 drop_last=False, | ||||||
|  |                 pin_memory=True, | ||||||
|  |                 num_workers=self.opt_config["num_workers"], | ||||||
|  |                 shuffle=shuffle, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         df_train, df_valid, df_test = dataset.prepare( |         df_train, df_valid, df_test = dataset.prepare( | ||||||
|             ["train", "valid", "test"], |             ["train", "valid", "test"], | ||||||
|             col_set=["feature", "label"], |             col_set=["feature", "label"], | ||||||
| @@ -169,15 +181,10 @@ class QuantTransformer(Model): | |||||||
|             _prepare_dataset(df_valid), |             _prepare_dataset(df_valid), | ||||||
|             _prepare_dataset(df_test), |             _prepare_dataset(df_test), | ||||||
|         ) |         ) | ||||||
|  |         train_loader, valid_loader, test_loader = ( | ||||||
|         train_loader = th_data.DataLoader( |             _prepare_loader(train_dataset, True), | ||||||
|             train_dataset, batch_size=self.opt_config["batch_size"], shuffle=True, drop_last=False, pin_memory=True |             _prepare_loader(valid_dataset, False), | ||||||
|         ) |             _prepare_loader(test_dataset, False), | ||||||
|         valid_loader = th_data.DataLoader( |  | ||||||
|             valid_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True |  | ||||||
|         ) |  | ||||||
|         test_loader = th_data.DataLoader( |  | ||||||
|             test_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True |  | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         if save_path == None: |         if save_path == None: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user