diff --git a/notebooks/TOT/ES-Model.ipynb b/notebooks/TOT/ES-Model-DC.ipynb similarity index 95% rename from notebooks/TOT/ES-Model.ipynb rename to notebooks/TOT/ES-Model-DC.ipynb index e71b2a7..20e9d38 100644 --- a/notebooks/TOT/ES-Model.ipynb +++ b/notebooks/TOT/ES-Model-DC.ipynb @@ -18,10 +18,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "[61765:MainThread](2021-04-11 21:23:06,638) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", - "[61765:MainThread](2021-04-11 21:23:06,641) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", - "[61765:MainThread](2021-04-11 21:23:06,643) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", - "[61765:MainThread](2021-04-11 21:23:06,644) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" + "[68147:MainThread](2021-04-12 13:09:24,409) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", + "[68147:MainThread](2021-04-12 13:09:24,411) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", + "[68147:MainThread](2021-04-12 13:09:24,414) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", + "[68147:MainThread](2021-04-12 13:09:24,417) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" ] } ], @@ -142,7 +142,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "[61765:MainThread](2021-04-11 21:23:07,182) INFO - qlib.workflow - [expm.py:290] - \n" + "[68147:MainThread](2021-04-12 13:09:25,066) INFO - qlib.workflow - [expm.py:290] - \n" ] }, { @@ -181,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 6, "id": "supreme-basis", "metadata": {}, "outputs": [], @@ -207,9 +207,9 @@ " depths.append(float(name.split('x')[0]))\n", " channels.append(float(name.split('x')[1]))\n", " if train_or_test:\n", - " ic_values.append(qresult['ICIR (train)'] * 100)\n", + " ic_values.append(qresult['IC (train)'])\n", " else:\n", - " ic_values.append(qresult['ICIR (valid)'] * 100)\n", + " ic_values.append(qresult['IC (valid)'])\n", " xmaps[(depths[-1], channels[-1])] = ic_values[-1]\n", " # cur_ax.scatter(depths, channels, ic_values, marker='o', c=\"tab:orange\")\n", " raw_depths = np.arange(1, 9, dtype=np.int32)\n", @@ -263,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 7, "id": "shared-envelope", "metadata": {}, "outputs": [ diff --git a/notebooks/TOT/ES-Model-Drop.ipynb b/notebooks/TOT/ES-Model-Drop.ipynb new file mode 100644 index 0000000..66d3f60 --- /dev/null +++ b/notebooks/TOT/ES-Model-Drop.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "afraid-minutes", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", + "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[64660:MainThread](2021-04-11 23:57:38,079) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", + "[64660:MainThread](2021-04-11 23:57:38,081) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", + "[64660:MainThread](2021-04-11 23:57:38,083) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", + "[64660:MainThread](2021-04-11 23:57:38,084) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" + ] + } + ], + "source": [ + "#\n", + "# Exhaustive Search Results\n", + "#\n", + "import os\n", + "import re\n", + "import sys\n", + "import qlib\n", + "import pprint\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from pathlib import Path\n", + "\n", + "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", + "root_dir = (Path(__file__).parent / \"..\").resolve()\n", + "lib_dir = (root_dir / \"lib\").resolve()\n", + "print(\"The root path: {:}\".format(root_dir))\n", + "print(\"The library path: {:}\".format(lib_dir))\n", + "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", + "if str(lib_dir) not in sys.path:\n", + " sys.path.insert(0, str(lib_dir))\n", + "\n", + "import qlib\n", + "from qlib import config as qconfig\n", + "from qlib.workflow import R\n", + "qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "hidden-exemption", + "metadata": {}, + "outputs": [], + "source": [ + "from utils.qlib_utils import QResult" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "continental-drain", + "metadata": {}, + "outputs": [], + "source": [ + "def filter_finished(recorders):\n", + " returned_recorders = dict()\n", + " not_finished = 0\n", + " for key, recorder in recorders.items():\n", + " if recorder.status == \"FINISHED\":\n", + " returned_recorders[key] = recorder\n", + " else:\n", + " not_finished += 1\n", + " return returned_recorders, not_finished\n", + "\n", + "def query_info(save_dir, verbose, name_filter, key_map):\n", + " if isinstance(save_dir, list):\n", + " results = []\n", + " for x in save_dir:\n", + " x = query_info(x, verbose, name_filter, key_map)\n", + " results.extend(x)\n", + " return results\n", + " # Here, the save_dir must be a string\n", + " R.set_uri(str(save_dir))\n", + " experiments = R.list_experiments()\n", + "\n", + " if verbose:\n", + " print(\"There are {:} experiments.\".format(len(experiments)))\n", + " qresults = []\n", + " for idx, (key, experiment) in enumerate(experiments.items()):\n", + " if experiment.id == \"0\":\n", + " continue\n", + " if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n", + " continue\n", + " recorders = experiment.list_recorders()\n", + " recorders, not_finished = filter_finished(recorders)\n", + " if verbose:\n", + " print(\n", + " \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n", + " idx + 1,\n", + " len(experiments),\n", + " experiment.name,\n", + " len(recorders),\n", + " len(recorders) + not_finished,\n", + " )\n", + " )\n", + " result = QResult(experiment.name)\n", + " for recorder_id, recorder in recorders.items():\n", + " result.update(recorder.list_metrics(), key_map)\n", + " result.append_path(\n", + " os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n", + " )\n", + " if not len(result):\n", + " print(\"There are no valid recorders for {:}\".format(experiment))\n", + " continue\n", + " else:\n", + " if verbose:\n", + " print(\n", + " \"There are {:} valid recorders for {:}\".format(\n", + " len(recorders), experiment.name\n", + " )\n", + " )\n", + " qresults.append(result)\n", + " return qresults" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "filled-multiple", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[64660:MainThread](2021-04-11 23:57:38,469) INFO - qlib.workflow - [expm.py:290] - \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n" + ] + } + ], + "source": [ + "paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n", + "paths = [path.resolve() for path in paths]\n", + "print(paths)\n", + "\n", + "key_map = dict()\n", + "for xset in (\"train\", \"valid\", \"test\"):\n", + " key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n", + " key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n", + "\n", + "qresults = query_info(paths, False, 'TSF-.*', key_map)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "intimate-approval", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "from matplotlib import cm\n", + "matplotlib.use(\"agg\")\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as ticker" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "supreme-basis", + "metadata": {}, + "outputs": [], + "source": [ + "def vis_dropouts(qresults, basenames, name2suffix, save_path):\n", + " save_dir = (save_path / '..').resolve()\n", + " save_dir.mkdir(parents=True, exist_ok=True)\n", + " print('There are {:} qlib-results'.format(len(qresults)))\n", + " \n", + " name2qresult = dict()\n", + " for qresult in qresults:\n", + " name2qresult[qresult.name] = qresult\n", + " # sort architectures\n", + " accuracies = []\n", + " for basename in basenames:\n", + " qresult = name2qresult[basename + '-drop0_0']\n", + " accuracies.append(qresult['ICIR (train)'])\n", + " sorted_basenames = sorted(basenames, key=lambda x: accuracies[basenames.index(x)])\n", + " \n", + " dpi, width, height = 200, 4000, 2000\n", + " figsize = width / float(dpi), height / float(dpi)\n", + " LabelSize, LegendFontsize = 22, 18\n", + " font_gap = 5\n", + " colors = ['k', 'r']\n", + " markers = ['*', 'o']\n", + " \n", + " fig = plt.figure(figsize=figsize)\n", + " \n", + " def plot_ax(cur_ax, train_or_test):\n", + " for idx, (legend, suffix) in enumerate(name2suffix.items()):\n", + " x_values = list(range(len(sorted_basenames)))\n", + " y_values = []\n", + " for i, name in enumerate(sorted_basenames):\n", + " name = '{:}{:}'.format(name, suffix)\n", + " qresult = name2qresult[name]\n", + " if train_or_test:\n", + " value = qresult['IC (train)']\n", + " else:\n", + " value = qresult['IC (valid)']\n", + " y_values.append(value)\n", + " cur_ax.plot(x_values, y_values, c=colors[idx])\n", + " cur_ax.scatter(x_values, y_values,\n", + " marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n", + " label=legend)\n", + " cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n", + " cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", + " for tick in cur_ax.xaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " for tick in cur_ax.yaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " cur_ax.legend(loc=4, fontsize=LegendFontsize)\n", + " ax = fig.add_subplot(1, 2, 1)\n", + " plot_ax(ax, True)\n", + " ax = fig.add_subplot(1, 2, 2)\n", + " plot_ax(ax, False)\n", + " # fig.tight_layout()\n", + " # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n", + " fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", + " plt.close(\"all\")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "shared-envelope", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'TSF-8x6', 'TSF-6x6', 'TSF-4x24', 'TSF-3x32', 'TSF-7x6', 'TSF-4x12', 'TSF-2x12', 'TSF-1x24', 'TSF-1x32', 'TSF-6x32', 'TSF-7x48', 'TSF-4x6', 'TSF-5x32', 'TSF-6x24', 'TSF-8x24', 'TSF-5x6', 'TSF-3x24', 'TSF-6x12', 'TSF-3x12', 'TSF-5x64', 'TSF-5x12', 'TSF-7x32', 'TSF-6x48', 'TSF-3x64', 'TSF-5x48', 'TSF-7x24', 'TSF-4x32', 'TSF-4x64', 'TSF-2x64', 'TSF-8x12', 'TSF-7x64', 'TSF-3x6', 'TSF-1x6', 'TSF-8x64', 'TSF-2x6', 'TSF-6x64', 'TSF-7x12', 'TSF-2x24', 'TSF-8x48', 'TSF-1x64', 'TSF-4x48', 'TSF-8x32', 'TSF-2x48', 'TSF-1x12', 'TSF-5x24', 'TSF-3x48', 'TSF-2x32', 'TSF-1x48'}\n", + "The Desktop is at: /Users/xuanyidong/Desktop\n", + "There are 104 qlib-results\n" + ] + } + ], + "source": [ + "# Visualization\n", + "names = [qresult.name for qresult in qresults]\n", + "base_names = set()\n", + "for name in names:\n", + " base_name = name.split('-drop')[0]\n", + " base_names.add(base_name)\n", + "print(base_names)\n", + "# filter\n", + "filtered_base_names = set()\n", + "for base_name in base_names:\n", + " if (base_name + '-drop0_0') in names and (base_name + '-drop0.1_0') in names:\n", + " filtered_base_names.add(base_name)\n", + " else:\n", + " print('Cannot find all names for {:}'.format(base_name))\n", + "# print(filtered_base_names)\n", + "home_dir = Path.home()\n", + "desktop_dir = home_dir / 'Desktop'\n", + "print('The Desktop is at: {:}'.format(desktop_dir))\n", + "\n", + "vis_dropouts(qresults, list(filtered_base_names),\n", + " {'No-dropout': '-drop0_0',\n", + " 'Ratio=0.1' : '-drop0.1_0'},\n", + " desktop_dir / 'es_csi300_drop.pdf')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}