311 lines
11 KiB
Plaintext
311 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "afraid-minutes",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
|
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[68147:MainThread](2021-04-12 13:09:24,409) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
|
"[68147:MainThread](2021-04-12 13:09:24,411) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
|
"[68147:MainThread](2021-04-12 13:09:24,414) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
|
"[68147:MainThread](2021-04-12 13:09:24,417) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#\n",
|
|
"# Exhaustive Search Results\n",
|
|
"#\n",
|
|
"import os\n",
|
|
"import re\n",
|
|
"import sys\n",
|
|
"import qlib\n",
|
|
"import pprint\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
|
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
|
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
|
"print(\"The root path: {:}\".format(root_dir))\n",
|
|
"print(\"The library path: {:}\".format(lib_dir))\n",
|
|
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
|
"if str(lib_dir) not in sys.path:\n",
|
|
" sys.path.insert(0, str(lib_dir))\n",
|
|
"\n",
|
|
"import qlib\n",
|
|
"from qlib import config as qconfig\n",
|
|
"from qlib.workflow import R\n",
|
|
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "hidden-exemption",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from utils.qlib_utils import QResult"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "continental-drain",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def filter_finished(recorders):\n",
|
|
" returned_recorders = dict()\n",
|
|
" not_finished = 0\n",
|
|
" for key, recorder in recorders.items():\n",
|
|
" if recorder.status == \"FINISHED\":\n",
|
|
" returned_recorders[key] = recorder\n",
|
|
" else:\n",
|
|
" not_finished += 1\n",
|
|
" return returned_recorders, not_finished\n",
|
|
"\n",
|
|
"def query_info(save_dir, verbose, name_filter, key_map):\n",
|
|
" if isinstance(save_dir, list):\n",
|
|
" results = []\n",
|
|
" for x in save_dir:\n",
|
|
" x = query_info(x, verbose, name_filter, key_map)\n",
|
|
" results.extend(x)\n",
|
|
" return results\n",
|
|
" # Here, the save_dir must be a string\n",
|
|
" R.set_uri(str(save_dir))\n",
|
|
" experiments = R.list_experiments()\n",
|
|
"\n",
|
|
" if verbose:\n",
|
|
" print(\"There are {:} experiments.\".format(len(experiments)))\n",
|
|
" qresults = []\n",
|
|
" for idx, (key, experiment) in enumerate(experiments.items()):\n",
|
|
" if experiment.id == \"0\":\n",
|
|
" continue\n",
|
|
" if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
|
|
" continue\n",
|
|
" recorders = experiment.list_recorders()\n",
|
|
" recorders, not_finished = filter_finished(recorders)\n",
|
|
" if verbose:\n",
|
|
" print(\n",
|
|
" \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
|
|
" idx + 1,\n",
|
|
" len(experiments),\n",
|
|
" experiment.name,\n",
|
|
" len(recorders),\n",
|
|
" len(recorders) + not_finished,\n",
|
|
" )\n",
|
|
" )\n",
|
|
" result = QResult(experiment.name)\n",
|
|
" for recorder_id, recorder in recorders.items():\n",
|
|
" result.update(recorder.list_metrics(), key_map)\n",
|
|
" result.append_path(\n",
|
|
" os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
|
|
" )\n",
|
|
" if not len(result):\n",
|
|
" print(\"There are no valid recorders for {:}\".format(experiment))\n",
|
|
" continue\n",
|
|
" else:\n",
|
|
" if verbose:\n",
|
|
" print(\n",
|
|
" \"There are {:} valid recorders for {:}\".format(\n",
|
|
" len(recorders), experiment.name\n",
|
|
" )\n",
|
|
" )\n",
|
|
" qresults.append(result)\n",
|
|
" return qresults"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "filled-multiple",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[68147:MainThread](2021-04-12 13:09:25,066) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fd449277a30>\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
|
|
"paths = [path.resolve() for path in paths]\n",
|
|
"print(paths)\n",
|
|
"\n",
|
|
"key_map = dict()\n",
|
|
"for xset in (\"train\", \"valid\", \"test\"):\n",
|
|
" key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
|
|
" key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
|
|
"qresults = query_info(paths, False, 'TSF-.*-drop0_0', key_map)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "intimate-approval",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import matplotlib\n",
|
|
"from matplotlib import cm\n",
|
|
"matplotlib.use(\"agg\")\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import matplotlib.ticker as ticker"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "supreme-basis",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def vis_depth_channel(qresults, save_path):\n",
|
|
" save_dir = (save_path / '..').resolve()\n",
|
|
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
" print('There are {:} qlib-results'.format(len(qresults)))\n",
|
|
" \n",
|
|
" dpi, width, height = 200, 4000, 2000\n",
|
|
" figsize = width / float(dpi), height / float(dpi)\n",
|
|
" LabelSize, LegendFontsize = 22, 12\n",
|
|
" font_gap = 5\n",
|
|
" \n",
|
|
" fig = plt.figure(figsize=figsize)\n",
|
|
" # fig, axs = plt.subplots(1, 2, figsize=figsize, projection='3d')\n",
|
|
" \n",
|
|
" def plot_ax(cur_ax, train_or_test):\n",
|
|
" depths, channels = [], []\n",
|
|
" ic_values, xmaps = [], dict()\n",
|
|
" for qresult in qresults:\n",
|
|
" name = qresult.name.split('-')[1]\n",
|
|
" depths.append(float(name.split('x')[0]))\n",
|
|
" channels.append(float(name.split('x')[1]))\n",
|
|
" if train_or_test:\n",
|
|
" ic_values.append(qresult['IC (train)'])\n",
|
|
" else:\n",
|
|
" ic_values.append(qresult['IC (valid)'])\n",
|
|
" xmaps[(depths[-1], channels[-1])] = ic_values[-1]\n",
|
|
" # cur_ax.scatter(depths, channels, ic_values, marker='o', c=\"tab:orange\")\n",
|
|
" raw_depths = np.arange(1, 9, dtype=np.int32)\n",
|
|
" raw_channels = np.array([6, 12, 24, 32, 48, 64], dtype=np.int32)\n",
|
|
" depths, channels = np.meshgrid(raw_depths, raw_channels)\n",
|
|
" ic_values = np.sin(depths) # initialize\n",
|
|
" # print(ic_values.shape)\n",
|
|
" num_x, num_y = ic_values.shape\n",
|
|
" for i in range(num_x):\n",
|
|
" for j in range(num_y):\n",
|
|
" xkey = (int(depths[i][j]), int(channels[i][j]))\n",
|
|
" if xkey not in xmaps:\n",
|
|
" raise ValueError(\"Did not find {:}\".format(xkey))\n",
|
|
" ic_values[i][j] = xmaps[xkey]\n",
|
|
" #print(sorted(list(xmaps.keys())))\n",
|
|
" #surf = cur_ax.plot_surface(\n",
|
|
" # np.array(depths), np.array(channels), np.array(ic_values),\n",
|
|
" # cmap=cm.coolwarm, linewidth=0, antialiased=False)\n",
|
|
" surf = cur_ax.plot_surface(\n",
|
|
" depths, channels, ic_values,\n",
|
|
" cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n",
|
|
" cur_ax.set_xticks(raw_depths)\n",
|
|
" cur_ax.set_yticks(raw_channels)\n",
|
|
" cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n",
|
|
" cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n",
|
|
" cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
|
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
|
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
|
" for tick in cur_ax.zaxis.get_major_ticks():\n",
|
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
|
" # Add a color bar which maps values to colors.\n",
|
|
"# cax = fig.add_axes([cur_ax.get_position().x1 + 0.01,\n",
|
|
"# cur_ax.get_position().y0,\n",
|
|
"# 0.01,\n",
|
|
"# cur_ax.get_position().height * 0.9])\n",
|
|
" # fig.colorbar(surf, cax=cax)\n",
|
|
" # fig.colorbar(surf, shrink=0.5, aspect=5)\n",
|
|
" # import pdb; pdb.set_trace()\n",
|
|
" # ax1.legend(loc=4, fontsize=LegendFontsize)\n",
|
|
" ax = fig.add_subplot(1, 2, 1, projection='3d')\n",
|
|
" plot_ax(ax, True)\n",
|
|
" ax = fig.add_subplot(1, 2, 2, projection='3d')\n",
|
|
" plot_ax(ax, False)\n",
|
|
" # fig.tight_layout()\n",
|
|
" plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
|
|
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
|
" plt.close(\"all\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "shared-envelope",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
|
"There are 48 qlib-results\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Visualization\n",
|
|
"home_dir = Path.home()\n",
|
|
"desktop_dir = home_dir / 'Desktop'\n",
|
|
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
|
|
"\n",
|
|
"vis_depth_channel(qresults, desktop_dir / 'es_csi300_d_vs_c.pdf')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|