MeCo/toy_model.ipynb

973 lines
305 KiB
Plaintext
Raw Normal View History

2023-05-04 07:09:03 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 240,
"id": "45e05b72",
"metadata": {},
"outputs": [],
"source": [
"from nasbench201.search_model_darts_proj import TinyNetworkDartsProj\n",
"import torch\n",
"import torch.nn as nn\n",
"from nasbench201.cell_operations import SearchSpaceNames\n",
"import nasbench201.utils as ig_utils\n",
"import torch.utils\n",
"import torchvision.datasets as dset\n",
"import numpy as np\n",
"import copy"
]
},
{
"cell_type": "code",
"execution_count": 241,
"id": "eaa02532",
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(1)\n",
"torch.manual_seed(1)\n",
"torch.cuda.manual_seed(1)"
]
},
{
"cell_type": "code",
"execution_count": 242,
"id": "29976057",
"metadata": {},
"outputs": [],
"source": [
"class AGRS():\n",
" def __init__(self):\n",
" self.data = '../data'\n",
" self.dataset = 'cifar10'\n",
" self.train_portion = 0.5\n",
" self.batch_size = 64\n",
" self.init_channels=16\n",
" self.layers = 8\n",
" self.learning_rate = 0.025\n",
" self.learning_rate_min = 0.001\n",
" self.momentum = 0.9\n",
" self.nesterov = False\n",
" self.weight_decay = 3e-4\n",
" self.grad_clip = 5\n",
" self.cutout = False\n",
"args = AGRS()"
]
},
{
"cell_type": "code",
"execution_count": 243,
"id": "3725b779",
"metadata": {},
"outputs": [],
"source": [
"def Jocab_Score(ori_model, input, target, weights=None):\n",
" model = copy.deepcopy(ori_model)\n",
" model.eval()\n",
" model.proj_weights = weights\n",
" num_edge, num_op = model.num_edge, model.num_op\n",
" for i in range(num_edge):\n",
" model.candidate_flags[i] = False\n",
" batch_size = input.shape[0]\n",
" model.K = torch.zeros(batch_size, batch_size).cuda()\n",
" model.K_list = {}\n",
" def counting_forward_hook(module, inp, out):\n",
" if isinstance(inp, tuple):\n",
" inp = inp[0]\n",
" inp = inp.view(inp.size(0), -1)\n",
" x = (inp > 0).float()\n",
" K = x @ x.t()\n",
" if x.cpu().numpy().sum() == 0:\n",
" model.K = model.K\n",
" else:\n",
" K2 = (1.-x) @ (1.-x.t())\n",
" model.K = model.K + K + K2\n",
" model.K_list[module.name]=K\n",
" #print(module)\n",
" \n",
"\n",
" for name, module in model.named_modules():\n",
" if isinstance(module, nn.ReLU):\n",
" module.name = name\n",
" module.register_forward_hook(counting_forward_hook)\n",
" \n",
" input = input.cuda()\n",
" model(input)\n",
" K = model.K.cpu().numpy()\n",
" score = hooklogdet(model.K.cpu().numpy())\n",
" #print(model.K_list)\n",
" K_list = model.K_list\n",
" del model\n",
" del input\n",
" return score, K,K_list\n",
"\n",
"def hooklogdet(K, labels=None):\n",
" s, ld = np.linalg.slogdet(K)\n",
" return ld"
]
},
{
"cell_type": "code",
"execution_count": 244,
"id": "ae134d08",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)\n",
"train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)\n",
"valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)"
]
},
{
"cell_type": "code",
"execution_count": 245,
"id": "cd9923d8",
"metadata": {},
"outputs": [],
"source": [
"num_train = len(train_data)\n",
"indices = list(range(num_train))\n",
"split = 64\n",
"\n",
"train_queue = torch.utils.data.DataLoader(\n",
" train_data, batch_size=args.batch_size,\n",
" sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),\n",
" pin_memory=True)\n",
"input, target = next(iter(train_queue))"
]
},
{
"cell_type": "code",
"execution_count": 1529,
"id": "e08b4613",
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(2)\n",
"torch.manual_seed(2)\n",
"torch.cuda.manual_seed(2)\n",
"from scipy.stats import rankdata\n",
"input, target = next(iter(train_queue))\n",
"LAYER=8\n",
"OPN=4"
]
},
{
"cell_type": "code",
"execution_count": 1530,
"id": "58c0ad9a",
"metadata": {},
"outputs": [],
"source": [
"from nasbench201.cell_operations import OPS\n",
"class TinyNetwork(nn.Module):\n",
" def __init__(self, C, N, num_classes, criterion, affine=False, track_running_stats=True, stem_channels=3):\n",
" super(TinyNetwork, self).__init__()\n",
" self.stem = nn.Sequential(\n",
" nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False),\n",
" nn.BatchNorm2d(C))\n",
" op_names=['skip_connect','nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']\n",
" self.N=N\n",
" self.edges = nn.ModuleDict()\n",
" for i in range(N):\n",
" self.edges[str(i)]=nn.ModuleList([OPS[op_name](C, C, 1, affine, track_running_stats) for op_name in op_names])\n",
" \n",
" self.lastact = nn.Sequential(nn.BatchNorm2d(C), nn.ReLU(inplace=True))\n",
" self.global_pooling = nn.AdaptiveAvgPool2d(1)\n",
" self.classifier = nn.Linear(C, num_classes)\n",
" self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)\n",
" self.weights=[[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4]]\n",
" self.weights = np.array(self.weights)\n",
" self.weights=torch.from_numpy(self.weights)\n",
" \n",
" def forward(self, inputs):\n",
" weights = self.weights\n",
" sum_value=[]\n",
" feature = self.stem(inputs)\n",
" for i in range(self.N):\n",
" feature=sum(op(feature, block_input=True)*w if w==0 else op(feature) * w for op, w in zip(self.edges[str(i)], weights[i]))\n",
"# with torch.no_grad():\n",
"# print(self.calc_k(feature))\n",
"# print(torch.mean(torch.abs(feature)))\n",
"# print(torch.count_nonzero((feature>0).float()))\n",
" \n",
" out = self.lastact(feature)\n",
" out = self.global_pooling( out )\n",
" out = out.view(out.size(0), -1)\n",
" logits = self.classifier(out)\n",
" #print(sum_value)\n",
" #print('model end')\n",
" return logits\n",
" \n",
" def calc_k(self, inp):\n",
" inp = inp.view(inp.size(0), -1)\n",
" x = (inp > 0).float()\n",
" K = x @ x.t()\n",
" if x.cpu().numpy().sum() == 0:\n",
" return 0\n",
" else:\n",
" K2 = (1.-x) @ (1.-x.t())\n",
" K = K + K2\n",
" return hooklogdet(K.cpu().numpy())\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 1533,
"id": "b84ded93",
"metadata": {},
"outputs": [],
"source": [
"model = TinyNetwork(C=16, N=LAYER, num_classes=10, criterion=nn.CrossEntropyLoss())\n",
"#model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 1534,
"id": "baddff20",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hongkaiw/anaconda2/envs/mct/lib/python3.7/site-packages/torch/tensor.py:593: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).\n",
" 'incorrect results).', category=RuntimeWarning)\n",
"/home/hongkaiw/.local/lib/python3.7/site-packages/ipykernel_launcher.py:27: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n"
]
}
],
"source": [
"import torch.onnx \n",
"model.eval()\n",
" \n",
"torch.onnx.export(model, # model being run \n",
" input, # model input (or a tuple for multiple inputs) \n",
" \"toy.onnx\", # where to save the model \n",
" export_params=False, # store the trained parameter weights inside the model file \n",
" opset_version=10, # the ONNX version to export the model to \n",
" do_constant_folding=True, # whether to execute constant folding for optimization \n",
" input_names = ['modelInput'], # the model's input names \n",
" output_names = ['modelOutput'], # the model's output names \n",
" dynamic_axes={'modelInput' : {1 : 'batch_size'}, # variable length axes \n",
" 'modelOutput' : {0 : 'batch_size'}}) "
]
},
{
"cell_type": "code",
"execution_count": 1522,
"id": "78f85e24",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"def Jocab_Score(ori_model, input, target, weights=None):\n",
" model = copy.deepcopy(ori_model)\n",
" model.eval()\n",
" model.proj_weights = weights\n",
" batch_size = input.shape[0]\n",
" model.K = torch.zeros(batch_size, batch_size).cuda()\n",
" model.K_list = {}\n",
" model.count = 0\n",
" def counting_forward_hook(module, inp, out):\n",
" if isinstance(inp, tuple):\n",
" inp = inp[0]\n",
" inp = inp.view(inp.size(0), -1)\n",
" #with torch.no_grad():\n",
" #print(torch.sum((inp > 0).float()), torch.count_nonzero(inp))\n",
" x = (inp > 0).float()\n",
" K = x @ x.t()\n",
" if x.cpu().numpy().sum() == 0:\n",
" model.K = model.K\n",
" else:\n",
" K2 = (1.-x) @ (1.-x.t())\n",
" model.K = model.K + K + K2\n",
" model.K_list[module.name]=K\n",
" #print(module)\n",
" \n",
"\n",
" for name, module in model.named_modules():\n",
" if isinstance(module, nn.ReLU):\n",
" #if 'ReLU' in str(type(module)):\n",
" module.name = name\n",
" #print(module)\n",
" model.count+=1\n",
" module.register_forward_hook(counting_forward_hook)\n",
" \n",
" input = input.cuda()\n",
" model(input, weights)\n",
" K = model.K.cpu().numpy()\n",
" score = hooklogdet(model.K.cpu().numpy())\n",
" #print(model.K_list)\n",
" K_list = model.K_list\n",
" #print(model.count)\n",
" del model\n",
" del input\n",
" return score, K,K_list\n",
"\n",
"def hooklogdet(K, labels=None):\n",
" s, ld = np.linalg.slogdet(K)\n",
" return ld"
]
},
{
"cell_type": "code",
"execution_count": 1523,
"id": "42ae5a33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n",
"0\n",
"1\n",
"1\n",
"1\n",
"1\n",
"1\n",
"2\n",
"[724.5167, 724.62964, 721.42676, 726.57513, 721.924, 724.0039, 724.2308, 724.328, 724.6446, 723.4378, 723.59174, 726.2936, 726.9928, 722.4523, 723.66644, 727.96545, 727.3341, 722.5211, 723.89703, 727.9818, 726.8876, 723.4647, 724.414, 727.43134, 727.0108, 724.00287, 724.3993, 727.34503, 727.5785, 724.18317, 724.01965, 727.66486]\n",
"[2.625 1.5 1.875 4. ]\n",
"1\n"
]
}
],
"source": [
"weights=[[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4],[1/4, 1/4, 1/4, 1/4]]\n",
"\n",
"\n",
"pt_score = []\n",
"avg_skip_rank=np.array([0.0,0.0,0.0,0.0])\n",
"count_skip = 0\n",
"crit,K,K_list = Jocab_Score(model, input, target, weights)\n",
"for l in range(LAYER):\n",
" op_s = []\n",
" for o in range(OPN):\n",
" w = copy.deepcopy(weights)\n",
" w[l][o]=0\n",
" crit,K,K_list = Jocab_Score(model, input, target, w)\n",
" pt_score.append(crit)\n",
" op_s.append(crit)\n",
" avg_skip_rank +=(rankdata(op_s))\n",
" select=np.argmin(op_s)\n",
" print(select)\n",
" if select ==0:\n",
" count_skip+=1\n",
"print(pt_score)\n",
"print(avg_skip_rank/LAYER)\n",
"print(count_skip)"
]
},
{
"cell_type": "code",
"execution_count": 1467,
"id": "3133fd42",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n",
"[557.9228, 597.0493, 597.9796, 550.9689]\n",
"[3. 2. 1. 4.]\n",
"0\n"
]
}
],
"source": [
"disc_score = []\n",
"avg_skip_rank=np.array([0.0,0.0,0.0,0.0])\n",
"count_skip=0\n",
"for l in range(LAYER):\n",
" op_s = []\n",
" for o in range(OPN):\n",
" w = copy.deepcopy(weights)\n",
" w[l]=np.zeros_like(w[l])\n",
" w[l][o]=1\n",
" #w[l][0]=1\n",
" crit,K,K_list = Jocab_Score(model, input, target, w)\n",
" #print(w)\n",
" op_s.append(crit)\n",
" disc_score.append(crit)\n",
" #print([5-x for x in rankdata(op_s)])\n",
" avg_skip_rank +=(5-rankdata(op_s))\n",
" select=np.argmax(op_s)\n",
" print(select)\n",
" if select ==0:\n",
" count_skip+=1\n",
"print(disc_score)\n",
"print(avg_skip_rank/LAYER)\n",
"print(count_skip)"
]
},
{
"cell_type": "code",
"execution_count": 1468,
"id": "e091015c",
"metadata": {},
"outputs": [],
"source": [
"w = copy.deepcopy(weights)\n",
"arch=[1,1,1,1,1,1,1,1]\n",
"for i in range(len(arch)): \n",
" w[i]=np.zeros_like(w[i])\n",
" w[i][arch[i]]=1\n",
"crit,K,K_list = Jocab_Score(model, input, target, w)\n"
]
},
{
"cell_type": "code",
"execution_count": 1469,
"id": "05ffd6a0",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7116abcb5edb4d99822e2a5f26b3cfb1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[557.9228, 597.0493, 597.9796, 550.9689]\n"
]
}
],
"source": [
"from itertools import combinations_with_replacement,permutations,product\n",
"from tqdm.notebook import tqdm\n",
"final_score = []\n",
"archs=[]\n",
"archs =list(product([0,1,2,3], repeat=LAYER)) \n",
" \n",
"archs = [list(x) for x in archs]\n",
"#print(archs)\n",
"for i in tqdm(range(len(archs))):\n",
" arch = archs[i]\n",
" w = copy.deepcopy(weights)\n",
" for i in range(len(arch)):\n",
" w[i]=np.zeros_like(w[i])\n",
" w[i][arch[i]]=1\n",
" crit,K,K_list = Jocab_Score(model, input, target, w)\n",
" final_score.append(crit)\n",
"print(final_score)"
]
},
{
"cell_type": "code",
"execution_count": 1470,
"id": "24aab655",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>arch</th>\n",
" <th>naswot</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[0]</td>\n",
" <td>557.922791</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[1]</td>\n",
" <td>597.049316</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>[2]</td>\n",
" <td>597.979614</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>[3]</td>\n",
" <td>550.968872</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" arch naswot\n",
"0 [0] 557.922791\n",
"1 [1] 597.049316\n",
"2 [2] 597.979614\n",
"3 [3] 550.968872"
]
},
"execution_count": 1470,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df = pd.DataFrame(list(zip(archs, final_score)),columns =['arch', 'naswot'])\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 1471,
"id": "2da3ceab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[557.9227905273438, 597.04931640625, 597.9796142578125, 550.9688720703125]\n"
]
}
],
"source": [
"index=0\n",
"best_nwot=[]\n",
"for l in range(LAYER):\n",
" for o in range(OPN): \n",
" max_nwot=max(df[df.apply(lambda x: x['arch'][l]==o, axis=1)]['naswot'])\n",
" best_nwot.append(max_nwot)\n",
"print(best_nwot)"
]
},
{
"cell_type": "code",
"execution_count": 1472,
"id": "a03d51ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n",
"[3. 2. 1. 4.]\n"
]
}
],
"source": [
"avg_rank=np.array([0.0,0.0,0.0,0.0])\n",
"\n",
"for i in range(LAYER):\n",
" #print(np.argmax(best_nwot[i*4:(i+1)*4]))\n",
" avg_rank +=(5-rankdata(best_nwot[i*4:(i+1)*4]))\n",
" #print((5-rankdata(best_nwot[i*4:(i+1)*4])))\n",
"print(avg_rank/LAYER)"
]
},
{
"cell_type": "code",
"execution_count": 1473,
"id": "5357b16a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SpearmanrResult(correlation=1.0, pvalue=0.0)"
]
},
"execution_count": 1473,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from scipy import stats\n",
"stats.spearmanr([x*-1 for x in pt_score], best_nwot)"
]
},
{
"cell_type": "code",
"execution_count": 1474,
"id": "4ebc7c45",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[557.9228, 597.0493, 597.9796, 550.9689]\n"
]
},
{
"data": {
"text/plain": [
"SpearmanrResult(correlation=1.0, pvalue=0.0)"
]
},
"execution_count": 1474,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(disc_score)\n",
"stats.spearmanr(disc_score, best_nwot)"
]
},
{
"cell_type": "code",
"execution_count": 1485,
"id": "02a823f1",
"metadata": {},
"outputs": [],
"source": [
"cor_dic={}\n",
"cor_dic['zc_pt(nwot)']=[1.0,0.80,0.82,0.83,0.79,0.69,0.66,0.65]\n",
"cor_dic['disc_zc(nwot)']=[1.0,0.85,0.89,0.71,0.63,0.35,0.28,0.07]\n",
"#cor_dic['zc_pt(nwot)_w/o_skip'] =[1.0,0.5217,0.5533,0.6655,0.7019,0.5058,0.5801,0.637]\n",
"#cor_dic['disc_zc(nwot)_w/o_skip']=[1.0,0.4638,0.5788,0.7332,0.7419,0.5711,0.6141,0.6845]"
]
},
{
"cell_type": "code",
"execution_count": 1486,
"id": "66c58b61",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import random\n",
"import statistics as stat\n",
"import itertools\n",
"marker = itertools.cycle(('^', 'x', 'o', 's', '*', '+', '1')) \n",
"color = itertools.cycle(('b', 'c', 'r', 'g', 'y', 'm', 'k')) "
]
},
{
"cell_type": "code",
"execution_count": 1488,
"id": "89d82376",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2EAAADxCAYAAABCgr/hAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAACPBklEQVR4nOzdd3hURdvA4d+m9wbpPZQQIISWhKYUERCkq4AoCLZXhc/XgooVFV/EXhAVFAUUKQrSQYqgtAABQighQBICIQkhvdf9/tjswrIbsgtpwHNf117KKXPmTLKbfc7MPKNQKpVKhBBCCCGEEEI0CJPGroAQQgghhBBC3EkkCBNCCCGEEEKIBiRBmBBCCCGEEEI0IAnChBBCCCGEEKIBSRAmhBBCCCGEEA1IgjAhhBBCCCGEaEAShAkhhBBCCCFEAzJr7AoIIYQQQgghhDFWr17NoUOHOHnyJKdOnaKkpIQpU6YwdepUo8vauHEjCxcu5NSpU5iYmNC+fXv+85//0L17d73Hnz9/ni+++II9e/ZQUFCAj48PI0aMYPLkyZibmxt0TQnChBBCCCGEELeUL7/8kpSUFBwdHXFzcyM5OfmGypk/fz6ffPIJLi4ujBo1CoANGzYwadIkPvvsMwYPHqx1fEJCAuPGjSM3N5d7770Xf39/Dhw4wGeffcbhw4eZO3cuJia1DzZUKJVK5Q3VWAghhBBCCCEawZ49e/Dz88PHx4eVK1cyffp0o3vCzp07x5AhQ7Czs+PPP//Ew8MDgLS0NEaMGIFSqWTr1q3Y29trzpkwYQJRUVHMmDGDcePGAaBUKnnppZdYv349s2fPZsSIEbVeW+aECSGEEEIIIW4pPXr0wMfH56bKWLlyJeXl5TzyyCOaAAzAw8ODRx55hJycHDZt2qTZnpSURFRUFL6+vowdO1azXaFQ8NJLLwGwbNkyg64tQZgQQgghhBDijhMVFQVAr169dPapt6mPAdi/fz+gCgAVCoXW8d7e3gQGBhITE0NZWVmt15Y5YUIIIYQQQohGlZeXR15ens52BwcHHBwc6uWaSUlJAPj7++vsU287d+6cZltiYiIAAQEBessLCAggMTGR8+fP06JFi+teW4IwIYQQQgghRL2Z82KPWo9RBo5jzpw5OttvNOOhIQoKCgC05nyp2dnZAWgFhtc7HsDW1lbnnJpIECaEEEIIIYRoVBMnTmTkyJE62+urF6yxSRAmhBBCCCGEaFT1OeywJnZ2dmRnZ5Ofn4+zs7PWPnWv19V1UveO5efn6y2vsLBQ55yaSGIOIYQQQgghRL1RKBS1vhqDem7X1fO+1NTbrp4vFhgYCFyZS3atpKQkTE1N8fX1rfXaEoQJIYQQQggh6o1CYVLrqzFERkYCsGvXLp196m3qYwAiIiIA1Rpl1y61nJKSQmJiImFhYVhYWNR6bQnChBBCCCGEELet/Px8zp49y6VLl7S2jxo1CnNzc3755RfS0tI029PS0vjll19wcnJi4MCBmu0BAQFERERw/vx5li5dqtmuVCr57LPPABgzZoxBdZI5YUIIIYQQQoh6Ux/DDVesWEF0dDRwZejg1q1bSUlJAaBLly48+OCDAGzZsoXp06czcuRIPvzwQ00Z/v7+PP/883zyySeMHDmSwYMHA7BhwwZycnL47LPPdOZ3zZgxg3HjxvHuu++yd+9e/Pz8OHDgAEeOHKFv374MGzbMoPpLECaEEEIIIYSoP/Uw3DA6OppVq1ZpbYuLiyMuLk7zb3UQdj1PPvkk3t7e/Pzzz/zxxx+YmJjQvn17nnnmGbp3765zfIsWLfj999/54osv2LNnD9u3b8fb25sXXniBxx9/HBMTw+5Vobx2QKMQQgghhBBC1JFvX+lT6zHPfLSj3uvRlEhPmBBCCCGEEKLeNFb2w6ZMgjAhhBBCCCFEvWms7IdNmQRhQgghhBBCiHojPWG6JCwVQgghhBBCiAYkQZgQQgghhBBCNCAZjiiEEEIIIYSoPzIcUYcEYUIIIYQQQoh6I4k5dEmLCCGEEEIIIUQDkp4wIYQQQgghRL2R7Ii6JAgTQgghhBBC1BsZjqhLWkQIIYQQQgghGpD0hAkhhBBCCCHqjwxH1CFBmBBCCCGEEKLeyHBEXdIiQgghhBBCCNGApCdM3FbmvNijsavQpE35bA+/fTyhsavRZI2btohfPxzX2NVo0sa/9hu/zBrb2NVosh6ZvpSln0xs7Go0aWNfXsia719s7Go0WcOe/oz1P77a2NVo0oY8Ppv1C6Y3djWarCGTZzV2FXRIdkRd0hMmhBBCCCGEEA1IesKEEEIIIYQQ9UZ6wnRJECaEEEIIIYSoP5KYQ4e0iBBCCCGEEEI0IOkJE0IIIYQQQtQbGY6oS4IwIYQQQgghRL2RdcJ03VSL5ObmkpaWRnl5eV3VRwghhBBCCCFuazfUE7Zp0ya++uorEhMTATAxMaFVq1b079+fhx56CDc3tzqtpBBCCCGEEOLWJMMRdRndE7Z+/XpeeOEFEhISUCgUODs7Y2JiQlxcHHPmzGHgwIH8+uuv9VFXIYQQQgghxK1GYVL76w5j9B3Pnz8fgMcff5z9+/ezZ88eYmJiWLp0KY899hgAM2fO5PPPP6/TigohhBBCCCHE7cDoIOzs2bOEhIQwbdo07OzsVIWYmNCxY0dee+01Vq1aRcuWLZk3bx579+6t8woLIYQQQgghbh0KhaLW153G6CBMoVAQFBRU4/6AgADmzp2LmZkZCxYsuKnKCSGEEEIIIcTtxujEHO7u7iQnJ1/3GF9fXyIiIjhy5MiN1ksIIYQQQghxG7gTe7pqY3RP2NChQ4mNja11qKGTkxMVFRU3XDEhhBBCCCHErU+hMKn1dacx+o6ffvppvLy8+L//+z/WrFmj95iysjIOHTpEcHDwTVdQCCGEEEIIIW4nRgdhlpaWfP/995iYmPDqq68yatQovv32W/bv38/p06fZuXMnjz/+OBkZGTz//PP1UWchhBBCCCHErUKhqP11h7mhxZpbtmzJmjVreO+999i2bRsnT57UOWbcuHHY2dlRXl6Oubn5TVdUCCGEEEIIceu5E4cb1uaGgjBQJej45ptvOHfuHFu2bGHv3r0cO3aM3NxcAH777Td+++03zMzMaNWqFe3btyc0NJQHH3ywziovhBBCCCGEELeaGw7C1Pz9/XniiSd44oknADh//jzHjh3TvE6cOKF5rVixQoIwIYQQQggh7iCSHVHXTQdh1/L19cXX15f77rtPsy0pKYnY2FiOHz9e15cTQgghhBBCNGEyHFFXnQdh+gQEBBAQEMDQoUMb4nJCCCGEEEII0WTVSRA2b948du3axaJFi+qiOCGEEEIIIcTtop6GI544cYKvvvqKQ4cOUVJSQlBQEGPHjmXMmDEGDYF89NFH2b9//3WP+b//+z+ee+45zb9fe+01Vq1aVePxR48exdLSstZr10kQlpiYyIEDB+qiKCGEEEIIIYS4roMHDzJp0iSUSiX33Xcfbm5u7Ny5k3feeYdTp07xzjvv1FrGyJEjiYiI0Ltv3rx5lJWVcdddd+ndP2HCBBwcHHS2m5qaGlT/BhmOKIQQQgghhLgz1XVijoqKCt544w3KysqYN28evXv3BuD5559n0qRJLFmyhCFDhtC1a9frljNq1Ci92w8ePMicOXNo06YNHTp00HvMxIkT8fHxueF7kFlyQgghhBBCiHqjUJjU+jJGVFQUSUlJREZGagIwAAsLC55//nkAli5desP1XbFiBQAPPfTQDZdRG+kJuwWdPn2a4cOH88YbbzB+/PjGrk6dmT17NsuXL2fLli24uLg0dnWM5urdGp/W4bj5tsHdNwSHZl4ALPnoEbLSEhq5dvVLoTAhuOtAAtv1xM7JncqKcjJTz3Iiah0ZF04ZWxpB7e8isH0vnFx9MTO3pKykiKz0RM4c2U7K2cP1cg91QaEwoU34fQS2vxt7Z3cqK8q4fPEMJ/at4dL5OGNLI6hDb1qE3o2jqy/m5laUlRSSlZZI/OGtpJyJ1jnD2T0A7xad8AzsgGNzH8wtrCgtKSTz4hniD/1FauLRurnRG6R
"text/plain": [
"<Figure size 936x936 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pathlib\n",
"import seaborn as sns\n",
"x = np.array([1, 2, 3, 4, 5, 6,7,8])\n",
"#plt.figure(figsize=(8.5,5.5))\n",
"idx=0\n",
"cor_df = pd.DataFrame.from_dict(cor_dic, orient='index')\n",
"# Set up the matplotlib figure\n",
"f, ax = plt.subplots(figsize=(13, 13))\n",
"sns.set(font_scale=1.8)\n",
"# Generate a custom diverging colormap\n",
"cmap = sns.diverging_palette(200, 55, as_cmap=True)\n",
"g = sns.heatmap(cor_df, cmap=cmap, center=0,\n",
" square=True, linewidths=.5, cbar_kws={\"shrink\": .267}, annot=True,xticklabels=x)\n",
"plt.xlabel('#N Layers')\n",
"plt.ylabel('Spearman-$\\\\rho$')\n",
"#plt.xticks(x)\n",
"#plt.legend(bbox_to_anchor=(0.55, 0.65), prop={'size': 13})\n",
"#plt.grid()\n",
"plt.savefig(pathlib.Path('op_correl_layer_increase_toy').with_suffix('.pdf'), bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 999,
"id": "7876c4d4",
"metadata": {},
"outputs": [],
"source": [
"rank_dic={}\n",
"rank_dic['zc_pt(nwot)'] =[4,4,4,4,2.6,2.67,2.71,2.75]\n",
"rank_dic['disc_zc(nwot)']=[4,4,4,2.5,2.2,2.33,2.28,2.25]"
]
},
{
"cell_type": "code",
"execution_count": 1087,
"id": "afca5839",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7YAAAEDCAYAAAABYph3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAACA8UlEQVR4nO3dd1yVZf8H8M9hKVNAGTIERAVUQGWjprlzgrtMNMejj6MnLXPUz6wsR5o7zcoUTXOLe4cLGSKioqiJgIIie2/O7w86J45w4GZ5QD/v14tXct33fd3f++Kc4HuuJRKLxWIQERERERERNVJKig6AiIiIiIiIqDaY2BIREREREVGjxsSWiIiIiIiIGjUmtkRERERERNSoMbElIiIiIiKiRo2JLRERERERETVqTGyJiIiIiIioUWNiS0RERERERI2aSnUviIyMxK5duxASEoKEhAQAgJGREVxcXPDBBx+gffv2dR4kERERERERkTwisVgsFnry1q1bsW7dOpSUlKCiy5SVlTF79mxMnz69ToMkIiIiIiIikkdwYnvq1CnMmTMHampqGD58OIYOHQpTU1MAQFxcHI4dO4ZDhw6hoKAAP/74I9577716DZyIiIiIiIgIqEZiO2bMGNy9exe//fYb3N3dKzwnKCgIH330Eezt7bF37946DZSIiIiIiIioIoIXj3r48CG6dOkiN6kFADc3Nzg7O+Phw4d1EhwRERERERFRVQQntmpqajA0NKzyPAMDA6ipqdUqKCIiIiIiIiKhBCe2nTt3RkRERJXnRUREoFOnTrWJiYiIiIiIiEgwwYntxx9/jPj4eKxcuRJFRUXljhcXF+OHH35AfHw8Pv744zoNkoiIiIiIiEgewYtHHTlyBGFhYdi3bx+MjY3Rv39/6arI8fHxOHPmDJ4/f44xY8ZU2GPr5eVVl3ETERERERERAahGYmtrawuRSCTdv1YkEskcl1cucf/+/drESURERERERFQhFaEnenl5yU1aiYiIiIiIiBRFcI8tERERERERUUMkePEoIiIiIiIiooaIiS0RERERERE1anLn2B45cgQA0KdPH2hpaUm/F4qrIBMREREREdHrIHeOrWQV5JMnT8LKykr6vVBcBZmIiIiIiIheB7k9tpJVkLW1tWW+JyIiIiIiImpIuCoyERERERER1Yqfnx8+//xzAMCyZcswfPjwal1//fp1bNmyBXfv3kVJSQlsbGwwYcIEvPfee4KuF7yPbXXk5+ejSZMm9VE1ERERERERNSAJCQlYunQpNDQ0kJOTU+3rT548iblz50JDQwODBg2CpqYmzp49i08++QRxcXGYMmVKlXUIXhV569atgs7Lz8/H9OnThVZLREREREREjdiiRYugo6ODsWPHVvvajIwMLFmyBKqqqvjjjz/w7bffYsGCBfDz84OlpSXWrl2L2NjYKusRnNj++OOPOHr0aKXnFBYWYubMmQgMDBRaLRERERERETVSu3fvxrVr1/Ddd99BQ0Oj2tefPn0a6enpGDx4MOzs7KTl2tramD59OgoLC3Hw4MEq6xGc2JqZmeGLL75AQEBAhccLCwsxe/ZsXL16Fd27dxdaLRERERERETVCsbGx+OGHH/DBBx/A3d29RnUEBQUBALp161bumKQsODi4ynoEJ7a//vorNDU1MXv27HJb+RQXF2Pu3Lnw9/eHp6cnNm7cKLRaIiIiIiIiagAyMjLw7Nmzcl8ZGRnlzi0pKcH8+fOhr6+Pzz77rMb3jI6OBgBYWFiUO2ZgYAANDQ3ExMRUWY/gxaMsLS2xdetWTJgwAVOnTsXevXthamqKkpISzJs3D+fOnYOLiws2b94MNTU14U9CRERERERE9WbjXE9B54mt3q+wk3LWrFmYPXu2TNlvv/2GsLAw7Nixo0ZDkCWysrIAQLrN7Ku0tLSQmppaZT3VWhXZwcEBa9aswcyZMzFlyhTs2rULy5cvx8mTJ9GlSxds3bqVqyETERERERE1QhMmTIC3t3e5ch0dHZnvHz58iPXr12PcuHFwc3N7XeFVqtrb/fTs2RNLlizB//3f/2HAgAHIzMyEvb09tm7dCnV19fqIkYiIiIiIiOqZjo5OuSS2IvPnz4eRkRE+/fTTWt9TS0sLAJCZmVnh8aysLEExCZ5jW9aoUaMwa9YsZGZmon379ti2bZs0ICIiIiIiInpz3bt3D0+fPkXnzp1hY2Mj/ZIMY164cCFsbGywYcOGKuuytLQEgArn0SYmJiInJ6fC+bevkttj6+PjU/XFKioQi8WYOXOmTLlIJMKOHTuqvJ6IiIiIiIjql0gkqtP6Ro4cWWH5vXv3cO/ePbi4uMDCwgLt27evsi43NzccP34cV69exaBBg2SOXb16FQDg6upaZT0isVgsruiAra1tlRfLrVQkKrdyMhEREREREb1+mz7tKui8mauv1eo+GzZswMaNG7Fs2TIMHz5c5lhubi7i4+Ohrq4OExMTaXlGRgb69OmD3Nxc7Nu3T7qXbWZmJkaOHIm4uDicPHkSrVq1qvTecntsfX19a/NMRERERERERACA27dvw8fHB66urti5c6e0XEdHB1999RU+/fRTjBs3DoMGDYKmpibOnj2LuLg4fPbZZ1UmtUAlia2Q7l4iIiIiIiJq2ESiGi2t9NoMGjQI+vr62Lx5M06cOIGSkhK0a9cO8+bNw3vvvSeoDrlDkYmIiIiIiKjx++mz7oLOm7HqSj1HUn+qvd1PWQkJCbh8+TJSU1NhZGSEd955B3p6enUVGxEREREREVGV5Ca2jx49wuHDh9G+fXsMHjy43PF9+/Zh6dKlKCwslJZpaGhgxYoV6NOnT/1ES0RERERERPQKuYOtz5w5g99//x1qamrljt26dQtLlixBQUEBmjZtCjs7O+jo6CA7Oxuffvopnj59Wq9BExERERERkTAikUjQV2MmN7G9efMmmjZtip49e5Y79vPPP6OkpAStW7fGmTNncOjQIVy/fh3e3t7Iz8/H7t276zNmIiIiIiIiIim5iW1sbCw6dOhQrse2oKAAV69ehUgkwmeffQZDQ8PSipSU8OWXX0JLSwuBgYH1GzURERERERHRP+QmtikpKTA2Ni5XfvfuXRQWFqJp06bo1q2bzDFNTU107NiRQ5GJiIiIiIgaCpGSsK9GTG70hYWFyMnJKVceEREBALCzs6tw/m2LFi2Ql5dXhyESERERERERySc3sTUwMMDjx4/LlYeGhkIkEsHBwaHC67Kzs6Grq1tnARIRERERERFVRm5i27lzZ8TGxuLs2bPSssTERFy6dAkA0LVr1wqve/jwoXTeLRERERERESnW27Aqstx9bMePH4+TJ0/i008/xYABA9C8eXOcOXMGubm5sLCwqDCxffz4MeLi4uDu7l6vQRMRERERERFJyE1sO3XqhE8//RSrV6/GsWPHIBKJIBaLoa6uju+//x5KSuU7e/fv3w9Afm8uERERERERUV2Tm9gCwJQpU9CtWzecPn0aKSkpaNmyJYYNGwYTE5MKz2/SpAl8fHyY2BIRERERETUQoka+4rEQlSa2AGBrawtbW1tBlc2ZM6fWARERERERERFVx5ufuhMREREREdEbrcoeWyKi6to411PRITRYs34MYPtUgW1UObZP1Wb9GIAT2xYqOowGa9CkZTj+6+eKDqNBGzxlJY79Mk/RYTRYQ6b+gKM/z1V0GA3a0Gk/KjoEGY19xWMh2GNLREREREREjRoTWyIiIiIiImrUmNgSERERERFRo8Y5tkRERERERG8yzrH9V1ZWFrKysuozFiIiIiIiIqJqE5zYOjs746OPPqrPWIiIiIiIiIiqTfBQZC0tLVhYWNRnLERERERERFTHRKI3f2klwU/Yrl07xMXF1WcsRERERERERNUmOLH98MMPERYWhuvXr9dnPERERERERETVIngocqdOnfD+++/jP//5D0aOHInevXvDxMQETZs2rfB8ExOTOguSiIiIiIiISB7BiW3v3r0BAGKxGH/++Sf+/PNPueeKRCLcu3ev9tERERERERFRrYjegu1+BCe2LVu2rM84iIiIiIiIiGpEcGJ
"text/plain": [
"<Figure size 936x936 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pathlib\n",
"x = np.array([1, 2, 3, 4, 5, 6,7,8])\n",
"idx=0\n",
"cor_df = pd.DataFrame.from_dict(rank_dic, orient='index')\n",
"# Set up the matplotlib figure\n",
"f, ax = plt.subplots(figsize=(13, 13))\n",
"#sns.set(font_scale=1.8)\n",
"# Generate a custom diverging colormap\n",
"cmap = sns.diverging_palette(200, 55, as_cmap=True)\n",
"g = sns.heatmap(cor_df, cmap=cmap, center=0,\n",
" square=True, linewidths=.5, cbar_kws={\"shrink\": .267}, annot=True,xticklabels=x)\n",
"plt.xlabel('#N Layers in Toy model')\n",
"plt.ylabel('Everage Rank for Skip')\n",
"#plt.legend(bbox_to_anchor=(0.55, 0.65), prop={'size': 13})\n",
"#plt.grid()\n",
"plt.savefig(pathlib.Path('skip_layer_increase_toy').with_suffix('.pdf'), bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 1489,
"id": "bcc59139",
"metadata": {},
"outputs": [],
"source": [
"pt_dic={}\n",
" #[3. 1.66666667 1.33333333 4. ] \n",
"pt_dic['skip'] = [3,3.00,3.00,3.00,3.00,2.83,2.71,2.75]# 00000011\n",
"pt_dic['skip_select']=[0,0,0,0,0,0,1,1] \n",
"pt_dic['conv_1x1']= [2,1.50,1.67,1.75,1.60,1.83,1.71,1.75]#01112222\n",
"pt_dic['conv_1x1_select']=[0,1,1,1,2,2,2,2]\n",
"pt_dic['conv_3x3']= [1,1.50,1.33,1.25,1.40,1.33,1.57,1.5]#21233445\n",
"pt_dic['conv_3x3_select']=[2,1,2,3,3,4,4,5]\n",
"pt_dic['avg_pooling']=[4,4.00,4.00,4.00,4.00,4.00,4.00,4.00]#00000000\n",
"pt_dic['avg_pooling_select']=[0,0,0,0,0,0,0,0]"
]
},
{
"cell_type": "code",
"execution_count": 1490,
"id": "30f0c50c",
"metadata": {},
"outputs": [],
"source": [
"disc_dic={}\n",
"#[3. 1.33333333 1.66666667 4. ]\n",
"disc_dic['skip'] = [3,3.00,3.00,2.75,2.40,2.17,2.14,1.75]#00001225\n",
"disc_dic['skip_select']=[0,0,0,0,1,2,2,5] \n",
"disc_dic['conv_1x1']= [2,2.00,1.33,1.50,2.00,2.00,2.57,2.75]#00221321\n",
"disc_dic['conv_1x1_select']=[0,0,2,2,1,3,2,1]\n",
"disc_dic['conv_3x3']= [1,1.00,1.67,1.75,1.80,2.17,2.00,2.25]#22123132\n",
"disc_dic['conv_3x3_select']=[2,2,1,2,3,1,3,2]\n",
"disc_dic['avg_pooling']=[4,4.00,4.00,4.00,3.80,3.67,3.29,3.25]#00000000\n",
"disc_dic['avg_pooling_select']=[0,0,0,0,0,0,0,0]"
]
},
{
"cell_type": "code",
"execution_count": 1491,
"id": "acbdfcc6",
"metadata": {},
"outputs": [],
"source": [
"best_dic={}\n",
"\n",
"best_dic['skip'] = [3,3.00,3.00,3.00,3.00,3.00,3.00,3.00]#00000000\n",
"best_dic['skip_select']=[0,0,0,0,0,0,0,0] \n",
"best_dic['conv_1x1']= [2,1.50,1.67,2.00,1.80,1.50,1.42,1.5]#01101344\n",
"best_dic['conv_1x1_select']=[0,1,1,0,1,3,4,4]\n",
"best_dic['conv_3x3']= [1,1.50,1.33,1.00,1.20,1.50,1.57,1.5]#21244334\n",
"best_dic['conv_3x3_select']=[2,1,2,4,4,3,3,4]\n",
"best_dic['avg_pooling']=[4,4.00,4.00,4.00,4.00,4.00,4.00,4.00]#00000000\n",
"best_dic['avg_pooling_select']=[0,0,0,0,0,0,0,0]"
]
},
{
"cell_type": "code",
"execution_count": 1496,
"id": "5c0200f1",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1sAAAKHCAYAAAB3t1LlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAC1iklEQVR4nOzdeVxU1fsH8A87IrvsO7K7Icjimopb5gYuKZq4lZZp5tLPLCtLM0szTdPUr+aS4q6oqbnimoCCKIKACCiiCLKJyD6/P4jRCZwBZJgZ+bxfL17JvWfOfebpzgzPnHPPVRIIBAIQERERERFRg1KWdQBERERERERvIhZbREREREREUsBii4iIiIiISApYbBEREREREUkBiy0iIiIiIiIpYLFFREREREQkBSy2iIiIiIiIpEBV1gEQKZqT2xfKOgS51mfMV8yRGMyPZMyReMyPZH3GfIUzu5bIOgy55jfyc4Tu/VnWYcitHsNn4+yepbIOQ671HPGZrENQCBzZIiIiIiIikgIWW0RERERERFLAYouIiIiIiEgKWGwRERERERFJAYstIiIiIiIiKWCxRUREREREJAUstoiIiIiIiKSAxRYREREREZEUsNgiIiIiIiKSAhZbREREREREUsBii4iIiIiISApYbBEREREREUkBiy0iIiIiIiIpYLFFREREREQkBSy2iIiIiIiIpIDFFhERERERkRSw2CIiIiIiIpICFltERERERERSwGKLiIiIiIhIClhsERERERERSQGLLSIiIiIiIilgsUVERERERCQFLLaIiIiIiIikgMUWERERERGRFKjKOgAiqj1rZ28YmNpCW98E6hpaUFFTR2nxc+RnpyMtMRJZDxJlHaLMMUfiMT+SMUeSNcUcKSkpQ9/EBkYWDtA3tkEzbQMoKyuj+HkBch6n4t7tcBQ+fVLnfo0snWHp0B46BmZQVddAeWkJCvIy8TD5Bh6lxEjhmUhHeXkF4pPTcTP+HhLuPkRmdh7Kyyugp9scri0t0KdrO5ibGNS6v817Q/FPVILEdoN6dcBAvw6vE3qjqcpRTPx9JCQ/RGZ2fmWOdLTg0tICfbq2rVOOqkTFpuBCeBzupT9BYVExNDXUYWlqgM6ezujo4QQlJSUpPBuqLRZbJJf8/PwAAGfOnJHYdv/+/Zg3bx5++OEHDB06VNqhyZRd685Q19BCQe5j5OY/QXl5GbS09WFs6QxjS2ekxl1BQuRJWYcpU8yReMyPZMyRZE0xR/om1vDoMQoAUFSYj5yMFAgEAugYmMKipTtMbVvj1uUQZKXXvtB0bN8LNi7eEAgEyMtKQ/Hzp9DQ1IG+kRUMTGzQwrwlbv1zSFpPqUElJKdj5R9HAQAGes3h6mAJZWUl3Et/gkvX4hEWfQcfjOyF9q3satWfo63ZK/cVl5Qi8lYyAMDZ3uK1Y28sCckP8evm4wAAA91/c6SkhHvpWbgcmYDwG0l4f6Qf2rvZ1rrP3Uev4MzlGCgpAS2tTWGg1xy5+YW4k5qBxJRHuJWYhvdH+knrKVEtsNgiUiA3Lu7D0+xHqCgvE9mub2wNj56jYevWERn3YpGX9UBGEcoecyQe8yMZcyRZk8yRAMi4F4f7CRHIf5L+YruSEhzavgVbt05w8x2Af/76HWUlRRK70zEwhY2LN8rLShB5NhhPsx8K92nrm8KzZyBMbVrhUUoMnjy8K41n1KCUlJTQoU1L9O7SFi1tTIXbKyoqcPBkBP4+H43N+0KxyG4UtLU0JfbX1dsVXb1da9x3MeI2Im8lw8hQB052ry7K5I2SkhI829hX5sjaRLi9oqICISev4u8LN7Bl3zk4znq3VjlKfZCFM5djoK6mipmT3oG91Ys+76VnYfnGv3D15l34tndCWxdrqTwnkozXbJHC69OnD44ePYo+ffrIOhSpy8tMq/bHDQDkZt5HRuotAEALs5aNHZZcYY7EY34kY44ka4o5ynmcilv/hIgWWgAgECDpxjk8y38CNXVNGFk41qo/fZPK0YvMB4kihRYAFORmION+HABAt4Xl6wffCFwdLDE5sLdIoQUAysrKCOjrA1MjPTwvKsHN+HuvfazLUfEAgM4eLgo1Rc7VwQKTR/USKbSAyhz59/UW5igm/n6t+ktIrjwX3d1sRQotALCxMIJ3WwcAQPL9xw0QPdUXiy1SeDo6OnBwcICOjo6sQ5EpgaACAFBRUS7jSOQXcyQe8yMZcyRZU81RQW7lH7QazbRr1b6mYrUmpcXP6x2TvFBSUoKVWQsAQG7+s9fq6/GTPCSlZkBJCejk6dwQ4ckFJSUlWJoZAqh9jlRVVWrVrrmWRr3jotfHYosa3alTpxAUFIQuXbqgTZs26NatG8aOHYvg4GCJj921axdatWqFYcOG4cmTyguR9+/fDxcXF+zfv1+krYuLC8aOHYtHjx5h1qxZ8PX1hbu7O0aOHIlz585J5bnJiraBKUxtWkFQUYGs9DuyDkcuMUfiMT+SMUeSNeUcaelULmxQUlS7P5RzMlJQUVEBY0sn6Biai+zT1jeFqbUbykqL8fjfES5F9/hJHgBAV1vrtfr5J7Jy0QyXlpYw1K9dYasoMp/kA6h9jtz+vS4uOi4VyWmio1f30rMQcTMJmhpq8Gr7Zo0yKxpes0WNKjg4GAsWLICxsTH8/PxgYGCAJ0+e4Pbt2zh48CACAwNf+dhVq1Zh9erV6NatG3799VdoaUl+M8rLy0NgYCAMDQ0xYsQIZGdn49ixY5gyZQqWL1+Od955pyGfXqOxcuoAvRaWUFZRgWZzPegZWUFQUY64iGPCb1ebOuZIPOZHMuZIMuaokoGpLXQMzFBeXlbr66sKn2bjzvXTcGrfC169gyoXyCh8CvVm2tA3skJBfhZuRxyrdfEmz+LupOH+wydQVVVBG6f6XzskEAhw5XrlAiRv0qgWAMTdeSDMUWtnq1o9xsxYHyP6d8Tuo1fw07pDcLAxhb5uc+Q9rVwgw8LUAO/5d4WezusVuPR6WGxRo9qzZw/U1NQQEhKCFi1aiOzLzs6u8THl5eVYsGABdu/ejYCAACxatAiqqrU7dePj4zFw4EAsW7ZMOK87KCgIw4cPx4IFC9C9e3c0b9789Z6UDBiY2MLMrrXw9/KyEsRfO4EHSddlF5ScYY7EY34kY44kY44ANfVmcPOu/OLu/u2wOhVHaYnXUFSYDzefAdA3flGElJeVIvtRMp4X5DR4vI2t4FkRth44DwDo07Ud9HTr/4f/7bvpyM4tgKaGGjxb2TdUiDJX8KwI2w5cAAD06dK2TsVRz06tYaDXHFv3n8ed1AzhdjVVFbRytISJoW6Dx0t1w2KLGp2qqmqNxZKhoWG1bUVFRZg5cybOnDmDyZMnY/bs2XU6loqKCmbNmiVyAa2rqyuGDBmCvXv34vTp0xg8eHDdn4SM3by0Hzcv7Yeyihq0dA1h4+KDVr4DYWrjhuvndtf6WoA3GXMkHvMjGXMkWVPPkbKyCtp0CYBmcz3kPL6H5FuX6vR4x/Z+sHHxwcPkG0i9HYaiZ3nQbK4HW9eOsHX1hZG5A66d3oay0mIpPQPpKi0tw+87TiI7twBOduYY9Jr3w7ry7xRCr7YOUFd/M/6ELS0tw7rgU8jOK4CTnRkG+nnW+rECgQD7jofh1KUYdPJwQp9u7WCkr4Os3Kc4cf4GTl68iZvx9/F/kwdBqxmv25IVXrNFjWrQoEF4/vw5BgwYgMWLF+PUqVOvHNEqKirC+PHjERoaiq+++qrOhRYAmJubw9Ky+kpOXl5eAIC4OMWeC19RXoqCnAzEXjmMB3ei0MLcAXatOss6LLnCHInH/EjGHEnWFHOkpKSE1p39YWBig/zsR7hxcZ9wcZDaMLNrCxsXH2Q9uIO48KMozH+CivIyFOY/QVz4X8hKT0JzPSPYuPhI8VlIT3l5BdbvPIXElIewsTDCx2P7QUWl/n92FhWXIiq28t5and+QKYTl5RXYsOsMElMewcaiBaa+17dOOfonKhGnLsWgrYs1xg3rDgsTA6irq8LCxADjh3dHa2crPMrMxcmLN6X4LEgSFlvUqCZMmIAff/wRFhYW2LZ
"text/plain": [
"<Figure size 936x936 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pathlib\n",
"x = np.array([1, 2, 3, 4, 5, 6,7,8])\n",
"idx=0\n",
"pt_df = pd.DataFrame.from_dict(pt_dic, orient='index')\n",
"# Set up the matplotlib figure\n",
"f, ax = plt.subplots(figsize=(13, 13))\n",
"#sns.set(font_scale=1.8)\n",
"# Generate a custom diverging colormap\n",
"cmap = sns.diverging_palette(200, 55, as_cmap=True)\n",
"g = sns.heatmap(pt_df, cmap=cmap, center=0,\n",
" square=True, linewidths=.5, cbar_kws={\"shrink\": .267}, annot=True,xticklabels=x)\n",
"plt.xlabel('#N Layers')\n",
"plt.ylabel('')\n",
"#plt.legend(bbox_to_anchor=(0.55, 0.65), prop={'size': 13})\n",
"#plt.grid()\n",
"plt.savefig(pathlib.Path('op_select_pt').with_suffix('.pdf'), bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 1497,
"id": "c96d7468",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1sAAAKHCAYAAAB3t1LlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAADKsUlEQVR4nOzdd1hU19YG8HeoinTpiIh0RBEEUcTYNbGCHY0l6jU3Rk005jPFVBNjEuPVYDTq1cSKXbH3LqEoiIUqVVQQpCigAjPz/TGXiQSkDzMD7+95eG44Z88+66x7ZmTN3mcfgVgsFoOIiIiIiIgalYq8AyAiIiIiImqOWGwRERERERHJAIstIiIiIiIiGWCxRUREREREJAMstoiIiIiIiGSAxRYREREREZEMsNgiIiIiIiKSATV5B0CkbM7sWCrvEBTaoMlfMEfVYH5qNmjyF7iw92d5h6Gw+o37GCHH1ss7DIXmM+xdRJzZIu8wFJrXoGmIOLtN3mEoLK+BU3A7JFjeYSi0zj6j5B2CUuDIFhERERERkQyw2CIiIiIiIpIBFltEREREREQywGKLiIiIiIhIBlhsERERERERyQCLLSIiIiIiIhlgsUVERERERCQDLLaIiIiIiIhkgMUWERERERGRDLDYIiIiIiIikgEWW0RERERERDLAYouIiIiIiEgGWGwRERERERHJAIstIiIiIiIiGWCxRUREREREJAMstoiIiIiIiGSAxRYREREREZEMsNgiIiIiIiKSARZbREREREREMsBii4iIiIiISAZYbBEREREREckAiy0iIiIiIiIZYLFFREREREQkAyy2iIiIiIiIZEBN3gEQUe1ZOXjBwNQa2vom0NDUgqq6BkpfPsfT3IfISIxEzoNEeYcod8xR9VpqfoRCEeJTHuJO/H0kpDxCdu5TCIUi6OlowbGjBQb5doa5iUGd+42KScWV8FikP3yC4hcv0UpTA5amBvDxcEAPd3sIBAIZnE3jKxMKEX8vA9ExyYhLysDjnHwIhSLo67WBs50V3uznCQvTtg06xoPMHHz9yw6UCYXo0M4UXy2c3EjRN40yoRCxCem4efceYhPTkJWdB6FIBAM9bbg4dMDQAd6wNDOqdX8FTwtx824Sbt69h+S0R8h/Wgh1NTWYm7ZFd3cnDOnjCQ0NdRmeUeOS5CdNkp+ENGRl51bMz8CedcqPSCRCxM04pKQ/QnLaQ6SkZ+L5i5ewaW+OpYtnyvBMZCs5NQO3Yu4hKTUDSSn38TgnDwCwculCtG9nVq8+w27cwekLoUhJe4Ci58/RupUm2rczR39fT/Tp1U1pPoeaKxZbpJD69+8PADh//nyNbQ8cOIBPP/0UP/zwA0aPHi3r0OSqQycfaGhqoTD/MfKfPoFQWAYtbX0YWzrA2NIBabGhSIg8I+8w5Yo5ql5LzU9CyiP8+udJAICBbhs42VpCRSBA+sMchEQmIPxWEmZN6I+uzta17nPP8VCcD7kDgQDoaGUKA702yH9ajHtpWUhMzcTdxAzMmtBfVqfUqOKTMrBi/X4AgIG+Nlwc2kNFIEDag8e4En4Xf0XG4b2pw+Dhalev/kUiETbvOg2hSNiYYTepuMR0/PhbEADAUF8Hrk42UBEIkJqRiUt/RSMk4g7mzvBHty4Otepvx8FzCIm4CxUVAawtTWFvY4nCoudITHmA3cEXcC38Dj7/YDJ0tLVkeVqNJi4xDT+u2QnglfyoqCD1/qNX8jMa3dwca9Xf8xclCNx0QJYhy8Xew+cQEXW30fr7I+gwjp2+CoFAAEc7axga6CEv/yniElIQE5+MyNvxWPiecn2x0dyw2CJSIreu7sez3EyIhGUVtusbW8G93yRYO/dAVnoMCnIeyClC+WOOqtdS8yMQCODhaoOBvTqjo5WJdLtIJELwmes4deUWtuy/BLuF46Gt1arG/tIe5OB8yB1oqKthwcyhsGn3d5/pD3OwctMxXL+dDO+u9ujsaCWTc2pMAoEAXm4OGNLHA7YdLKTbRSIR9h+/huPnI/Dfnafw0+eW0G7Tus79n7x4A8npmejn44YLIdGNGXqTEQgE6O7uhKH9vWFnYyndLhKJsOfIJRw98xfWbzuCX756r1YFknab1hg3vA/6+rhBT1dbuj03/xlWrNuN9AePsW3/GcyZNkom59PYJPlxxtAB3rCzaSfdLhKJsOfwRRw9E4L12w7jF9v3a5UfVVUV9PJyhU17c9i0N0fxi5f4Zd1uWZ5Ck3CwbY/27cxg28ESth3aYcmydch+klevvpJSM3Ds9FVoaqjj68Xvwr5je+m+5LQH+PrH9QgJj0ZfHw94uDk31ilQHfGeLVJ6gwYNwvHjxzFo0CB5hyJzBdkZlf5IBoD87PvISpN8U9bWrGNTh6VQmKPqtdT8ONlaYPbEARUKLQBQUVGB32AvmBrp4fmLEtyJv1+r/hJSHgIA3JytKxRaANDewghenW0BACn3HzdC9LLnYt8ec6YNr1BoAZL8jB3mCzMTAzx/8RLRMSl17vvR41wcOhmCrp1s0b1r7UZ9FFEnxw6YP3N0hUILkORowsi+MDc1RPHzl7h5N6lW/U0dOxij3uxVodACJKNC0ye8CQCIuBmPsjLlGA3s5GiD+bPGVCi0gP/lZ1Q/mJu2/V9+7tWqv1aaGnhvuh/e7O8NR7v20FSiKZXV8R/WDwGjh6C7hyvaGuo3qK+7cZJrzcujU4VCCwA6WluiV3c3AEBCcnqDjkMNw2KLlJ6Ojg5sbW2ho6Mj71DkSiwWAQBESjxNR9aYo+q11PwIBAJYmhkCAPKfFtXqNWpqqrVq10ZLs95xKQqBQAArc2MAQF5BYZ1eKxKJ8cfu01BTU8XUMcoxpbI+BAIBrCwkRXde/rMG92fdzhQAUFpahmdFzxvcn7w1dn5IQl2tdhPUdNq0kXEkVB0WW9Tkzp49i6lTp6JXr15wdXVF7969MWXKFAQFBdX42t27d8PFxQVjxozBkydPAEju2XJ0dMSBAxXndjs6OmLKlCnIzMzEwoUL4e3tDTc3N0yYMAGXLl2SybnJi7aBKUzbu0AsEiHnYe2+NWxpmKPqtfT8ZD95CgDQreX9Mc62llBRESA6Ng0pGRVHr9If5iDidhJaaarDs3PzGCXMyskHAOjp1u3+obNXo5CY8hDjhveGgX7z/kIsK1syFUxPt+F/2Jb3paqqUqtprcogK6fx8kMSXTrZQ0VFBRGRd5H4j9Gr5LQHuBYejdatNOHTvYucIiSA92xREwsKCsLXX38NY2Nj9O/fHwYGBnjy5Ani4uJw6NAhBAQEvPa1gYGBWLNmDXr37o1ff/0VWlo1/6NfUFCAgIAAGBoaYty4ccjNzcWJEyfw7rvvYuXKlRg6dGhjnl6TaWffDXptLaGiqopWbfSgZ9QOYpEQsREnUJivHNOWZI05qh7z87fYew9w/9ETqKmpopNDu5pfAMDMWB/j3uqBPcdD8dP6w7Btbwp93TYoeCZZIMPC1ABv+/lCT0c5Fjeozt2ENKQ/eAw1NVV0dupQ69c9fpKP/cevwqGjJfr2bN5/7N2JS0FaRhbU1VTRxcW2wf0dPfsXAKCLsy3U1ZX/T7U7cclIu5/ZaPkhCUtzE0wPGIE/dx7GZ9/9Bkc7a7Q1LF8gIxVW7czw7+ljYKCvK+9QWzTlfweTUtm7dy/U1dURHByMtm0rLiOcm5tb5WuEQiG+/vpr7NmzB/7+/vjuu++gVsuh8/j4eAwfPhwrVqyQLn06depUjB07Fl9//TX69OmDNko4vG5gYg2zDp2kvwvLShB/4zQeJN2UX1AKhjmqHvMjUVj0AtsOXgEADOrVuU7FUb+enWCg1wZbD1zGvbQs6XZ1NVW42FnCxFD5/8B5Vvgcf+w+DQB4s68n9P9xf9HriMVi/LH7DEQiMaaPH9Ssl55+VliMjTuOAQCGDvCGgV7tcvQ6oZExCIm4C3U1VYwf2acxQpSrZ4XF2Lj9KABg6MAeMNBr3iOcTW3owF4wMtTDb5v2Ii4xVbpdQ10Nbp3sYWZS++X2STZYbFGTU1NTq7JYMjQ0rLTtxYsXWLBgAc6fP4/Zs2f
"text/plain": [
"<Figure size 936x936 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pathlib\n",
"x = np.array([1, 2, 3, 4, 5, 6,7,8])\n",
"idx=0\n",
"disc_df = pd.DataFrame.from_dict(disc_dic, orient='index')\n",
"# Set up the matplotlib figure\n",
"f, ax = plt.subplots(figsize=(13, 13))\n",
"#sns.set(font_scale=1.8)\n",
"# Generate a custom diverging colormap\n",
"cmap = sns.diverging_palette(200, 55, as_cmap=True)\n",
"g = sns.heatmap(disc_df, cmap=cmap, center=0,\n",
" square=True, linewidths=.5, cbar_kws={\"shrink\": .267}, annot=True,xticklabels=x)\n",
"plt.xlabel('#N Layers')\n",
"plt.ylabel('')\n",
"#plt.legend(bbox_to_anchor=(0.55, 0.65), prop={'size': 13})\n",
"#plt.grid()\n",
"plt.savefig(pathlib.Path('op_select_disc').with_suffix('.pdf'), bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 1498,
"id": "b9a92916",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1sAAAKHCAYAAAB3t1LlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAACo5klEQVR4nOzdeXxM9/7H8XckYk0QhCyWWBJBEULU2lqrtUUtpRWK6q2lvVW96te9XFRbl1JaSktLKEpQVNVWlMRSW0JsQWwNIcSW9fdHbqbNTcwkkcnMJK/n4+FROec73/OZT88Z+cz3e77HLjU1NVUAAAAAgDxVxNIBAAAAAEBBRLEFAAAAAGZAsQUAAAAAZkCxBQAAAABmQLEFAAAAAGZAsQUAAAAAZkCxBQAAAABm4GDpAABbs2LGcEuHYNV6vzaXHBlBfkwjR8aRH9PIkWnkyDjyY1rv1+ZaOgSbwMgWAAAAAJgBxRYAAAAAmAHFFgAAAACYAcUWAAAAAJgBxRYAAAAAmAHFFgAAAACYAcUWAAAAAJgBxRYAAAAAmAHFFgAAAACYAcUWAAAAAJgBxRYAAAAAmAHFFgAAAACYAcUWAAAAAJgBxRYAAAAAmAHFFgAAAACYAcUWAAAAAJgBxRYAAAAAmAHFFgAAAACYAcUWAAAAAJgBxRYAAAAAmAHFFgAAAACYAcUWAAAAAJgBxRYAAAAAmAHFFgAAAACYgYOlAwCQfTUbPqmKHt4qU8FDxUo4yaFocSXcj9eNP8/pzJHtunz2iKVDtDhyZBz5MY0cmUaOjCM/ppEj48hPwUGxBavUrl07SdKWLVtMtv3xxx81fvx4TZ48Wb169TJ3aBbl0+QpFStRWnHXL+r2jatKTkpUqTIV5ObVQG5eDRR54Bcd/m25pcO0KHJkHPkxjRyZRo6MIz+mkSPjyE/BQbEF2JC9G+bqZswFJSclZNhe3r2WWvd4Vd6NOyr65H7FXjljoQgtjxwZR35MI0emkSPjyI9p5Mg48lNwcM8WbF7Hjh21fv16dezY0dKhmN31y6czffBK0vVLp3QhMkySVKmqb36HZVXIkXHkxzRyZBo5Mo78mEaOjCM/BQcjW7B5Tk5OcnJysnQYFpeSmpL23+QkC0divciRceTHNHJkGjkyjvyYRo6MIz+2hZEt5LvNmzcrKChILVu2VP369dW6dWsNHDhQwcHBJl+7bNky1a1bV88++6yuX78uKe2eLR8fH/34448Z2vr4+GjgwIG6cuWKxowZo4CAADVs2FD9+vXT9u3bzfLeLKVMBU9Vqe2v1JQUXTl31NLhWCVyZBz5MY0cmUaOjCM/ppEj48iP7WFkC/kqODhYH3zwgSpWrKh27dqpXLlyun79uo4fP67Vq1erf//+D33tzJkzNWvWLLVu3Vqff/65SpYsafJ4cXFx6t+/v1xcXNSnTx/FxsZqw4YNevnllzVt2jQ9/fTTefn28k2Nx9rKpbKXitg7qJRTebm4eSklOVkHti5R3LWLlg7PKpAj48iPaeTINHJkHPkxjRwZR35sH8UW8tXy5ctVtGhRhYSEqHz58hn2xcbGZvma5ORkffDBB/rhhx8UGBioiRMnysEhe6fuiRMn1LVrV3366aeys7OTJAUFBal379764IMP1LZtW5UqVerR3pQFVPTwVhWfpoafkxIf6NCOH3T22E4LRmVdyJFx5Mc0cmQaOTKO/JhGjowjP7aPYgv5zsHBIctiycXFJdO2+/fv6/XXX9eWLVs0fPhwvfHGGzk6lr29vcaMGWMotCSpTp066tGjh1asWKFff/1V3bt3z/mbsLC9G+dp78Z5sndwVOmyrqrt115N2g+UZ60m2r1udpY31RY25Mg48mMaOTKNHBlHfkwjR8aRH9vHPVvIV926ddO9e/f0zDPPaNKkSdq8efNDR7Tu37+vwYMHa9u2bXr33XdzXGhJkpubmzw8PDJt9/f3lyRFRETkuE9rkpyUoLhr0dr3y0KdPbZTlarVlXeTTpYOy6qQI+PIj2nkyDRyZBz5MY0cGUd+bBfFFvLViy++qI8//lju7u767rvvNHLkSLVo0UKDBw/OVPjcuXNH4eHhcnZ21uOPP56r41WoUCHL7elTGG/fvp2rfq3RufDfJUnuNRpZNhArRo6MIz+mkSPTyJFx5Mc0cmQc+bEtFFvIdz179tQPP/ygvXv3au7cuerdu7f27t2rIUOG6MaNG4Z25cuX15w5c3T//n0NHDhQkZGROT7WtWvXstyevpJhQVoy/sG9tMKxWInSFo7EepEj48iPaeTINHJkHPkxjRwZR35sC8UWLMbZ2Vlt27bVxIkTFRgYqNjYWO3fvz9Dm5YtW+qrr77S3bt3FRQUpOPHj+foGJcvX9bFi5lX69m3b58kyde34DwQsKKnjyQp/maMhSOxXuTIOPJjGjkyjRwZR35MI0fGkR/bQrGFfLVnzx6lpqZm2p5+31aJEiUy7WvevLnmzZunBw8eaNCgQQoPD8/28ZKTkzVt2rQMxzx+/LhCQkLk7Oysdu3a5eJdWEZ5t5qqXL2+9LfFPtK5eTVQ/RY9JUlnj/2Wz5FZD3JkHPkxjRyZRo6MIz+mkSPjyE/BwmqEyFejRo1SyZIl1ahRI3l4eCg1NVX79u3TkSNH1KBBAwUEBGT5uqZNm2r+/Pl66aWXNHjwYH399ddq0KCByeP5+PjowIED6t27tx5//HHDc7aSkpL04YcfqnRp2xmCL13WVU07vagH9+J1M+a8Hty9raLFSsqpXCWVLusqSTp5cLMunAi1cKSWQ46MIz+mkSPTyJFx5Mc0cmQc+SlYKLaQr9544w399ttvOnbsmLZv365ixYrJw8NDb775pvr372/0+VmNGzfWggULNHToUA0ZMkRff/21GjVqZPR4ZcqU0dy5czV16lQtX75c9+7dU506dTRy5Ei1bds2j9+decVcjFTE3nWq4OEtp3JuquBeW6mpqbp/56bORfyus0d/07VLpywdpkWRI+PIj2nkyDRyZBz5MY0cGUd+ChaKLeSr/v37q3///ibbbdmyJcvtDRs2NNxvla5Xr17q1avXQ/uqXLmypk2blrNArdDdW9d1bM8aS4dh1ciRceTHNHJkGjkyjvyYRo6MIz8FC/dsAQAAAIAZUGwBAAAAgBlQbAEAAACAGXDPFgqsEydOWDoEAAAAFGKMbAEAAACAGVBsAQAAAIAZUGwBAAAAgBlQbAEAAACAGVBsAQAAAIAZUGwBAAAAgBlQbAEAAACAGVBsAQAAAIAZUGwBAAAAgBlQbAEAAACAGVBsAQAAAIAZUGwBAAAAgBlQbAEAAACAGVBsAQAAAIAZUGwBAAAAgBlQbAEAAACAGVBsAQAAAIAZUGwBAAAAgBlQbAEAAACAGVBsAQAAAIAZUGwBAAAAgBlQbAEAAACAGVBsAQAAAIAZUGwBAAAAgBlQbAEAAACAGVBsAQAAAIAZ2KWmpqZaOggAAAAAKGgcLB0AYGtu34qzdAhWzcm5DDkygvyYRo6MIz+mkSPTyJFx5Mc0J+cylg7BJjCNEAAAAADMgGILAAAAAMyAYgsAAAAAzIBiCwAAAADMgGILAAAAAMyAYgsAAAAAzIBiCwAAAADMgGILAAAAAMyAYgsAAAAAzIBiCwAAAADMgGILAAAAAMyAYgsAAAAAzIBiCwAAAADMgGILAAAAAMyAYgsAAAAAzIBiCwAAAADMgGILAAAAAMyAYgsAAAAAzIBiCwAAAADMgGILAAAAAMyAYgsAAAAAzIBiCwAAAADMgGILAAAAAMyAYgsAAAAAzMDB0gFAateunSRpy5YtJtv++OOPGj9+vCZPnqxevXqZOzSzysn7xl+SkpIUHLxUP61frwsXLqh48WKqV6++Bg8KUuPGjS0dnlUgR8aRH9PIkXHkxzRyZBz5MY0cFQyMbKHQ+PHHH+Xj46Mff/zR0qHkWlJSkl597Z+a8fnniomJUatWLVW7dm3t2bNH/3hlhNb99JOlQ7Q4cmQc+TGNHBlHfkwjR8aRH9PIUcHByJaN6dixoxo2bChXV1dLhwILWPTddwoNDZWPj4/mzP5Czs7
"text/plain": [
"<Figure size 936x936 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pathlib\n",
"x = np.array([1, 2, 3, 4, 5, 6,7,8])\n",
"idx=0\n",
"best_df = pd.DataFrame.from_dict(best_dic, orient='index')\n",
"# Set up the matplotlib figure\n",
"f, ax = plt.subplots(figsize=(13, 13))\n",
"#sns.set(font_scale=1.8)\n",
"# Generate a custom diverging colormap\n",
"cmap = sns.diverging_palette(200, 55, as_cmap=True)\n",
"g = sns.heatmap(best_df, cmap=cmap, center=0,\n",
" square=True, linewidths=.5, cbar_kws={\"shrink\": .267}, annot=True,xticklabels=x)\n",
"plt.xlabel('#N Layers')\n",
"plt.ylabel('')\n",
"#plt.legend(bbox_to_anchor=(0.55, 0.65), prop={'size': 13})\n",
"#plt.grid()\n",
"plt.savefig(pathlib.Path('op_select_best').with_suffix('.pdf'), bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ef09034",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mct",
"language": "python",
"name": "mct"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}