209 lines
6.3 KiB
Plaintext
209 lines
6.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "afraid-minutes",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
|
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"import re\n",
|
|
"import sys\n",
|
|
"import torch\n",
|
|
"import pprint\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"from pathlib import Path\n",
|
|
"from scipy.interpolate import make_interp_spline\n",
|
|
"\n",
|
|
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
|
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
|
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
|
"print(\"The root path: {:}\".format(root_dir))\n",
|
|
"print(\"The library path: {:}\".format(lib_dir))\n",
|
|
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
|
"if str(lib_dir) not in sys.path:\n",
|
|
" sys.path.insert(0, str(lib_dir))\n",
|
|
"from utils.qlib_utils import QResult"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "continental-drain",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"TSF-2x24-drop0_0s2013-01-01\n",
|
|
"TSF-2x24-drop0_0s2012-01-01\n",
|
|
"TSF-2x24-drop0_0s2008-01-01\n",
|
|
"TSF-2x24-drop0_0s2009-01-01\n",
|
|
"TSF-2x24-drop0_0s2010-01-01\n",
|
|
"TSF-2x24-drop0_0s2011-01-01\n",
|
|
"TSF-2x24-drop0_0s2008-07-01\n",
|
|
"TSF-2x24-drop0_0s2009-07-01\n",
|
|
"There are 3011 dates\n",
|
|
"Dates: 2008-01-02 2008-01-03\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n",
|
|
"for qresult in qresults:\n",
|
|
" print(qresult.name)\n",
|
|
"all_dates = set()\n",
|
|
"for qresult in qresults:\n",
|
|
" dates = qresult.find_all_dates()\n",
|
|
" for date in dates:\n",
|
|
" all_dates.add(date)\n",
|
|
"all_dates = sorted(list(all_dates))\n",
|
|
"print('There are {:} dates'.format(len(all_dates)))\n",
|
|
"print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "intimate-approval",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import matplotlib\n",
|
|
"from matplotlib import cm\n",
|
|
"matplotlib.use(\"agg\")\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import matplotlib.ticker as ticker"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "supreme-basis",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def vis_time_curve(qresults, dates, use_original, save_path):\n",
|
|
" save_dir = (save_path / '..').resolve()\n",
|
|
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
" print('There are {:} qlib-results'.format(len(qresults)))\n",
|
|
" \n",
|
|
" dpi, width, height = 200, 5000, 2000\n",
|
|
" figsize = width / float(dpi), height / float(dpi)\n",
|
|
" LabelSize, LegendFontsize = 22, 12\n",
|
|
" font_gap = 5\n",
|
|
" linestyles = ['-', '--']\n",
|
|
" colors = ['k', 'r']\n",
|
|
" \n",
|
|
" fig = plt.figure(figsize=figsize)\n",
|
|
" cur_ax = fig.add_subplot(1, 1, 1)\n",
|
|
" for idx, qresult in enumerate(qresults):\n",
|
|
" print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n",
|
|
" x_axis, y_axis = [], []\n",
|
|
" for idate, date in enumerate(dates):\n",
|
|
" if date in qresult._date2ICs[-1]:\n",
|
|
" mean, std = qresult.get_IC_by_date(date, 100)\n",
|
|
" if not np.isnan(mean):\n",
|
|
" x_axis.append(idate)\n",
|
|
" y_axis.append(mean)\n",
|
|
" x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n",
|
|
" if use_original:\n",
|
|
" cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n",
|
|
" else:\n",
|
|
" xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n",
|
|
" spl = make_interp_spline(x_axis, y_axis, k=5)\n",
|
|
" ynew = spl(xnew)\n",
|
|
" cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n",
|
|
" \n",
|
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
|
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
|
" cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n",
|
|
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
|
" plt.close(\"all\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "shared-envelope",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
|
"There are 2 qlib-results\n",
|
|
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
|
|
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n",
|
|
"There are 2 qlib-results\n",
|
|
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
|
|
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Visualization\n",
|
|
"home_dir = Path.home()\n",
|
|
"desktop_dir = home_dir / 'Desktop'\n",
|
|
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
|
|
"\n",
|
|
"vis_time_curve(\n",
|
|
" (qresults[2], qresults[-1]),\n",
|
|
" all_dates,\n",
|
|
" True,\n",
|
|
" desktop_dir / 'es_csi300_time_curve.pdf')\n",
|
|
"\n",
|
|
"vis_time_curve(\n",
|
|
" (qresults[2], qresults[-1]),\n",
|
|
" all_dates,\n",
|
|
" False,\n",
|
|
" desktop_dir / 'es_csi300_time_curve-inter.pdf')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "exempt-stable",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|