Updates
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -128,7 +128,7 @@ TEMP-L.sh | ||||
|  | ||||
| # Visual Studio Code | ||||
| .vscode | ||||
| mlruns | ||||
| mlruns* | ||||
| outputs | ||||
|  | ||||
| pytest_cache | ||||
|   | ||||
 Submodule .latent-data/qlib updated: ba56e4071e...2b74b4dfa4
									
								
							| @@ -98,6 +98,7 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|         R.save_objects(**{"model.pkl": model}) | ||||
|  | ||||
|         # Generate records: prediction, backtest, and analysis | ||||
|         import pdb; pdb.set_trace() | ||||
|         for record in task_config["record"]: | ||||
|             record = record.copy() | ||||
|             if record["class"] == "SignalRecord": | ||||
|   | ||||
| @@ -1,75 +0,0 @@ | ||||
| import os | ||||
| import sys | ||||
| import ruamel.yaml as yaml | ||||
| import pprint | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| from pathlib import Path | ||||
|  | ||||
| import qlib | ||||
| from qlib import config as qconfig | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
| qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) | ||||
|  | ||||
| dataset_config = { | ||||
|     "class": "DatasetH", | ||||
|     "module_path": "qlib.data.dataset", | ||||
|     "kwargs": { | ||||
|         "handler": { | ||||
|             "class": "Alpha360", | ||||
|             "module_path": "qlib.contrib.data.handler", | ||||
|             "kwargs": { | ||||
|                 "start_time": "2008-01-01", | ||||
|                 "end_time": "2020-08-01", | ||||
|                 "fit_start_time": "2008-01-01", | ||||
|                 "fit_end_time": "2014-12-31", | ||||
|                 "instruments": "csi300", | ||||
|                 "infer_processors": [ | ||||
|                     {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}}, | ||||
|                     {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, | ||||
|                 ], | ||||
|                 "learn_processors": [ | ||||
|                     {"class": "DropnaLabel"}, | ||||
|                     {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, | ||||
|                 ], | ||||
|                "label": ["Ref($close, -2) / Ref($close, -1) - 1"] | ||||
|             }, | ||||
|         }, | ||||
|         "segments": { | ||||
|             "train": ("2008-01-01", "2014-12-31"), | ||||
|             "valid": ("2015-01-01", "2016-12-31"), | ||||
|             "test": ("2017-01-01", "2020-08-01"), | ||||
|         }, | ||||
|     }, | ||||
| } | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     qlib_root_dir = (Path(__file__).parent / '..' / '..' / '.latent-data' / 'qlib').resolve() | ||||
|     demo_yaml_path = qlib_root_dir / 'examples' / 'benchmarks' / 'GRU' / 'workflow_config_gru_Alpha360.yaml' | ||||
|     print('Demo-workflow-yaml: {:}'.format(demo_yaml_path)) | ||||
|     with open(demo_yaml_path, 'r') as fp: | ||||
|       config = yaml.safe_load(fp) | ||||
|     pprint.pprint(config['task']['dataset']) | ||||
|      | ||||
|     dataset = init_instance_by_config(dataset_config) | ||||
|     pprint.pprint(dataset_config) | ||||
|     pprint.pprint(dataset) | ||||
|  | ||||
|     df_train, df_valid, df_test = dataset.prepare( | ||||
|         ["train", "valid", "test"], | ||||
|         col_set=["feature", "label"], | ||||
|         data_key=DataHandlerLP.DK_L, | ||||
|     ) | ||||
|  | ||||
|     x_train, y_train = df_train["feature"], df_train["label"] | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|     print("Complete") | ||||
|  | ||||
							
								
								
									
										75
									
								
								notebooks/spaces/test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								notebooks/spaces/test.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | ||||
| import os | ||||
| import sys | ||||
| import qlib | ||||
| import pprint | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
|  | ||||
| from pathlib import Path | ||||
| import torch | ||||
|  | ||||
| __file__ = os.path.dirname(os.path.realpath("__file__")) | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||
| print("library path: {:}".format(lib_dir)) | ||||
| assert lib_dir.exists(), "{:} does not exist".format(lib_dir) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from trade_models import get_transformer | ||||
|  | ||||
| from qlib import config as qconfig | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.model.base import Model | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
| qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) | ||||
|  | ||||
| dataset_config = { | ||||
|             "class": "DatasetH", | ||||
|             "module_path": "qlib.data.dataset", | ||||
|             "kwargs": { | ||||
|                 "handler": { | ||||
|                     "class": "Alpha360", | ||||
|                     "module_path": "qlib.contrib.data.handler", | ||||
|                     "kwargs": { | ||||
|                         "start_time": "2008-01-01", | ||||
|                         "end_time": "2020-08-01", | ||||
|                         "fit_start_time": "2008-01-01", | ||||
|                         "fit_end_time": "2014-12-31", | ||||
|                         "instruments": "csi100", | ||||
|                     }, | ||||
|                 }, | ||||
|                 "segments": { | ||||
|                     "train": ("2008-01-01", "2014-12-31"), | ||||
|                     "valid": ("2015-01-01", "2016-12-31"), | ||||
|                     "test": ("2017-01-01", "2020-08-01"), | ||||
|                 }, | ||||
|             }, | ||||
|         } | ||||
| pprint.pprint(dataset_config) | ||||
| dataset = init_instance_by_config(dataset_config) | ||||
|  | ||||
| df_train, df_valid, df_test = dataset.prepare( | ||||
|             ["train", "valid", "test"], | ||||
|             col_set=["feature", "label"], | ||||
|             data_key=DataHandlerLP.DK_L, | ||||
|         ) | ||||
| model = get_transformer(None) | ||||
| print(model) | ||||
|  | ||||
| features = torch.from_numpy(df_train["feature"].values).float() | ||||
| labels = torch.from_numpy(df_train["label"].values).squeeze().float() | ||||
|  | ||||
| batch = list(range(2000)) | ||||
| predicts = model(features[batch]) | ||||
| mask = ~torch.isnan(labels[batch]) | ||||
|  | ||||
| pred = predicts[mask] | ||||
| label = labels[batch][mask] | ||||
|  | ||||
| loss = torch.nn.functional.mse_loss(pred, label) | ||||
|  | ||||
| from sklearn.metrics import mean_squared_error | ||||
| mse_loss = mean_squared_error(pred.numpy(), label.numpy()) | ||||
		Reference in New Issue
	
	Block a user