diff --git a/.gitignore b/.gitignore index be684fb..20a318b 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,4 @@ outputs pytest_cache *.pkl +*.pth diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index 2c9343c..ecbe246 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -64,7 +64,7 @@ def extend_transformer_settings(alg2configs, name): config = copy.deepcopy(alg2configs[name]) for i in range(1, 9): for j in (6, 12, 24, 32, 48, 64): - for k1 in (0, 0.1, 0.2, 0.3): + for k1 in (0, 0.05, 0.1, 0.2, 0.3): for k2 in (0, 0.1): alg2configs[ name + "-{:}x{:}-drop{:}_{:}".format(i, j, k1, k2) diff --git a/exps/trading/organize_results.py b/exps/trading/organize_results.py index a291e00..1f88d30 100644 --- a/exps/trading/organize_results.py +++ b/exps/trading/organize_results.py @@ -22,6 +22,7 @@ from qlib.workflow import R from utils.qlib_utils import QResult + def compare_results( heads, values, names, space=10, separate="& ", verbose=True, sort_key=False ): @@ -69,7 +70,10 @@ def query_info(save_dir, verbose, name_filter, key_map): for idx, (key, experiment) in enumerate(experiments.items()): if experiment.id == "0": continue - if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None: + if ( + name_filter is not None + and re.fullmatch(name_filter, experiment.name) is None + ): continue recorders = experiment.list_recorders() recorders, not_finished = filter_finished(recorders) diff --git a/lib/utils/qlib_utils.py b/lib/utils/qlib_utils.py index f110141..35ab154 100644 --- a/lib/utils/qlib_utils.py +++ b/lib/utils/qlib_utils.py @@ -1,3 +1,4 @@ +import os import numpy as np from typing import List, Text from collections import defaultdict, OrderedDict @@ -10,6 +11,7 @@ class QResult: self._result = defaultdict(list) self._name = name self._recorder_paths = [] + self._date2ICs = [] def append(self, key, value): self._result[key].append(value) @@ -17,6 +19,25 @@ class QResult: def append_path(self, xpath): self._recorder_paths.append(xpath) + def append_date2ICs(self, date2IC): + if self._date2ICs: # not empty + keys = sorted(list(date2IC.keys())) + pre_keys = sorted(list(self._date2ICs[0].keys())) + assert len(keys) == len(pre_keys) + for i, (x, y) in enumerate(zip(keys, pre_keys)): + assert x == y, "[{:}] {:} vs {:}".format(i, x, y) + self._date2ICs.append(date2IC) + + def find_all_dates(self): + dates = self._date2ICs[-1].keys() + return sorted(list(dates)) + + def get_IC_by_date(self, date, scale=1.0): + values = [] + for date2IC in self._date2ICs: + values.append(date2IC[date] * scale) + return float(np.mean(values)), float(np.std(values)) + @property def name(self): return self._name diff --git a/notebooks/TOT/ES-Model-DC.ipynb b/notebooks/TOT/ES-Model-DC.ipynb index 20e9d38..2a5a5bd 100644 --- a/notebooks/TOT/ES-Model-DC.ipynb +++ b/notebooks/TOT/ES-Model-DC.ipynb @@ -18,10 +18,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "[68147:MainThread](2021-04-12 13:09:24,409) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", - "[68147:MainThread](2021-04-12 13:09:24,411) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", - "[68147:MainThread](2021-04-12 13:09:24,414) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", - "[68147:MainThread](2021-04-12 13:09:24,417) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" + "[70148:MainThread](2021-04-12 13:23:30,262) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", + "[70148:MainThread](2021-04-12 13:23:30,266) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", + "[70148:MainThread](2021-04-12 13:23:30,269) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", + "[70148:MainThread](2021-04-12 13:23:30,271) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" ] } ], @@ -142,7 +142,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "[68147:MainThread](2021-04-12 13:09:25,066) INFO - qlib.workflow - [expm.py:290] - \n" + "[70148:MainThread](2021-04-12 13:23:31,137) INFO - qlib.workflow - [expm.py:290] - \n" ] }, { @@ -233,6 +233,7 @@ " cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n", " cur_ax.set_xticks(raw_depths)\n", " cur_ax.set_yticks(raw_channels)\n", + " cur_ax.set_zticks(np.arange(4, 11, 2))\n", " cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n", " cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n", " cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", diff --git a/notebooks/TOT/ES-Model-Drop.ipynb b/notebooks/TOT/ES-Model-Drop.ipynb index 66d3f60..0544b0d 100644 --- a/notebooks/TOT/ES-Model-Drop.ipynb +++ b/notebooks/TOT/ES-Model-Drop.ipynb @@ -18,10 +18,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "[64660:MainThread](2021-04-11 23:57:38,079) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", - "[64660:MainThread](2021-04-11 23:57:38,081) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", - "[64660:MainThread](2021-04-11 23:57:38,083) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", - "[64660:MainThread](2021-04-11 23:57:38,084) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" + "[70363:MainThread](2021-04-12 13:25:01,065) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", + "[70363:MainThread](2021-04-12 13:25:01,069) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", + "[70363:MainThread](2021-04-12 13:25:01,085) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", + "[70363:MainThread](2021-04-12 13:25:01,092) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" ] } ], @@ -142,7 +142,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "[64660:MainThread](2021-04-11 23:57:38,469) INFO - qlib.workflow - [expm.py:290] - \n" + "[70363:MainThread](2021-04-12 13:25:01,647) INFO - qlib.workflow - [expm.py:290] - \n" ] }, { @@ -182,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 8, "id": "supreme-basis", "metadata": {}, "outputs": [], @@ -204,7 +204,7 @@ " \n", " dpi, width, height = 200, 4000, 2000\n", " figsize = width / float(dpi), height / float(dpi)\n", - " LabelSize, LegendFontsize = 22, 18\n", + " LabelSize, LegendFontsize = 22, 22\n", " font_gap = 5\n", " colors = ['k', 'r']\n", " markers = ['*', 'o']\n", @@ -227,6 +227,7 @@ " cur_ax.scatter(x_values, y_values,\n", " marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n", " label=legend)\n", + " cur_ax.set_yticks(np.arange(4, 11, 2))\n", " cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n", " cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", " for tick in cur_ax.xaxis.get_major_ticks():\n", @@ -246,7 +247,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 9, "id": "shared-envelope", "metadata": {}, "outputs": [ @@ -254,7 +255,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'TSF-8x6', 'TSF-6x6', 'TSF-4x24', 'TSF-3x32', 'TSF-7x6', 'TSF-4x12', 'TSF-2x12', 'TSF-1x24', 'TSF-1x32', 'TSF-6x32', 'TSF-7x48', 'TSF-4x6', 'TSF-5x32', 'TSF-6x24', 'TSF-8x24', 'TSF-5x6', 'TSF-3x24', 'TSF-6x12', 'TSF-3x12', 'TSF-5x64', 'TSF-5x12', 'TSF-7x32', 'TSF-6x48', 'TSF-3x64', 'TSF-5x48', 'TSF-7x24', 'TSF-4x32', 'TSF-4x64', 'TSF-2x64', 'TSF-8x12', 'TSF-7x64', 'TSF-3x6', 'TSF-1x6', 'TSF-8x64', 'TSF-2x6', 'TSF-6x64', 'TSF-7x12', 'TSF-2x24', 'TSF-8x48', 'TSF-1x64', 'TSF-4x48', 'TSF-8x32', 'TSF-2x48', 'TSF-1x12', 'TSF-5x24', 'TSF-3x48', 'TSF-2x32', 'TSF-1x48'}\n", + "{'TSF-3x48', 'TSF-2x64', 'TSF-2x12', 'TSF-8x48', 'TSF-6x32', 'TSF-4x48', 'TSF-8x6', 'TSF-4x6', 'TSF-2x32', 'TSF-5x12', 'TSF-5x64', 'TSF-1x64', 'TSF-2x24', 'TSF-8x24', 'TSF-4x12', 'TSF-6x12', 'TSF-1x32', 'TSF-5x32', 'TSF-3x24', 'TSF-8x12', 'TSF-5x48', 'TSF-6x64', 'TSF-7x64', 'TSF-7x48', 'TSF-1x6', 'TSF-2x48', 'TSF-7x24', 'TSF-3x32', 'TSF-1x24', 'TSF-4x64', 'TSF-3x12', 'TSF-8x64', 'TSF-4x32', 'TSF-5x6', 'TSF-7x6', 'TSF-7x12', 'TSF-3x6', 'TSF-4x24', 'TSF-6x48', 'TSF-6x6', 'TSF-1x48', 'TSF-1x12', 'TSF-7x32', 'TSF-5x24', 'TSF-2x6', 'TSF-6x24', 'TSF-3x64', 'TSF-8x32'}\n", "The Desktop is at: /Users/xuanyidong/Desktop\n", "There are 104 qlib-results\n" ] diff --git a/notebooks/TOT/Time-Curve.ipynb b/notebooks/TOT/Time-Curve.ipynb new file mode 100644 index 0000000..fa8911f --- /dev/null +++ b/notebooks/TOT/Time-Curve.ipynb @@ -0,0 +1,208 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "afraid-minutes", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", + "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" + ] + } + ], + "source": [ + "import os\n", + "import re\n", + "import sys\n", + "import torch\n", + "import pprint\n", + "import numpy as np\n", + "import pandas as pd\n", + "from pathlib import Path\n", + "from scipy.interpolate import make_interp_spline\n", + "\n", + "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", + "root_dir = (Path(__file__).parent / \"..\").resolve()\n", + "lib_dir = (root_dir / \"lib\").resolve()\n", + "print(\"The root path: {:}\".format(root_dir))\n", + "print(\"The library path: {:}\".format(lib_dir))\n", + "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", + "if str(lib_dir) not in sys.path:\n", + " sys.path.insert(0, str(lib_dir))\n", + "from utils.qlib_utils import QResult" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "continental-drain", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TSF-2x24-drop0_0s2013-01-01\n", + "TSF-2x24-drop0_0s2012-01-01\n", + "TSF-2x24-drop0_0s2008-01-01\n", + "TSF-2x24-drop0_0s2009-01-01\n", + "TSF-2x24-drop0_0s2010-01-01\n", + "TSF-2x24-drop0_0s2011-01-01\n", + "TSF-2x24-drop0_0s2008-07-01\n", + "TSF-2x24-drop0_0s2009-07-01\n", + "There are 3011 dates\n", + "Dates: 2008-01-02 2008-01-03\n" + ] + } + ], + "source": [ + "qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n", + "for qresult in qresults:\n", + " print(qresult.name)\n", + "all_dates = set()\n", + "for qresult in qresults:\n", + " dates = qresult.find_all_dates()\n", + " for date in dates:\n", + " all_dates.add(date)\n", + "all_dates = sorted(list(all_dates))\n", + "print('There are {:} dates'.format(len(all_dates)))\n", + "print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "intimate-approval", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "from matplotlib import cm\n", + "matplotlib.use(\"agg\")\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as ticker" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "supreme-basis", + "metadata": {}, + "outputs": [], + "source": [ + "def vis_time_curve(qresults, dates, use_original, save_path):\n", + " save_dir = (save_path / '..').resolve()\n", + " save_dir.mkdir(parents=True, exist_ok=True)\n", + " print('There are {:} qlib-results'.format(len(qresults)))\n", + " \n", + " dpi, width, height = 200, 5000, 2000\n", + " figsize = width / float(dpi), height / float(dpi)\n", + " LabelSize, LegendFontsize = 22, 12\n", + " font_gap = 5\n", + " linestyles = ['-', '--']\n", + " colors = ['k', 'r']\n", + " \n", + " fig = plt.figure(figsize=figsize)\n", + " cur_ax = fig.add_subplot(1, 1, 1)\n", + " for idx, qresult in enumerate(qresults):\n", + " print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n", + " x_axis, y_axis = [], []\n", + " for idate, date in enumerate(dates):\n", + " if date in qresult._date2ICs[-1]:\n", + " mean, std = qresult.get_IC_by_date(date, 100)\n", + " if not np.isnan(mean):\n", + " x_axis.append(idate)\n", + " y_axis.append(mean)\n", + " x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n", + " if use_original:\n", + " cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n", + " else:\n", + " xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n", + " spl = make_interp_spline(x_axis, y_axis, k=5)\n", + " ynew = spl(xnew)\n", + " cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n", + " \n", + " for tick in cur_ax.xaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " for tick in cur_ax.yaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n", + " fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", + " plt.close(\"all\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "shared-envelope", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The Desktop is at: /Users/xuanyidong/Desktop\n", + "There are 2 qlib-results\n", + "Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n", + "Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n", + "There are 2 qlib-results\n", + "Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n", + "Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n" + ] + } + ], + "source": [ + "# Visualization\n", + "home_dir = Path.home()\n", + "desktop_dir = home_dir / 'Desktop'\n", + "print('The Desktop is at: {:}'.format(desktop_dir))\n", + "\n", + "vis_time_curve(\n", + " (qresults[2], qresults[-1]),\n", + " all_dates,\n", + " True,\n", + " desktop_dir / 'es_csi300_time_curve.pdf')\n", + "\n", + "vis_time_curve(\n", + " (qresults[2], qresults[-1]),\n", + " all_dates,\n", + " False,\n", + " desktop_dir / 'es_csi300_time_curve-inter.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "exempt-stable", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/TOT/synthetic.ipynb b/notebooks/TOT/synthetic.ipynb new file mode 100644 index 0000000..6173183 --- /dev/null +++ b/notebooks/TOT/synthetic.ipynb @@ -0,0 +1,128 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "filled-multiple", + "metadata": {}, + "outputs": [], + "source": [ + "#\n", + "# %matplotlib notebook\n", + "from pathlib import Path\n", + "import numpy as np\n", + "import matplotlib\n", + "from matplotlib import cm\n", + "matplotlib.use(\"agg\")\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as ticker" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "supreme-basis", + "metadata": {}, + "outputs": [], + "source": [ + "def visualize_syn(save_path):\n", + " save_dir = (save_path / '..').resolve()\n", + " save_dir.mkdir(parents=True, exist_ok=True)\n", + " \n", + " dpi, width, height = 50, 2000, 1000\n", + " figsize = width / float(dpi), height / float(dpi)\n", + " LabelSize, font_gap = 30, 4\n", + " \n", + " fig = plt.figure(figsize=figsize)\n", + " \n", + " times = np.arange(0, np.pi * 100, 0.1)\n", + " num = len(times)\n", + " x = []\n", + " for i in range(num):\n", + " scale = (i + 1.) / num * 4\n", + " value = times[i] * scale\n", + " x.append(np.sin(value) * (1.3 - scale))\n", + " x = np.array(x)\n", + " y = np.cos( x * x - 0.3 * x )\n", + " \n", + " cur_ax = fig.add_subplot(2, 1, 1)\n", + " cur_ax.plot(times, x)\n", + " cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n", + " cur_ax.set_ylabel(\"x\", fontsize=LabelSize)\n", + " for tick in cur_ax.xaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " tick.label.set_rotation(30)\n", + " for tick in cur_ax.yaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " \n", + " \n", + " cur_ax = fig.add_subplot(2, 1, 2)\n", + " cur_ax.plot(times, y)\n", + " cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n", + " cur_ax.set_ylabel(\"f(x)\", fontsize=LabelSize)\n", + " for tick in cur_ax.xaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " tick.label.set_rotation(30)\n", + " for tick in cur_ax.yaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " \n", + " # fig.tight_layout()\n", + " # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n", + " fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", + " plt.close(\"all\")\n", + " # plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "shared-envelope", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The Desktop is at: /Users/xuanyidong/Desktop\n" + ] + } + ], + "source": [ + "# Visualization\n", + "home_dir = Path.home()\n", + "desktop_dir = home_dir / 'Desktop'\n", + "print('The Desktop is at: {:}'.format(desktop_dir))\n", + "visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "romantic-ordinance", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/TOT/time-curve.py b/notebooks/TOT/time-curve.py new file mode 100644 index 0000000..9477ea9 --- /dev/null +++ b/notebooks/TOT/time-curve.py @@ -0,0 +1,123 @@ +import os +import re +import sys +import torch +import qlib +import pprint +from collections import OrderedDict +import numpy as np +import pandas as pd + +from pathlib import Path + +# __file__ = os.path.dirname(os.path.realpath("__file__")) +note_dir = Path(__file__).parent.resolve() +root_dir = (Path(__file__).parent / ".." / "..").resolve() +lib_dir = (root_dir / "lib").resolve() +print("The root path: {:}".format(root_dir)) +print("The library path: {:}".format(lib_dir)) +assert lib_dir.exists(), "{:} does not exist".format(lib_dir) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +import qlib +from qlib import config as qconfig +from qlib.workflow import R +qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) + +from utils.qlib_utils import QResult + +def filter_finished(recorders): + returned_recorders = dict() + not_finished = 0 + for key, recorder in recorders.items(): + if recorder.status == "FINISHED": + returned_recorders[key] = recorder + else: + not_finished += 1 + return returned_recorders, not_finished + + +def add_to_dict(xdict, timestamp, value): + date = timestamp.date().strftime("%Y-%m-%d") + if date in xdict: + raise ValueError("This date [{:}] is already in the dict".format(date)) + xdict[date] = value + +def query_info(save_dir, verbose, name_filter, key_map): + if isinstance(save_dir, list): + results = [] + for x in save_dir: + x = query_info(x, verbose, name_filter, key_map) + results.extend(x) + return results + # Here, the save_dir must be a string + R.set_uri(str(save_dir)) + experiments = R.list_experiments() + + if verbose: + print("There are {:} experiments.".format(len(experiments))) + qresults = [] + for idx, (key, experiment) in enumerate(experiments.items()): + if experiment.id == "0": + continue + if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None: + continue + recorders = experiment.list_recorders() + recorders, not_finished = filter_finished(recorders) + if verbose: + print( + "====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format( + idx + 1, + len(experiments), + experiment.name, + len(recorders), + len(recorders) + not_finished, + ) + ) + result = QResult(experiment.name) + for recorder_id, recorder in recorders.items(): + file_names = ['results-train.pkl', 'results-valid.pkl', 'results-test.pkl'] + date2IC = OrderedDict() + for file_name in file_names: + xtemp = recorder.load_object(file_name)['all-IC'] + timestamps, values = xtemp.index.tolist(), xtemp.tolist() + for timestamp, value in zip(timestamps, values): + add_to_dict(date2IC, timestamp, value) + result.update(recorder.list_metrics(), key_map) + result.append_path( + os.path.join(recorder.uri, recorder.experiment_id, recorder.id) + ) + result.append_date2ICs(date2IC) + if not len(result): + print("There are no valid recorders for {:}".format(experiment)) + continue + else: + if verbose: + print( + "There are {:} valid recorders for {:}".format( + len(recorders), experiment.name + ) + ) + qresults.append(result) + return qresults + + +## +paths = [root_dir / 'outputs' / 'qlib-baselines-csi300'] +paths = [path.resolve() for path in paths] +print(paths) + +key_map = dict() +for xset in ("train", "valid", "test"): + key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset) + key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset) +qresults = query_info(paths, False, 'TSF-2x24-drop0_0s.*-.*-01', key_map) +print('Find {:} results'.format(len(qresults))) +times = [] +for qresult in qresults: + times.append(qresult.name.split('0_0s')[-1]) +print(times) +save_path = os.path.join(note_dir, 'temp-time-x.pth') +torch.save(qresults, save_path) +print(save_path)