{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "afraid-minutes", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" ] } ], "source": [ "import os\n", "import re\n", "import sys\n", "import torch\n", "import pprint\n", "import numpy as np\n", "import pandas as pd\n", "from pathlib import Path\n", "from scipy.interpolate import make_interp_spline\n", "\n", "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", "root_dir = (Path(__file__).parent / \"..\").resolve()\n", "lib_dir = (root_dir / \"lib\").resolve()\n", "print(\"The root path: {:}\".format(root_dir))\n", "print(\"The library path: {:}\".format(lib_dir))\n", "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", "if str(lib_dir) not in sys.path:\n", " sys.path.insert(0, str(lib_dir))\n", "from utils.qlib_utils import QResult" ] }, { "cell_type": "code", "execution_count": 2, "id": "continental-drain", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TSF-2x24-drop0_0s2013-01-01\n", "TSF-2x24-drop0_0s2012-01-01\n", "TSF-2x24-drop0_0s2008-01-01\n", "TSF-2x24-drop0_0s2009-01-01\n", "TSF-2x24-drop0_0s2010-01-01\n", "TSF-2x24-drop0_0s2011-01-01\n", "TSF-2x24-drop0_0s2008-07-01\n", "TSF-2x24-drop0_0s2009-07-01\n", "There are 3011 dates\n", "Dates: 2008-01-02 2008-01-03\n" ] } ], "source": [ "qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n", "for qresult in qresults:\n", " print(qresult.name)\n", "all_dates = set()\n", "for qresult in qresults:\n", " dates = qresult.find_all_dates()\n", " for date in dates:\n", " all_dates.add(date)\n", "all_dates = sorted(list(all_dates))\n", "print('There are {:} dates'.format(len(all_dates)))\n", "print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))" ] }, { "cell_type": "code", "execution_count": 3, "id": "intimate-approval", "metadata": {}, "outputs": [], "source": [ "import matplotlib\n", "from matplotlib import cm\n", "matplotlib.use(\"agg\")\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as ticker" ] }, { "cell_type": "code", "execution_count": 6, "id": "supreme-basis", "metadata": {}, "outputs": [], "source": [ "def vis_time_curve(qresults, dates, use_original, save_path):\n", " save_dir = (save_path / '..').resolve()\n", " save_dir.mkdir(parents=True, exist_ok=True)\n", " print('There are {:} qlib-results'.format(len(qresults)))\n", " \n", " dpi, width, height = 200, 5000, 2000\n", " figsize = width / float(dpi), height / float(dpi)\n", " LabelSize, LegendFontsize = 22, 12\n", " font_gap = 5\n", " linestyles = ['-', '--']\n", " colors = ['k', 'r']\n", " \n", " fig = plt.figure(figsize=figsize)\n", " cur_ax = fig.add_subplot(1, 1, 1)\n", " for idx, qresult in enumerate(qresults):\n", " print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n", " x_axis, y_axis = [], []\n", " for idate, date in enumerate(dates):\n", " if date in qresult._date2ICs[-1]:\n", " mean, std = qresult.get_IC_by_date(date, 100)\n", " if not np.isnan(mean):\n", " x_axis.append(idate)\n", " y_axis.append(mean)\n", " x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n", " if use_original:\n", " cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n", " else:\n", " xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n", " spl = make_interp_spline(x_axis, y_axis, k=5)\n", " ynew = spl(xnew)\n", " cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n", " \n", " for tick in cur_ax.xaxis.get_major_ticks():\n", " tick.label.set_fontsize(LabelSize - font_gap)\n", " for tick in cur_ax.yaxis.get_major_ticks():\n", " tick.label.set_fontsize(LabelSize - font_gap)\n", " cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n", " fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", " plt.close(\"all\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "shared-envelope", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The Desktop is at: /Users/xuanyidong/Desktop\n", "There are 2 qlib-results\n", "Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n", "Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n", "There are 2 qlib-results\n", "Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n", "Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n" ] } ], "source": [ "# Visualization\n", "home_dir = Path.home()\n", "desktop_dir = home_dir / 'Desktop'\n", "print('The Desktop is at: {:}'.format(desktop_dir))\n", "\n", "vis_time_curve(\n", " (qresults[2], qresults[-1]),\n", " all_dates,\n", " True,\n", " desktop_dir / 'es_csi300_time_curve.pdf')\n", "\n", "vis_time_curve(\n", " (qresults[2], qresults[-1]),\n", " all_dates,\n", " False,\n", " desktop_dir / 'es_csi300_time_curve-inter.pdf')" ] }, { "cell_type": "code", "execution_count": null, "id": "exempt-stable", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }