Updates
This commit is contained in:
parent
dd27f15e27
commit
01397660de
2
.gitignore
vendored
2
.gitignore
vendored
@ -128,7 +128,7 @@ TEMP-L.sh
|
||||
|
||||
# Visual Studio Code
|
||||
.vscode
|
||||
mlruns
|
||||
mlruns*
|
||||
outputs
|
||||
|
||||
pytest_cache
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit ba56e4071efd1c08003eaf7e23978aaf81376dd1
|
||||
Subproject commit 2b74b4dfa4a6996ab6135873c0329022a1b9626b
|
@ -98,6 +98,7 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
|
||||
R.save_objects(**{"model.pkl": model})
|
||||
|
||||
# Generate records: prediction, backtest, and analysis
|
||||
import pdb; pdb.set_trace()
|
||||
for record in task_config["record"]:
|
||||
record = record.copy()
|
||||
if record["class"] == "SignalRecord":
|
||||
|
@ -1,75 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import ruamel.yaml as yaml
|
||||
import pprint
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
from qlib import config as qconfig
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha360",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi300",
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}},
|
||||
{"class": "Fillna", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSRankNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
"label": ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
qlib_root_dir = (Path(__file__).parent / '..' / '..' / '.latent-data' / 'qlib').resolve()
|
||||
demo_yaml_path = qlib_root_dir / 'examples' / 'benchmarks' / 'GRU' / 'workflow_config_gru_Alpha360.yaml'
|
||||
print('Demo-workflow-yaml: {:}'.format(demo_yaml_path))
|
||||
with open(demo_yaml_path, 'r') as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
pprint.pprint(config['task']['dataset'])
|
||||
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
pprint.pprint(dataset_config)
|
||||
pprint.pprint(dataset)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
print("Complete")
|
||||
|
75
notebooks/spaces/test.py
Normal file
75
notebooks/spaces/test.py
Normal file
@ -0,0 +1,75 @@
|
||||
import os
|
||||
import sys
|
||||
import qlib
|
||||
import pprint
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
__file__ = os.path.dirname(os.path.realpath("__file__"))
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||
print("library path: {:}".format(lib_dir))
|
||||
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from trade_models import get_transformer
|
||||
|
||||
from qlib import config as qconfig
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha360",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
pprint.pprint(dataset_config)
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
model = get_transformer(None)
|
||||
print(model)
|
||||
|
||||
features = torch.from_numpy(df_train["feature"].values).float()
|
||||
labels = torch.from_numpy(df_train["label"].values).squeeze().float()
|
||||
|
||||
batch = list(range(2000))
|
||||
predicts = model(features[batch])
|
||||
mask = ~torch.isnan(labels[batch])
|
||||
|
||||
pred = predicts[mask]
|
||||
label = labels[batch][mask]
|
||||
|
||||
loss = torch.nn.functional.mse_loss(pred, label)
|
||||
|
||||
from sklearn.metrics import mean_squared_error
|
||||
mse_loss = mean_squared_error(pred.numpy(), label.numpy())
|
Loading…
Reference in New Issue
Block a user