{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os, pickle, sys\n", "import matplotlib.pyplot as plt\n", "from scipy import stats\n", "import numpy as np\n", "import glob\n", "from prettytable import PrettyTable" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ptcv_seed0\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| exp | grad_norm | snip | grasp | fisher | synflow | jacob_cov | samples |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| pred_ptcv_svhn_pretrain.p | 0.707 | 0.576 | 0.382 | 0.168 | 0.747 | 0.34 | 49 |\n", "| pred_ptcv_cifar100.p | 0.385 | 0.509 | 0.105 | 0.469 | 0.428 | 0.145 | 54 |\n", "| pred_ptcv_svhn.p | 0.668 | 0.695 | 0.165 | 0.675 | 0.821 | 0.344 | 49 |\n", "| pred_ptcv_cifar100_pretrain.p | 0.763 | 0.813 | 0.832 | 0.595 | 0.424 | 0.595 | 54 |\n", "| pred_ptcv_cifar10.p | 0.409 | 0.521 | 0.127 | 0.471 | 0.456 | 0.046 | 56 |\n", "| pred_ptcv_cifar10_pretrain.p | 0.639 | 0.71 | 0.434 | 0.464 | 0.416 | 0.646 | 56 |\n", "| pred_ptcv_ImageNet1k.p | 0.563 | 0.644 | 0.025 | 0.675 | 0.652 | 0.343 | 191 |\n", "| pred_ptcv_ImageNet1k_pretrain.p | 0.692 | 0.67 | 0.493 | 0.725 | 0.691 | 0.141 | 191 |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "ptcv_seed1\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| exp | grad_norm | snip | grasp | fisher | synflow | jacob_cov | samples |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| pred_ptcv_svhn_pretrain.p | 0.681 | 0.541 | 0.449 | 0.148 | 0.747 | 0.324 | 49 |\n", "| pred_ptcv_cifar100.p | 0.384 | 0.501 | 0.051 | 0.483 | 0.429 | 0.059 | 54 |\n", "| pred_ptcv_svhn.p | 0.642 | 0.666 | 0.077 | 0.633 | 0.83 | 0.224 | 49 |\n", "| pred_ptcv_cifar100_pretrain.p | 0.618 | 0.646 | 0.793 | 0.543 | 0.424 | 0.62 | 54 |\n", "| pred_ptcv_cifar10.p | 0.387 | 0.505 | 0.111 | 0.476 | 0.455 | 0.101 | 56 |\n", "| pred_ptcv_cifar10_pretrain.p | 0.689 | 0.733 | 0.376 | 0.476 | 0.416 | 0.646 | 56 |\n", "| pred_ptcv_ImageNet1k.p | 0.569 | 0.64 | 0.165 | 0.668 | 0.651 | 0.292 | 191 |\n", "| pred_ptcv_ImageNet1k_pretrain.p | 0.69 | 0.671 | 0.502 | 0.73 | 0.691 | 0.14 | 191 |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "ptcv_seed2\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| exp | grad_norm | snip | grasp | fisher | synflow | jacob_cov | samples |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| pred_ptcv_svhn_pretrain.p | 0.661 | 0.526 | 0.346 | 0.115 | 0.747 | 0.357 | 49 |\n", "| pred_ptcv_cifar100.p | 0.366 | 0.474 | 0.045 | 0.48 | 0.428 | 0.093 | 54 |\n", "| pred_ptcv_svhn.p | 0.668 | 0.685 | 0.244 | 0.661 | 0.823 | 0.28 | 49 |\n", "| pred_ptcv_cifar100_pretrain.p | 0.462 | 0.608 | 0.829 | 0.465 | 0.424 | 0.636 | 54 |\n", "| pred_ptcv_cifar10.p | 0.411 | 0.511 | 0.085 | 0.496 | 0.454 | 0.056 | 56 |\n", "| pred_ptcv_cifar10_pretrain.p | 0.715 | 0.761 | 0.478 | 0.515 | 0.416 | 0.641 | 56 |\n", "| pred_ptcv_ImageNet1k.p | 0.564 | 0.633 | 0.096 | 0.669 | 0.652 | 0.327 | 191 |\n", "| pred_ptcv_ImageNet1k_pretrain.p | 0.692 | 0.671 | 0.507 | 0.732 | 0.691 | 0.177 | 191 |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "ptcv_seed3\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| exp | grad_norm | snip | grasp | fisher | synflow | jacob_cov | samples |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| pred_ptcv_svhn_pretrain.p | 0.653 | 0.529 | 0.309 | 0.102 | 0.747 | 0.349 | 49 |\n", "| pred_ptcv_cifar100.p | 0.366 | 0.48 | 0.058 | 0.474 | 0.428 | 0.204 | 54 |\n", "| pred_ptcv_svhn.p | 0.661 | 0.678 | 0.128 | 0.661 | 0.833 | 0.256 | 49 |\n", "| pred_ptcv_cifar100_pretrain.p | 0.682 | 0.783 | 0.8 | 0.664 | 0.424 | 0.621 | 54 |\n", "| pred_ptcv_cifar10.p | 0.388 | 0.495 | 0.014 | 0.479 | 0.454 | 0.036 | 56 |\n", "| pred_ptcv_cifar10_pretrain.p | 0.63 | 0.718 | 0.222 | 0.478 | 0.416 | 0.666 | 56 |\n", "| pred_ptcv_ImageNet1k.p | 0.575 | 0.647 | 0.081 | 0.669 | 0.651 | 0.301 | 191 |\n", "| pred_ptcv_ImageNet1k_pretrain.p | 0.691 | 0.668 | 0.493 | 0.725 | 0.691 | 0.171 | 191 |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "ptcv_seed4\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| exp | grad_norm | snip | grasp | fisher | synflow | jacob_cov | samples |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| pred_ptcv_svhn_pretrain.p | 0.684 | 0.549 | 0.33 | 0.17 | 0.747 | 0.334 | 49 |\n", "| pred_ptcv_cifar100.p | 0.368 | 0.484 | 0.085 | 0.492 | 0.429 | 0.207 | 54 |\n", "| pred_ptcv_svhn.p | 0.659 | 0.671 | 0.082 | 0.641 | 0.824 | 0.252 | 49 |\n", "| pred_ptcv_cifar100_pretrain.p | 0.766 | 0.831 | 0.793 | 0.755 | 0.424 | 0.62 | 54 |\n", "| pred_ptcv_cifar10.p | 0.401 | 0.533 | 0.086 | 0.473 | 0.454 | 0.06 | 56 |\n", "| pred_ptcv_cifar10_pretrain.p | 0.536 | 0.614 | 0.273 | 0.412 | 0.416 | 0.657 | 56 |\n", "| pred_ptcv_ImageNet1k.p | 0.561 | 0.627 | 0.092 | 0.659 | 0.651 | 0.272 | 191 |\n", "| pred_ptcv_ImageNet1k_pretrain.p | 0.689 | 0.67 | 0.498 | 0.73 | 0.691 | 0.164 | 191 |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "ptcv_seed5\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| exp | grad_norm | snip | grasp | fisher | synflow | jacob_cov | samples |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n", "| pred_ptcv_svhn_pretrain.p | 0.672 | 0.557 | 0.384 | 0.186 | 0.747 | 0.271 | 49 |\n", "| pred_ptcv_cifar100.p | 0.393 | 0.503 | 0.056 | 0.493 | 0.429 | 0.05 | 54 |\n", "| pred_ptcv_svhn.p | 0.643 | 0.675 | 0.237 | 0.641 | 0.826 | 0.25 | 49 |\n", "| pred_ptcv_cifar100_pretrain.p | 0.741 | 0.776 | 0.833 | 0.676 | 0.424 | 0.615 | 54 |\n", "| pred_ptcv_cifar10.p | 0.384 | 0.484 | 0.087 | 0.468 | 0.457 | 0.004 | 56 |\n", "| pred_ptcv_cifar10_pretrain.p | 0.692 | 0.767 | 0.303 | 0.533 | 0.416 | 0.66 | 56 |\n", "| pred_ptcv_ImageNet1k.p | 0.57 | 0.638 | 0.141 | 0.67 | 0.651 | 0.337 | 191 |\n", "| pred_ptcv_ImageNet1k_pretrain.p | 0.689 | 0.671 | 0.51 | 0.723 | 0.691 | 0.178 | 191 |\n", "+---------------------------------+-----------+-------+-------+--------+---------+-----------+---------+\n" ] } ], "source": [ "root='../results_release/ptcv'\n", "\n", "alld = [f'ptcv_seed{i}' for i in range(0,6)]\n", "\n", "allm = []\n", "for dirs in alld:\n", " res = {}\n", " print(dirs)\n", " dirs = os.path.join(root,dirs)\n", " hl = ['exp']\n", " added_hl = False\n", " for fn in os.listdir(dirs):\n", " res[fn] = {}\n", " with open(os.path.join(dirs,fn),'rb') as f:\n", " ptcv=pickle.load(f)\n", " acc = []\n", " metrics = {}\n", " for d in ptcv:\n", " acc.append(d['valacc'])\n", " if len(hl) == 1:\n", " hl.extend(d['logmeasures'].keys())\n", " for m in d['logmeasures'].keys():\n", " if not m in metrics:\n", " metrics[m] = []\n", " metrics[m].append(d['logmeasures'][m])\n", " if not added_hl:\n", " added_hl = True\n", " hl.append('samples')\n", " t = PrettyTable(hl)\n", " row=[fn]\n", " for m,v in metrics.items():\n", " cr = abs(stats.spearmanr(acc,v,nan_policy='omit').correlation)\n", " res[fn][m] = cr\n", " cr=round(cr,3)\n", " row.append(cr)\n", " row.append(len(acc))\n", " t.add_row(row)\n", " allm.append(res)\n", " print(t)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "parsed = {}\n", "for e in allm:\n", " for d,t in e.items():\n", " if not d in parsed.keys():\n", " parsed[d] = {}\n", " for m,c in t.items():\n", " if not m in parsed[d].keys():\n", " parsed[d][m] = []\n", " parsed[d][m].append(c)\n", "labels = {\n", " 'pred_ptcv_ImageNet1k.p': 'ImageNet1k',\n", " 'pred_ptcv_cifar10.p': 'CIFAR-10',\n", " 'pred_ptcv_cifar100.p': 'CIFAR100',\n", " 'pred_ptcv_svhn.p': 'SVHN', \n", " 'pred_ptcv_ImageNet1k_pretrain.p': 'ImageNet1k',\n", " 'pred_ptcv_cifar10_pretrain.p': 'CIFAR-10',\n", " 'pred_ptcv_cifar100_pretrain.p': 'CIFAR100',\n", " 'pred_ptcv_svhn_pretrain.p': 'SVHN', \n", "}\n", "pattern = {\n", " 'pred_ptcv_ImageNet1k.p': '*',\n", " 'pred_ptcv_cifar10.p': '\\\\',\n", " 'pred_ptcv_cifar100.p': '/',\n", " 'pred_ptcv_svhn.p': 'o', \n", " 'pred_ptcv_ImageNet1k_pretrain.p': '*',\n", " 'pred_ptcv_cifar10_pretrain.p': '\\\\',\n", " 'pred_ptcv_cifar100_pretrain.p': '/',\n", " 'pred_ptcv_svhn_pretrain.p': 'o', \n", "}\n", "ps=[ \"|\" , \"\\\\\" , \"/\" , \"+\" , \"-\", \".\", \"*\",\"x\", \"o\", \"O\" ]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "width = 0.75\n", "fig, ax = plt.subplots(figsize=(8,4))\n", "i=0\n", "for exp in ['pred_ptcv_cifar10.p', 'pred_ptcv_cifar100.p', 'pred_ptcv_svhn.p', 'pred_ptcv_ImageNet1k.p']:\n", " #exp=exp.replace('.p', '_pretrain.p')\n", " bars = []\n", " for k,v in parsed[exp].items():\n", " v = np.array(v)\n", " m = np.mean(v)\n", " s = np.std(v)\n", " #print(k,m,s,v)\n", " bars.append((k,m,s))\n", " \n", " ys = [s[1] for s in bars]\n", " yerrs = [s[2] for s in bars]\n", " lbs = [s[0] for s in bars]\n", " xpos = np.arange(len(ys))\n", " i += 1\n", " \n", " ax.bar(xpos+(i-2.5)*width/4, ys, yerr=yerrs, width=width/4, align='center', alpha=0.5, label=labels[exp],\n", " hatch=pattern[exp], edgecolor='black', lw=1.,\n", " error_kw=dict(ecolor='blue', lw=2, capsize=1, capthick=2))\n", " ax.set_ylabel('Spearman $\\\\rho$')\n", " ax.set_xticks(xpos)\n", " ax.set_xticklabels(lbs)\n", " ax.yaxis.grid(True)\n", " ax.set_axisbelow(True)\n", "\n", "# Save the figure and show\n", "plt.legend(prop={'size': 10})\n", "plt.tight_layout()\n", "plt.savefig('ptcv.pdf')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "pattern = {\n", " 'grad_norm': '\\\\',\n", " 'snip': '/',\n", " 'grasp': '-',\n", " 'fisher': '*',\n", " 'synflow': 'o', \n", " 'jacob_cov': 'x' \n", "}\n", "ps=[ \"|\" , \"\\\\\" , \"/\" , \"+\" , \"-\", \".\", \"*\",\"x\", \"o\", \"O\" ]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "width = 0.75\n", "fig, ax = plt.subplots(figsize=(9,4))\n", "i=0\n", "all_b = {}\n", "for exp in ['pred_ptcv_cifar10.p', 'pred_ptcv_cifar100.p', 'pred_ptcv_svhn.p', 'pred_ptcv_ImageNet1k.p']:\n", " #exp=exp.replace('.p', '_pretrain.p')\n", " bars = []\n", " for k,v in parsed[exp].items():\n", " v = np.array(v)\n", " m = np.mean(v)\n", " s = np.std(v)\n", " #print(k,m,s,v)\n", " bars.append((k,m,s))\n", " if k not in all_b:\n", " all_b[k] = []\n", " all_b[k].append((labels[exp],m,s))\n", " print()\n", " \n", "\n", "i=0\n", "for k,v in all_b.items():\n", " \n", " ys = [s[1] for s in v]\n", " yerrs = [s[2] for s in v]\n", " lbs = [s[0] for s in v]\n", " xpos = np.arange(len(ys))\n", " \n", " ax.bar(xpos+(i-2.5)*width/6, ys, yerr=yerrs, width=width/6, align='center', alpha=0.5, label=k,\n", " hatch=pattern[k], edgecolor='black', lw=1.,\n", " error_kw=dict(ecolor='blue', lw=2, capsize=1, capthick=2))\n", " ax.set_ylabel('Spearman $\\\\rho$')\n", " ax.yaxis.grid(True)\n", " ax.set_axisbelow(True)\n", " i += 1\n", "\n", " \n", "ax.set_xticks([0, 1, 2, 3])\n", "ax.set_xticklabels(['CIFAR-10', 'CIFAR-100', 'SVHN', 'ImageNet1k'])\n", "# Save the figure and show\n", "#plt.legend(loc='upper center', prop={'size': 11}, ncol=6, bbox_to_anchor=(0.5,1.15))\n", "plt.legend(prop={'size': 11}, bbox_to_anchor=(1.25,0.75))\n", "plt.tight_layout()\n", "plt.savefig('ptcv_flip.pdf')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }