Updates
This commit is contained in:
		
							
								
								
									
										75
									
								
								notebooks/spaces/test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								notebooks/spaces/test.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | ||||
| import os | ||||
| import sys | ||||
| import qlib | ||||
| import pprint | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
|  | ||||
| from pathlib import Path | ||||
| import torch | ||||
|  | ||||
| __file__ = os.path.dirname(os.path.realpath("__file__")) | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||
| print("library path: {:}".format(lib_dir)) | ||||
| assert lib_dir.exists(), "{:} does not exist".format(lib_dir) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from trade_models import get_transformer | ||||
|  | ||||
| from qlib import config as qconfig | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.model.base import Model | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
| qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) | ||||
|  | ||||
| dataset_config = { | ||||
|             "class": "DatasetH", | ||||
|             "module_path": "qlib.data.dataset", | ||||
|             "kwargs": { | ||||
|                 "handler": { | ||||
|                     "class": "Alpha360", | ||||
|                     "module_path": "qlib.contrib.data.handler", | ||||
|                     "kwargs": { | ||||
|                         "start_time": "2008-01-01", | ||||
|                         "end_time": "2020-08-01", | ||||
|                         "fit_start_time": "2008-01-01", | ||||
|                         "fit_end_time": "2014-12-31", | ||||
|                         "instruments": "csi100", | ||||
|                     }, | ||||
|                 }, | ||||
|                 "segments": { | ||||
|                     "train": ("2008-01-01", "2014-12-31"), | ||||
|                     "valid": ("2015-01-01", "2016-12-31"), | ||||
|                     "test": ("2017-01-01", "2020-08-01"), | ||||
|                 }, | ||||
|             }, | ||||
|         } | ||||
| pprint.pprint(dataset_config) | ||||
| dataset = init_instance_by_config(dataset_config) | ||||
|  | ||||
| df_train, df_valid, df_test = dataset.prepare( | ||||
|             ["train", "valid", "test"], | ||||
|             col_set=["feature", "label"], | ||||
|             data_key=DataHandlerLP.DK_L, | ||||
|         ) | ||||
| model = get_transformer(None) | ||||
| print(model) | ||||
|  | ||||
| features = torch.from_numpy(df_train["feature"].values).float() | ||||
| labels = torch.from_numpy(df_train["label"].values).squeeze().float() | ||||
|  | ||||
| batch = list(range(2000)) | ||||
| predicts = model(features[batch]) | ||||
| mask = ~torch.isnan(labels[batch]) | ||||
|  | ||||
| pred = predicts[mask] | ||||
| label = labels[batch][mask] | ||||
|  | ||||
| loss = torch.nn.functional.mse_loss(pred, label) | ||||
|  | ||||
| from sklearn.metrics import mean_squared_error | ||||
| mse_loss = mean_squared_error(pred.numpy(), label.numpy()) | ||||
		Reference in New Issue
	
	Block a user