Update visualization codes
This commit is contained in:
parent
5f2ba0a8e7
commit
c82c7e9f3f
1
.gitignore
vendored
1
.gitignore
vendored
@ -133,3 +133,4 @@ outputs
|
|||||||
|
|
||||||
pytest_cache
|
pytest_cache
|
||||||
*.pkl
|
*.pkl
|
||||||
|
*.pth
|
||||||
|
@ -64,7 +64,7 @@ def extend_transformer_settings(alg2configs, name):
|
|||||||
config = copy.deepcopy(alg2configs[name])
|
config = copy.deepcopy(alg2configs[name])
|
||||||
for i in range(1, 9):
|
for i in range(1, 9):
|
||||||
for j in (6, 12, 24, 32, 48, 64):
|
for j in (6, 12, 24, 32, 48, 64):
|
||||||
for k1 in (0, 0.1, 0.2, 0.3):
|
for k1 in (0, 0.05, 0.1, 0.2, 0.3):
|
||||||
for k2 in (0, 0.1):
|
for k2 in (0, 0.1):
|
||||||
alg2configs[
|
alg2configs[
|
||||||
name + "-{:}x{:}-drop{:}_{:}".format(i, j, k1, k2)
|
name + "-{:}x{:}-drop{:}_{:}".format(i, j, k1, k2)
|
||||||
|
@ -22,6 +22,7 @@ from qlib.workflow import R
|
|||||||
|
|
||||||
from utils.qlib_utils import QResult
|
from utils.qlib_utils import QResult
|
||||||
|
|
||||||
|
|
||||||
def compare_results(
|
def compare_results(
|
||||||
heads, values, names, space=10, separate="& ", verbose=True, sort_key=False
|
heads, values, names, space=10, separate="& ", verbose=True, sort_key=False
|
||||||
):
|
):
|
||||||
@ -69,7 +70,10 @@ def query_info(save_dir, verbose, name_filter, key_map):
|
|||||||
for idx, (key, experiment) in enumerate(experiments.items()):
|
for idx, (key, experiment) in enumerate(experiments.items()):
|
||||||
if experiment.id == "0":
|
if experiment.id == "0":
|
||||||
continue
|
continue
|
||||||
if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:
|
if (
|
||||||
|
name_filter is not None
|
||||||
|
and re.fullmatch(name_filter, experiment.name) is None
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
recorders = experiment.list_recorders()
|
recorders = experiment.list_recorders()
|
||||||
recorders, not_finished = filter_finished(recorders)
|
recorders, not_finished = filter_finished(recorders)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Text
|
from typing import List, Text
|
||||||
from collections import defaultdict, OrderedDict
|
from collections import defaultdict, OrderedDict
|
||||||
@ -10,6 +11,7 @@ class QResult:
|
|||||||
self._result = defaultdict(list)
|
self._result = defaultdict(list)
|
||||||
self._name = name
|
self._name = name
|
||||||
self._recorder_paths = []
|
self._recorder_paths = []
|
||||||
|
self._date2ICs = []
|
||||||
|
|
||||||
def append(self, key, value):
|
def append(self, key, value):
|
||||||
self._result[key].append(value)
|
self._result[key].append(value)
|
||||||
@ -17,6 +19,25 @@ class QResult:
|
|||||||
def append_path(self, xpath):
|
def append_path(self, xpath):
|
||||||
self._recorder_paths.append(xpath)
|
self._recorder_paths.append(xpath)
|
||||||
|
|
||||||
|
def append_date2ICs(self, date2IC):
|
||||||
|
if self._date2ICs: # not empty
|
||||||
|
keys = sorted(list(date2IC.keys()))
|
||||||
|
pre_keys = sorted(list(self._date2ICs[0].keys()))
|
||||||
|
assert len(keys) == len(pre_keys)
|
||||||
|
for i, (x, y) in enumerate(zip(keys, pre_keys)):
|
||||||
|
assert x == y, "[{:}] {:} vs {:}".format(i, x, y)
|
||||||
|
self._date2ICs.append(date2IC)
|
||||||
|
|
||||||
|
def find_all_dates(self):
|
||||||
|
dates = self._date2ICs[-1].keys()
|
||||||
|
return sorted(list(dates))
|
||||||
|
|
||||||
|
def get_IC_by_date(self, date, scale=1.0):
|
||||||
|
values = []
|
||||||
|
for date2IC in self._date2ICs:
|
||||||
|
values.append(date2IC[date] * scale)
|
||||||
|
return float(np.mean(values)), float(np.std(values))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
|
@ -18,10 +18,10 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"[68147:MainThread](2021-04-12 13:09:24,409) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
"[70148:MainThread](2021-04-12 13:23:30,262) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
||||||
"[68147:MainThread](2021-04-12 13:09:24,411) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
"[70148:MainThread](2021-04-12 13:23:30,266) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||||
"[68147:MainThread](2021-04-12 13:09:24,414) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
"[70148:MainThread](2021-04-12 13:23:30,269) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
||||||
"[68147:MainThread](2021-04-12 13:09:24,417) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
"[70148:MainThread](2021-04-12 13:23:30,271) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -142,7 +142,7 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"[68147:MainThread](2021-04-12 13:09:25,066) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fd449277a30>\n"
|
"[70148:MainThread](2021-04-12 13:23:31,137) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7f8c4a47efa0>\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -233,6 +233,7 @@
|
|||||||
" cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n",
|
" cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n",
|
||||||
" cur_ax.set_xticks(raw_depths)\n",
|
" cur_ax.set_xticks(raw_depths)\n",
|
||||||
" cur_ax.set_yticks(raw_channels)\n",
|
" cur_ax.set_yticks(raw_channels)\n",
|
||||||
|
" cur_ax.set_zticks(np.arange(4, 11, 2))\n",
|
||||||
" cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n",
|
" cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n",
|
||||||
" cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n",
|
" cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n",
|
||||||
" cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
|
" cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
|
||||||
|
@ -18,10 +18,10 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"[64660:MainThread](2021-04-11 23:57:38,079) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
"[70363:MainThread](2021-04-12 13:25:01,065) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
||||||
"[64660:MainThread](2021-04-11 23:57:38,081) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
"[70363:MainThread](2021-04-12 13:25:01,069) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||||
"[64660:MainThread](2021-04-11 23:57:38,083) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
"[70363:MainThread](2021-04-12 13:25:01,085) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
||||||
"[64660:MainThread](2021-04-11 23:57:38,084) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
"[70363:MainThread](2021-04-12 13:25:01,092) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -142,7 +142,7 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"[64660:MainThread](2021-04-11 23:57:38,469) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fba2bc7df70>\n"
|
"[70363:MainThread](2021-04-12 13:25:01,647) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fa920e56820>\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -182,7 +182,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 27,
|
"execution_count": 8,
|
||||||
"id": "supreme-basis",
|
"id": "supreme-basis",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -204,7 +204,7 @@
|
|||||||
" \n",
|
" \n",
|
||||||
" dpi, width, height = 200, 4000, 2000\n",
|
" dpi, width, height = 200, 4000, 2000\n",
|
||||||
" figsize = width / float(dpi), height / float(dpi)\n",
|
" figsize = width / float(dpi), height / float(dpi)\n",
|
||||||
" LabelSize, LegendFontsize = 22, 18\n",
|
" LabelSize, LegendFontsize = 22, 22\n",
|
||||||
" font_gap = 5\n",
|
" font_gap = 5\n",
|
||||||
" colors = ['k', 'r']\n",
|
" colors = ['k', 'r']\n",
|
||||||
" markers = ['*', 'o']\n",
|
" markers = ['*', 'o']\n",
|
||||||
@ -227,6 +227,7 @@
|
|||||||
" cur_ax.scatter(x_values, y_values,\n",
|
" cur_ax.scatter(x_values, y_values,\n",
|
||||||
" marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n",
|
" marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n",
|
||||||
" label=legend)\n",
|
" label=legend)\n",
|
||||||
|
" cur_ax.set_yticks(np.arange(4, 11, 2))\n",
|
||||||
" cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n",
|
" cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n",
|
||||||
" cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
|
" cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
|
||||||
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||||
@ -246,7 +247,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 28,
|
"execution_count": 9,
|
||||||
"id": "shared-envelope",
|
"id": "shared-envelope",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -254,7 +255,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"{'TSF-8x6', 'TSF-6x6', 'TSF-4x24', 'TSF-3x32', 'TSF-7x6', 'TSF-4x12', 'TSF-2x12', 'TSF-1x24', 'TSF-1x32', 'TSF-6x32', 'TSF-7x48', 'TSF-4x6', 'TSF-5x32', 'TSF-6x24', 'TSF-8x24', 'TSF-5x6', 'TSF-3x24', 'TSF-6x12', 'TSF-3x12', 'TSF-5x64', 'TSF-5x12', 'TSF-7x32', 'TSF-6x48', 'TSF-3x64', 'TSF-5x48', 'TSF-7x24', 'TSF-4x32', 'TSF-4x64', 'TSF-2x64', 'TSF-8x12', 'TSF-7x64', 'TSF-3x6', 'TSF-1x6', 'TSF-8x64', 'TSF-2x6', 'TSF-6x64', 'TSF-7x12', 'TSF-2x24', 'TSF-8x48', 'TSF-1x64', 'TSF-4x48', 'TSF-8x32', 'TSF-2x48', 'TSF-1x12', 'TSF-5x24', 'TSF-3x48', 'TSF-2x32', 'TSF-1x48'}\n",
|
"{'TSF-3x48', 'TSF-2x64', 'TSF-2x12', 'TSF-8x48', 'TSF-6x32', 'TSF-4x48', 'TSF-8x6', 'TSF-4x6', 'TSF-2x32', 'TSF-5x12', 'TSF-5x64', 'TSF-1x64', 'TSF-2x24', 'TSF-8x24', 'TSF-4x12', 'TSF-6x12', 'TSF-1x32', 'TSF-5x32', 'TSF-3x24', 'TSF-8x12', 'TSF-5x48', 'TSF-6x64', 'TSF-7x64', 'TSF-7x48', 'TSF-1x6', 'TSF-2x48', 'TSF-7x24', 'TSF-3x32', 'TSF-1x24', 'TSF-4x64', 'TSF-3x12', 'TSF-8x64', 'TSF-4x32', 'TSF-5x6', 'TSF-7x6', 'TSF-7x12', 'TSF-3x6', 'TSF-4x24', 'TSF-6x48', 'TSF-6x6', 'TSF-1x48', 'TSF-1x12', 'TSF-7x32', 'TSF-5x24', 'TSF-2x6', 'TSF-6x24', 'TSF-3x64', 'TSF-8x32'}\n",
|
||||||
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
||||||
"There are 104 qlib-results\n"
|
"There are 104 qlib-results\n"
|
||||||
]
|
]
|
||||||
|
208
notebooks/TOT/Time-Curve.ipynb
Normal file
208
notebooks/TOT/Time-Curve.ipynb
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "afraid-minutes",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
||||||
|
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import re\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import pprint\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from scipy.interpolate import make_interp_spline\n",
|
||||||
|
"\n",
|
||||||
|
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||||||
|
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
||||||
|
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
||||||
|
"print(\"The root path: {:}\".format(root_dir))\n",
|
||||||
|
"print(\"The library path: {:}\".format(lib_dir))\n",
|
||||||
|
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||||||
|
"if str(lib_dir) not in sys.path:\n",
|
||||||
|
" sys.path.insert(0, str(lib_dir))\n",
|
||||||
|
"from utils.qlib_utils import QResult"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "continental-drain",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"TSF-2x24-drop0_0s2013-01-01\n",
|
||||||
|
"TSF-2x24-drop0_0s2012-01-01\n",
|
||||||
|
"TSF-2x24-drop0_0s2008-01-01\n",
|
||||||
|
"TSF-2x24-drop0_0s2009-01-01\n",
|
||||||
|
"TSF-2x24-drop0_0s2010-01-01\n",
|
||||||
|
"TSF-2x24-drop0_0s2011-01-01\n",
|
||||||
|
"TSF-2x24-drop0_0s2008-07-01\n",
|
||||||
|
"TSF-2x24-drop0_0s2009-07-01\n",
|
||||||
|
"There are 3011 dates\n",
|
||||||
|
"Dates: 2008-01-02 2008-01-03\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n",
|
||||||
|
"for qresult in qresults:\n",
|
||||||
|
" print(qresult.name)\n",
|
||||||
|
"all_dates = set()\n",
|
||||||
|
"for qresult in qresults:\n",
|
||||||
|
" dates = qresult.find_all_dates()\n",
|
||||||
|
" for date in dates:\n",
|
||||||
|
" all_dates.add(date)\n",
|
||||||
|
"all_dates = sorted(list(all_dates))\n",
|
||||||
|
"print('There are {:} dates'.format(len(all_dates)))\n",
|
||||||
|
"print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "intimate-approval",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import matplotlib\n",
|
||||||
|
"from matplotlib import cm\n",
|
||||||
|
"matplotlib.use(\"agg\")\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import matplotlib.ticker as ticker"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "supreme-basis",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def vis_time_curve(qresults, dates, use_original, save_path):\n",
|
||||||
|
" save_dir = (save_path / '..').resolve()\n",
|
||||||
|
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
" print('There are {:} qlib-results'.format(len(qresults)))\n",
|
||||||
|
" \n",
|
||||||
|
" dpi, width, height = 200, 5000, 2000\n",
|
||||||
|
" figsize = width / float(dpi), height / float(dpi)\n",
|
||||||
|
" LabelSize, LegendFontsize = 22, 12\n",
|
||||||
|
" font_gap = 5\n",
|
||||||
|
" linestyles = ['-', '--']\n",
|
||||||
|
" colors = ['k', 'r']\n",
|
||||||
|
" \n",
|
||||||
|
" fig = plt.figure(figsize=figsize)\n",
|
||||||
|
" cur_ax = fig.add_subplot(1, 1, 1)\n",
|
||||||
|
" for idx, qresult in enumerate(qresults):\n",
|
||||||
|
" print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n",
|
||||||
|
" x_axis, y_axis = [], []\n",
|
||||||
|
" for idate, date in enumerate(dates):\n",
|
||||||
|
" if date in qresult._date2ICs[-1]:\n",
|
||||||
|
" mean, std = qresult.get_IC_by_date(date, 100)\n",
|
||||||
|
" if not np.isnan(mean):\n",
|
||||||
|
" x_axis.append(idate)\n",
|
||||||
|
" y_axis.append(mean)\n",
|
||||||
|
" x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n",
|
||||||
|
" if use_original:\n",
|
||||||
|
" cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n",
|
||||||
|
" else:\n",
|
||||||
|
" xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n",
|
||||||
|
" spl = make_interp_spline(x_axis, y_axis, k=5)\n",
|
||||||
|
" ynew = spl(xnew)\n",
|
||||||
|
" cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n",
|
||||||
|
" \n",
|
||||||
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n",
|
||||||
|
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
||||||
|
" plt.close(\"all\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "shared-envelope",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
||||||
|
"There are 2 qlib-results\n",
|
||||||
|
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
|
||||||
|
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n",
|
||||||
|
"There are 2 qlib-results\n",
|
||||||
|
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
|
||||||
|
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Visualization\n",
|
||||||
|
"home_dir = Path.home()\n",
|
||||||
|
"desktop_dir = home_dir / 'Desktop'\n",
|
||||||
|
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
|
||||||
|
"\n",
|
||||||
|
"vis_time_curve(\n",
|
||||||
|
" (qresults[2], qresults[-1]),\n",
|
||||||
|
" all_dates,\n",
|
||||||
|
" True,\n",
|
||||||
|
" desktop_dir / 'es_csi300_time_curve.pdf')\n",
|
||||||
|
"\n",
|
||||||
|
"vis_time_curve(\n",
|
||||||
|
" (qresults[2], qresults[-1]),\n",
|
||||||
|
" all_dates,\n",
|
||||||
|
" False,\n",
|
||||||
|
" desktop_dir / 'es_csi300_time_curve-inter.pdf')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "exempt-stable",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.8"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
128
notebooks/TOT/synthetic.ipynb
Normal file
128
notebooks/TOT/synthetic.ipynb
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "filled-multiple",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#\n",
|
||||||
|
"# %matplotlib notebook\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import matplotlib\n",
|
||||||
|
"from matplotlib import cm\n",
|
||||||
|
"matplotlib.use(\"agg\")\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import matplotlib.ticker as ticker"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "supreme-basis",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def visualize_syn(save_path):\n",
|
||||||
|
" save_dir = (save_path / '..').resolve()\n",
|
||||||
|
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
" \n",
|
||||||
|
" dpi, width, height = 50, 2000, 1000\n",
|
||||||
|
" figsize = width / float(dpi), height / float(dpi)\n",
|
||||||
|
" LabelSize, font_gap = 30, 4\n",
|
||||||
|
" \n",
|
||||||
|
" fig = plt.figure(figsize=figsize)\n",
|
||||||
|
" \n",
|
||||||
|
" times = np.arange(0, np.pi * 100, 0.1)\n",
|
||||||
|
" num = len(times)\n",
|
||||||
|
" x = []\n",
|
||||||
|
" for i in range(num):\n",
|
||||||
|
" scale = (i + 1.) / num * 4\n",
|
||||||
|
" value = times[i] * scale\n",
|
||||||
|
" x.append(np.sin(value) * (1.3 - scale))\n",
|
||||||
|
" x = np.array(x)\n",
|
||||||
|
" y = np.cos( x * x - 0.3 * x )\n",
|
||||||
|
" \n",
|
||||||
|
" cur_ax = fig.add_subplot(2, 1, 1)\n",
|
||||||
|
" cur_ax.plot(times, x)\n",
|
||||||
|
" cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n",
|
||||||
|
" cur_ax.set_ylabel(\"x\", fontsize=LabelSize)\n",
|
||||||
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" tick.label.set_rotation(30)\n",
|
||||||
|
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
" cur_ax = fig.add_subplot(2, 1, 2)\n",
|
||||||
|
" cur_ax.plot(times, y)\n",
|
||||||
|
" cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n",
|
||||||
|
" cur_ax.set_ylabel(\"f(x)\", fontsize=LabelSize)\n",
|
||||||
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" tick.label.set_rotation(30)\n",
|
||||||
|
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" \n",
|
||||||
|
" # fig.tight_layout()\n",
|
||||||
|
" # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
|
||||||
|
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
||||||
|
" plt.close(\"all\")\n",
|
||||||
|
" # plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "shared-envelope",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The Desktop is at: /Users/xuanyidong/Desktop\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Visualization\n",
|
||||||
|
"home_dir = Path.home()\n",
|
||||||
|
"desktop_dir = home_dir / 'Desktop'\n",
|
||||||
|
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
|
||||||
|
"visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "romantic-ordinance",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.8"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
123
notebooks/TOT/time-curve.py
Normal file
123
notebooks/TOT/time-curve.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import qlib
|
||||||
|
import pprint
|
||||||
|
from collections import OrderedDict
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# __file__ = os.path.dirname(os.path.realpath("__file__"))
|
||||||
|
note_dir = Path(__file__).parent.resolve()
|
||||||
|
root_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||||
|
lib_dir = (root_dir / "lib").resolve()
|
||||||
|
print("The root path: {:}".format(root_dir))
|
||||||
|
print("The library path: {:}".format(lib_dir))
|
||||||
|
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
|
||||||
|
if str(lib_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
import qlib
|
||||||
|
from qlib import config as qconfig
|
||||||
|
from qlib.workflow import R
|
||||||
|
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)
|
||||||
|
|
||||||
|
from utils.qlib_utils import QResult
|
||||||
|
|
||||||
|
def filter_finished(recorders):
|
||||||
|
returned_recorders = dict()
|
||||||
|
not_finished = 0
|
||||||
|
for key, recorder in recorders.items():
|
||||||
|
if recorder.status == "FINISHED":
|
||||||
|
returned_recorders[key] = recorder
|
||||||
|
else:
|
||||||
|
not_finished += 1
|
||||||
|
return returned_recorders, not_finished
|
||||||
|
|
||||||
|
|
||||||
|
def add_to_dict(xdict, timestamp, value):
|
||||||
|
date = timestamp.date().strftime("%Y-%m-%d")
|
||||||
|
if date in xdict:
|
||||||
|
raise ValueError("This date [{:}] is already in the dict".format(date))
|
||||||
|
xdict[date] = value
|
||||||
|
|
||||||
|
def query_info(save_dir, verbose, name_filter, key_map):
|
||||||
|
if isinstance(save_dir, list):
|
||||||
|
results = []
|
||||||
|
for x in save_dir:
|
||||||
|
x = query_info(x, verbose, name_filter, key_map)
|
||||||
|
results.extend(x)
|
||||||
|
return results
|
||||||
|
# Here, the save_dir must be a string
|
||||||
|
R.set_uri(str(save_dir))
|
||||||
|
experiments = R.list_experiments()
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("There are {:} experiments.".format(len(experiments)))
|
||||||
|
qresults = []
|
||||||
|
for idx, (key, experiment) in enumerate(experiments.items()):
|
||||||
|
if experiment.id == "0":
|
||||||
|
continue
|
||||||
|
if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:
|
||||||
|
continue
|
||||||
|
recorders = experiment.list_recorders()
|
||||||
|
recorders, not_finished = filter_finished(recorders)
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format(
|
||||||
|
idx + 1,
|
||||||
|
len(experiments),
|
||||||
|
experiment.name,
|
||||||
|
len(recorders),
|
||||||
|
len(recorders) + not_finished,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = QResult(experiment.name)
|
||||||
|
for recorder_id, recorder in recorders.items():
|
||||||
|
file_names = ['results-train.pkl', 'results-valid.pkl', 'results-test.pkl']
|
||||||
|
date2IC = OrderedDict()
|
||||||
|
for file_name in file_names:
|
||||||
|
xtemp = recorder.load_object(file_name)['all-IC']
|
||||||
|
timestamps, values = xtemp.index.tolist(), xtemp.tolist()
|
||||||
|
for timestamp, value in zip(timestamps, values):
|
||||||
|
add_to_dict(date2IC, timestamp, value)
|
||||||
|
result.update(recorder.list_metrics(), key_map)
|
||||||
|
result.append_path(
|
||||||
|
os.path.join(recorder.uri, recorder.experiment_id, recorder.id)
|
||||||
|
)
|
||||||
|
result.append_date2ICs(date2IC)
|
||||||
|
if not len(result):
|
||||||
|
print("There are no valid recorders for {:}".format(experiment))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
"There are {:} valid recorders for {:}".format(
|
||||||
|
len(recorders), experiment.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
qresults.append(result)
|
||||||
|
return qresults
|
||||||
|
|
||||||
|
|
||||||
|
##
|
||||||
|
paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']
|
||||||
|
paths = [path.resolve() for path in paths]
|
||||||
|
print(paths)
|
||||||
|
|
||||||
|
key_map = dict()
|
||||||
|
for xset in ("train", "valid", "test"):
|
||||||
|
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
|
||||||
|
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
|
||||||
|
qresults = query_info(paths, False, 'TSF-2x24-drop0_0s.*-.*-01', key_map)
|
||||||
|
print('Find {:} results'.format(len(qresults)))
|
||||||
|
times = []
|
||||||
|
for qresult in qresults:
|
||||||
|
times.append(qresult.name.split('0_0s')[-1])
|
||||||
|
print(times)
|
||||||
|
save_path = os.path.join(note_dir, 'temp-time-x.pth')
|
||||||
|
torch.save(qresults, save_path)
|
||||||
|
print(save_path)
|
Loading…
Reference in New Issue
Block a user