MeCo/zero-cost-nas/notebooks/nas_examples.ipynb

388 lines
63 KiB
Plaintext
Raw Permalink Normal View History

2023-05-04 07:09:03 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Zero-Cost NAS Examples\n",
"\n",
"In this notebook, we provide examples of how to use zero-cost proxies within NAS algorithms. Specifically, we provide implementations of **random search** and **aging evolution search** with and without zero-cost warmup and move proposal.\n",
"\n",
"_note: the results in our ICLR paper were produced by a different (internal) AutoML tool that we are not publicly releasing. While it should be possible to get this notebook to exactly match our tool, we haven't done that here. We just provide examples to showcase the possible advantages of using zero-cost proxies._"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"try to create the NAS-Bench-201 api from ../data/NAS-Bench-201-v1_0-e61699.pth\n"
]
}
],
"source": [
"from nas_201_api import NASBench201API as API\n",
"api = API('../data/NAS-Bench-201-v1_0-e61699.pth')\n",
"api.verbose = False"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"synflow_proxy=[]\n",
"f = open('../results_release/nasbench2/nb2_cf100_seed42_dlrandom_dlinfo1_initwnone_initbnone.p','rb')\n",
"while(1):\n",
" try:\n",
" d = pickle.load(f)\n",
" synflow_proxy.append(d['logmeasures']['synflow'])\n",
" except EOFError:\n",
" break\n",
"f.close()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#dicts to map from index to spec or spec to index\n",
"#spec is the model description within NAS, in our case we use an encoding [5, 5, 5, 5, 5, 5]\n",
"#this is a length-6 vector, with each entry having a value between 0-4 \n",
"#this is sufficient to describe any NAS-Bench-201 cell\n",
"\n",
"_opname_to_index = {\n",
" 'none': 0,\n",
" 'skip_connect': 1,\n",
" 'nor_conv_1x1': 2,\n",
" 'nor_conv_3x3': 3,\n",
" 'avg_pool_3x3': 4\n",
"}\n",
"\n",
"def get_spec_from_arch_str(arch_str):\n",
" nodes = arch_str.split('+')\n",
" nodes = [node[1:-1].split('|') for node in nodes]\n",
" nodes = [[op_and_input.split('~')[0] for op_and_input in node] for node in nodes]\n",
"\n",
" spec = [_opname_to_index[op] for node in nodes for op in node]\n",
" return spec\n",
"\n",
"idx_to_spec = {}\n",
"for i, arch_str in enumerate(api):\n",
" idx_to_spec[i] = get_spec_from_arch_str(arch_str)\n",
"\n",
"spec_to_idx = {}\n",
"for idx,spec in idx_to_spec.items():\n",
" spec_to_idx[str(spec)] = idx"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import copy\n",
"random.seed = 42\n",
"def random_spec():\n",
" return random.choice(list(idx_to_spec.values()))\n",
"\n",
"def mutate_spec(old_spec):\n",
" idx_to_change = random.randrange(len(old_spec))\n",
" entry_to_change = old_spec[idx_to_change]\n",
" possible_entries = [x for x in range(5) if x != entry_to_change]\n",
" new_entry = random.choice(possible_entries)\n",
" new_spec = copy.copy(old_spec)\n",
" new_spec[idx_to_change] = new_entry\n",
" return new_spec\n",
"\n",
"def mutate_spec_zero_cost(old_spec):\n",
" possible_specs = []\n",
" for idx_to_change in range(len(old_spec)): \n",
" entry_to_change = old_spec[idx_to_change]\n",
" possible_entries = [x for x in range(5) if x != entry_to_change]\n",
" for new_entry in possible_entries:\n",
" new_spec = copy.copy(old_spec)\n",
" new_spec[idx_to_change] = new_entry\n",
" possible_specs.append((synflow_proxy[spec_to_idx[str(new_spec)]], new_spec))\n",
" best_new_spec = sorted(possible_specs, key=lambda i:i[0])[-1][1]\n",
" if random.random() > 0.75:\n",
" best_new_spec = random.choice(possible_specs)[1]\n",
" return best_new_spec"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def random_combination(iterable, sample_size):\n",
" pool = tuple(iterable)\n",
" n = len(pool)\n",
" indices = sorted(random.sample(range(n), sample_size))\n",
" return tuple(pool[i] for i in indices)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def run_evolution_search(max_trained_models=1000, \n",
" pool_size=64, \n",
" tournament_size=10, \n",
" zero_cost_warmup=0, \n",
" zero_cost_move=False):\n",
" \n",
" best_valids, best_tests = [0.0], [0.0]\n",
" pool = [] # (validation, spec) tuples\n",
" num_trained_models = 0\n",
"\n",
" # fill the initial pool\n",
" if zero_cost_warmup > 0:\n",
" zero_cost_pool = []\n",
" for _ in range(zero_cost_warmup):\n",
" spec = random_spec()\n",
" spec_idx = spec_to_idx[str(spec)]\n",
" zero_cost_pool.append((synflow_proxy[spec_idx], spec))\n",
" zero_cost_pool = sorted(zero_cost_pool, key=lambda i:i[0], reverse=True)\n",
" for i in range(pool_size):\n",
" if zero_cost_warmup > 0:\n",
" spec = zero_cost_pool[i][1]\n",
" else:\n",
" spec = random_spec()\n",
" info = api.get_more_info(spec_to_idx[str(spec)], 'cifar100', iepoch=None, hp='200', is_random=False)\n",
" num_trained_models += 1\n",
" pool.append((info['valid-accuracy'], spec))\n",
"\n",
" if info['valid-accuracy'] > best_valids[-1]:\n",
" best_valids.append(info['valid-accuracy'])\n",
" else:\n",
" best_valids.append(best_valids[-1])\n",
" \n",
" if info['test-accuracy'] > best_tests[-1]:\n",
" best_tests.append(info['test-accuracy'])\n",
" else:\n",
" best_tests.append(best_tests[-1])\n",
"\n",
" # After the pool is seeded, proceed with evolving the population.\n",
" while(1):\n",
" sample = random_combination(pool, tournament_size)\n",
" best_spec = sorted(sample, key=lambda i:i[0])[-1][1]\n",
" if zero_cost_move:\n",
" new_spec = mutate_spec_zero_cost(best_spec)\n",
" else:\n",
" new_spec = mutate_spec(best_spec)\n",
"\n",
" info = api.get_more_info(spec_to_idx[str(new_spec)], 'cifar100', iepoch=None, hp='200', is_random=False)\n",
" num_trained_models += 1\n",
"\n",
" # kill the oldest individual in the population.\n",
" pool.append((info['valid-accuracy'], new_spec))\n",
" pool.pop(0)\n",
"\n",
" if info['valid-accuracy'] > best_valids[-1]:\n",
" best_valids.append(info['valid-accuracy'])\n",
" else:\n",
" best_valids.append(best_valids[-1])\n",
" \n",
" if info['test-accuracy'] > best_tests[-1]:\n",
" best_tests.append(info['test-accuracy'])\n",
" else:\n",
" best_tests.append(best_tests[-1])\n",
"\n",
" if num_trained_models >= max_trained_models:\n",
" break\n",
" best_tests.pop(0)\n",
" best_valids.pop(0)\n",
" return best_valids, best_tests"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def run_random_search(max_trained_models=1000, \n",
" zero_cost_warmup=0):\n",
" \n",
" best_valids, best_tests = [0.0], [0.0]\n",
" pool = [] # (validation, spec) tuples\n",
" num_trained_models = 0\n",
"\n",
" # fill the initial pool\n",
" if zero_cost_warmup > 0:\n",
" zero_cost_pool = []\n",
" for _ in range(zero_cost_warmup):\n",
" spec = random_spec()\n",
" spec_idx = spec_to_idx[str(spec)]\n",
" zero_cost_pool.append((synflow_proxy[spec_idx], spec))\n",
" zero_cost_pool = sorted(zero_cost_pool, key=lambda i:i[0], reverse=True)\n",
" for i in range(max_trained_models):\n",
" if i < zero_cost_warmup:\n",
" spec = zero_cost_pool[i][1]\n",
" else:\n",
" spec = random_spec()\n",
" info = api.get_more_info(spec_to_idx[str(spec)], 'cifar100', iepoch=None, hp='200', is_random=False)\n",
"\n",
" if info['valid-accuracy'] > best_valids[-1]:\n",
" best_valids.append(info['valid-accuracy'])\n",
" else:\n",
" best_valids.append(best_valids[-1])\n",
" \n",
" if info['test-accuracy'] > best_tests[-1]:\n",
" best_tests.append(info['test-accuracy'])\n",
" else:\n",
" best_tests.append(best_tests[-1])\n",
" \n",
" best_tests.pop(0)\n",
" best_valids.pop(0)\n",
" return best_valids, best_tests"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10/10 [00:18<00:00, 1.80s/it]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"num_rounds = 10\n",
"length = 300\n",
"ae, ae_warmup, ae_move, rand, rand_warmup = [], [], [], [], []\n",
"for _ in tqdm(range(num_rounds)):\n",
" ae_best_valids, ae_best_tests = run_evolution_search(max_trained_models=length)\n",
" ae.append(ae_best_tests)\n",
" ae_warmup_best_valids, ae_warmup_best_tests = run_evolution_search(max_trained_models=length, zero_cost_warmup=3000)\n",
" ae_warmup.append(ae_warmup_best_tests)\n",
" ae_move_best_valids, ae_move_best_tests = run_evolution_search(max_trained_models=length, zero_cost_move=True)\n",
" ae_move.append(ae_move_best_tests)\n",
" rand_best_valids, rand_best_tests = run_random_search(max_trained_models=length)\n",
" rand.append(rand_best_tests)\n",
" rand_warmup_best_valids, rand_warmup_best_tests = run_random_search(max_trained_models=length, zero_cost_warmup=3000)\n",
" rand_warmup.append(rand_warmup_best_tests)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"def plot_experiment(exp_list, title):\n",
" def plot_exp(exp, label):\n",
" exp = np.array(exp) \n",
" q_75 = np.quantile(exp, .75, axis=0)\n",
" q_25 = np.quantile(exp, .25, axis=0)\n",
" mean = np.mean(exp, axis=0)\n",
" plt.plot(mean, label=label)\n",
" plt.fill_between(range(len(q_25)), q_25, q_75, alpha=0.1)\n",
" for exp,ename in exp_list:\n",
" plot_exp(exp,ename)\n",
" plt.grid()\n",
" plt.xlabel('Trained Models')\n",
" plt.ylabel('Test Accuracy')\n",
" plt.ylim(70,73.6)\n",
" plt.legend()\n",
" plt.title(title)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOydeXhb1Zn/P68273YcJ3H2HbKHQBZ2GsoObShdmDBMobT8aClM6b4MM5QuDEz3ocAUWlpgoIEWytKyFKbFUHaSQAgkISGrndVbbEu29vf3x5UV2ZZkSZZk2T6f59Fj3XPvPfdcyTrf+77vOe8RVcVgMBgMht7YBrsBBoPBYChMjEAYDAaDIS5GIAwGg8EQFyMQBoPBYIiLEQiDwWAwxMUIhMFgMBjiYgTCUHCIiFtEZg52O1JFRO4RkR8O4Pz3RGRlFptUMIjIdBFREXEMdlsM6WMEwpAzRKRORFpFpCid81S1XFV35KA9N4pIICJA3a/D2b5OP23oIyaqukBV63JwrQUi8qyItIjIYRFZJyLnZ/s6huGLEQhDThCR6cCpgAKrBrUxPXkoIkDdr1GD3aAc8mfgOWA8MA74EtCe7YsY62D4YgTCkCsuA14D7gEuj90hIjUi8mcRaReRN0XkhyLyUsx+FZHZkff3iMjtIvKkiHSIyOsiMivm2LNF5H0RaRORO0TkBRG5Mt3Gisj/iMhPepU9LiJfjbyfF7GIDkdcQnFFT0Q+E3svsfcjIlcBlwLfjFgvf47s3yUiZ0beF4nIL0RkX+T1i24LTERWikiDiHxNRA6JyH4RuSJBO8YAM4Bfq6o/8npZVWM/54+IyNuRe3pFRBbH7Pu2iGyPfOabROSiXvf4soj8XESagRtFpEREfioiuyPfxUsiUhLTpEtFZI+INInI9f1/I4ZCwAiEIVdcBjwQeZ0jIrUx+24HPFhPtpfTS0DisBr4HlANfADcBNFO8GHgO0AN8D5wUobtXQP8k4hIpO5q4GzgQRFxYj2NP4v1JP6vwAMiMiedC6jqXVifx48i1stH4xx2PXACsAQ4BlgB/HvM/vFAFTAJ+Bxwe6StvWnG+qzuF5GP9fr8EZFjgd8Cn8f67O4EnohxB27HsgCrsD77+0VkQkwVxwM7gFqs7+MnwFKsz3808E0gHHP8KcAc4AzgBhGZF+8zMhQWRiAMWUdETgGmAX9Q1XVYnc0/R/bZgU8A31XVTlXdBNzbT5WPquobqhrE6mCXRMrPB95T1T9F9t0KHOinrosjT8zdr+cj5f/AcoedGtn+JPCqqu7D6rDLgVsiT+J/B/4CXJLCx5EulwLfV9VDqtqI1Tl/OmZ/ILI/oKpPAW6sjrcHaiVZOx3YBfwU2C8iL4rIUZFDrgLuVNXXVTWkqvcCvsi9oqp/VNV9qhpW1YeAbVhi1c0+Vf1l5HP3AZ8FrlPVvZH6XlFVX8zx31PVLlXdAGzAEj9DgWMEwpALLgeeVdWmyPbvOWIljAUcQH3M8bHv4xHb6XdiddYAE2PPjXSKDf3U9QdVHRXzOj3m3Ac50un/M5YYRa+jqrFPxLuxnuKzzcRI3bHXmRiz3RzplLuJ/Tx6oKoNqnqtqs7CEmwPcF9k9zTga7FiCUzpvpaIXBbjfjoMLATGxFQf+52NAYqxHgQSkeg7NBQwRiAMWSXid74Y+JCIHBCRA8BXgGNE5BigEQgCk2NOm5Lh5fbH1hNxD01OfHi/rAE+KSLTsFwoj0TK9wFTRCT29zIV2BunDg9QGtOm8b3295c+eR9W5x17nX39Nz05qlqP5dpbGCmqB27qJZalqromcv+/Bq4FaiKB/HcBSXAfTYAXmIVhWGEEwpBtPgaEgPlYrqAlwDwsF85lqhoC/oQV2CwVkblY8YpMeBJYFPGxO4BrsHz0GaGqb2F1dr8B/qqq3UNgX8d66v2miDjFmrPwUSyLozcbgAUiskREioEbe+0/CCSb47EG+HcRGRuJsdwA3J/uvYhItYh8LxIct0Xq+izWwAGwBOALInK8WJSJyAUiUgGUYQlAY6SuKzgiLH2IWFa/BX4mIhNFxC4iJ0qaw5sNhYcRCEO2uRz4naruUdUD3S/gNqyRLA6sJ9MqLLfD/2J1ir6ENSYg4sL6FPAjrKDsfGBtP3X9k/ScB+EWkXEx+38PnBn5230dP5YgnIclIHdgid2WOG3aCnwf+D8sv/1LvQ65G5gfcd08Fqd9P4zcwzvARmB9pCxd/MD0SDvasSwAH/CZSDvXAv8P63tpxQpod+/bhBW3eBVL0BYBL/dzva9H2vsm0AL8F6Z/GfKIWTDIMNiIyH8B41W1v9FM/dVjw4pBXKqqz/d3vMFgSI5ReEPeEZG5IrI44tpYgTVc89EM6zpHREZF3Bn/huUnf62f0wwGQwqYGZCGwaACy600EcuF8VPg8QzrOhHLHeQCNgEfU9WubDTSYBjpGBeTwWAwGOJiXEwGg8FgiMuwcjGNGTNGp0+fntG5Ho+HsrKy7DZokDD3UngMl/sAcy+FSqb3sm7duiZVHRtv37ASiOnTp7N27dqMzq2rq2PlypXZbdAgYe6l8Bgu9wHmXgqVTO9FRHYn2mdcTAaDwWCIS84siEimy4diimZizQqtAS7EyvR4CPhMJCFa7/NDWBNvAPaoaiGtKWAwGAzDnpwJhKq+TyTrZiSD516sse6tqvofkfIvYYnGF+JU0aWqS+KUGwzDluauZtr8bYPdjKQcCBxgR1vWF/wbFAZ8L55m8HVkr0EZYrc5c1JvvmIQZwDbVbW3r6s754uhkFCFQGfSQ7qCXjTBV1dkc2G32bPUljD4PdmpK9fYi8Ce+U+qw9/B2Q+fjT/sz2KjckS8JCFDlWFwLzUh5caZt2W93nwJxGqsiVEAiMhNWAna2rBy1sejWETWYmX+vEVV436NYq3SdRVAbW0tdXV1GTXQ7XZnfG6hMfB7UQgHk+4OaijhbrvYEUm4Oy3cHi91L72Sncpyjc1Bz4SnR0jlO2nwN+AP+/lw5YeZ6pqa/fYlIZTGfCi/14uruDiHrckfye7FEXAzrukVbAn+18s9uwnZijg05kRErOeqDw6H8Sf+aeQMp82Vkz4s5xPlRMSFla54gaoe7LXvO0Cxqn43znmTVHWviMwE/g6coarJ8s2zbNkyNaOYsnAvPndSszkQDtDqS7y08eiiUTiyZEHUrX2XlcsSJhItLMrGgD2+qZ/Kd/Jiw4tc87druP/8+zlmbP7W0/EGQrR1BVI+fuPaV1m07MQctih/JLuXiv/7JiUb7ydcMTHu/rCrnPZz/ptg7TFUFDv4y4b9fPORd5hQVYwtW09IKVJT7uKrC4OZjmJap6rL4u3LhwVxHrC+tzhEeAB4CugjEKq6N/J3h4jUAceSfEESQ7YIJXdxBJJZF5D3H0dB4G2Hdx6EYPzPbtqunVD3etIqDnVsA2DchkfA8UzWm5iQYIiyUOoPijP21VMWGCJWXT8kvBcNU/LeGroWXUrHmT9OWkcwFOb+13Zz7yu7WTipkj9fewoyCL+BXHhA8iEQl9DTvXSUqm6LbF4I9EmZHFljt1NVfZE89idjpXQ25INQ8qfJYDi5DW2TETh6+m/fg/efTLh7BliLfyahcVQlVI9izEv/nc2W9Uu6zqJy6Lnm3RAm2b2EXRV0Lrum3zoeWb+Xnz23lTKXnZsuWjgo4pArcioQIlIGnIW1MHo3t0SGwIaxvpovRI5dBnxBVa/EWmDmThEJY83VuCWSo96QS8Jh8ByynKm9iBWFgCYWkCH/2/jgb7BhTf/HxaJh2PMKnPYNWPlvPXa1dQXwBkMpuWV2v30zo/a/QMtXklsag00huJgeXtfAP7Y2Rrd7/Mf2+veNHUzR+1/b3d5GWWVVZF/PnYrAX5qB5h7n9f51bD3Ywcmzarj/yuOHlThAjgVCVT1Y8x5iyz6R4Ni1wJWR969gLVJiyCcajisOAK3+w4l29UCG8txLTxM8801wVUB5bXrnzrkATv4y2HrefxgBsR15JaGpq4kxxeP6PW7QSeFe0uFQh5fDnanHQHY3d/LjZ7cxdXQ
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_experiment([(ae,'AE'), (ae_warmup,'AE + warmup (3000)'), (ae_move,'AE + move')], 'Aging Evolution Search')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3de3xcdZ34/9d7cmnSJE2vBEjpvdTSFlIILcit/QFCEVoV0LIuUrwgLq4Lrl+F1QUUL9VFZF3cBQQEEVtQ6VYFZLESoFIoKRTS0nubQu9tesmluc3M+/fHOUkn6UwyM5kzt7yfj0cemXP/fDKT857P5Xw+oqoYY4wx3flSnQBjjDHpyQKEMcaYsCxAGGOMCcsChDHGmLAsQBhjjAnLAoQxxpiwLEAYE4GI3C0iv0l1OpJFRB4Xke+nOh0mfViAMBlFRGpFpFlEGkVkj3tTK051uvpKRP5NRLa5+dohIk+nOk3GWIAwmegqVS0GKoDpwB0pTk+fiMgNwPXAJW6+KoFlHlwnN9HnNNnNAoTJWKq6B3gRJ1AAICK3i8gWEWkQkfdF5JMh2xaIyHIRuVdEDrnf2OeEbB8rIq+4x74EDA+9nojMFZG1InJYRKpEZHLItloR+X8i8p6INInIoyJSJiIvuOf7q4gMiZCVs4EXVXVLR75U9eGQc5e659stIjtF5PsikuNuGy8ifxOROhE5ICJPicjgbun6loi8BzSJSK6InC8ir7v5+FBEFoSkZYiIPOem+U0RGR/Le2KyiwUIk7FEZCQwB9gcsnoLcAFQCnwX+I2InBSyfSawAefm/xPgURERd9tvgVXutnuAG0KudSqwCLgVGAE8D/xJRPJDzn01cClwKnAV8ALwb+7+PuBrEbLyBvA5N8BUdtz8QzwO+IEJOCWmjwFf7Ega8CPgZGAycApwd7fjrwM+DgwGyt10/Zebrgpgdci+83H+bkNw/q4/iJBm0x+oqv3YT8b8ALVAI9AAKE5VzOAe9l8NzHNfLwA2h2wb6J7jRGAUzk24KGT7b4HfuK//HXgmZJsP2AnMCknXZ0O2/wH4n5Dlfwb+t4d0fhb4K9AE1AHfcteXAa1AYci+1wEvRzjPJ4B3uv29Ph+yfAewJMKxjwOPhCxfAaxP9XtuP6n7sTpJk4k+oap/FZGLcG7iw4HDACLyOeDrwBh332K6VhXt6XihqkfdwkPHPodUtSlk3+0438jB+Ya+PeTYoIh8iPONvMPekNfNYZYjNqar6lPAUyKSh3OTf0pEVgOHgDxg97GCDj7gQze/ZcB/4pSaStxth7qd/sOQ16fglLIi2RPy+mhPaTbZz6qYTMZS1VdwvvXeCyAio4FfAl8FhqnqYGANTjVMb3bj1L8XhawbFfJ6FzC6Y8GtljoFpxSRMKrarqq/A94DpuLc3FuB4ao62P0ZpKpT3EN+iFMKmqaqg4B/5Pj8hg7Z/CFg7QomKhYgTKa7H7hURM4AinBuhvsBRORGnJtsr1R1O1ANfFdE8kXkfJx2hA7PAB8XkYvdb/n/inPjfr2vGXAbzz8uIiUi4nMbzqcAb6rqbuD/gJ+KyCB3+3i39AROqaEROCIi5cD/6+VyTwGXiMin3QbrYSJS0csxpp+yAGEymqruB34N3Kmq7wM/BVbgVO9MA/4ew+n+AacR+yBwl3vejutswPl2/l/AAZzgcZWqtiUgG/U4jdkf4FSV/QT4iqoud7d/DsgH3sepPvo90NHw/l3gTOAI8BzwbE8XUtUPcNoW/hUnn6uBMxKQB5OFRNUmDDLGGHM8K0EYY4wJywKEMcaYsCxAGGOMCcsChDHGmLCy6kG54cOH65gxY+I6tqmpiaKiot53zACWl/STLfkAy0u6ijcvq1atOqCqI8Jty6oAMWbMGKqrq+M6tqqqilmzZiU2QSlieUk/2ZIPsLykq3jzIiLbI22zKiZjjDFheVaCEJFJQOikJ+OAO4FhwDwgCOwDFqjqrjDHB4Aad/EDVZ3rVVqNMcYcz7MA4T55WgHgDl+8E1iCMyDav7vrv4YTNG4Oc4pmVbUhAIwxJkWS1QZxMbDFHe8mVMfYOcYYY9JMUobaEJHHgLdV9QF3+Qc448scAWa74+l0P8aPM06MH1ioqv8b4dw3ATcBlJWVnbV48eK40tjY2EhxcXaMbGx5ST/Zkg+wvKSrePMye/bsVapaGXaj1xNO4AwydgAoC7PtDuC7EY4rd3+Pw5n0ZHxv1zrrrLM0Xi+//HLcx6Yby0v6yZZ8qFpe0lW8eQGqNcI9NRm9mObglB72htn2FM40jcdR1Z3u761AFc5Ui8YYY5IkGQHiOpy5fAEQkYkh2+YB67sfICJDRGSA+3o4cB7OUMfGGGOSxNNGand2rkuBL4esXuh2gQ3iTOF4s7tvJXCzqn4RZ/L1h0QkiBPEFqoz1r8xxpgk8TRAqDO/77Bu6yJVKVUDX3Rfv44z2YsxxpgUsSepjTHGhGUBwhhjTFgWIIwxxoRlAcIYY0xYFiCMMcaEZQHCGGNMWBYgjDHGhGUBwhhjTFgWIIwxxoRlAcIYY0xYFiCMMcaEZQHCGGNMWBYgjDHGhGUBwhhjTFieDvedNYLBFF1YIdAGsc4brkFoO+pNkpItk/Oiwa6vWxtTl5ZEsrykn7yBnpzWAkQkquBvcW5OgbZUpyY2GoCWI6lORWJkS140AK0NqU5FYlhe0k9ugTen9eSs2SAYgObDqU6FMcakjLVBRKKpqlYyxpj0YAEiEgsQxph+zgJEJBYgjDH9nLVBRBRjzyFj+qK1AV76d2jYk+qU9Gp601HY5E2vmWTLmrwUnwDltyT8tJ4FCBGZBDwdsmoccCcwDJgHBIF9wAJV3RXm+BuA77iL31fVJ7xKa1hWgjB9EQzAzlVOTzhg6MFa2HYw8v5rn4VN/wennAOSnCTGK9AchPwsuKmSRXnJtG6uqroBqAAQkRxgJ7AEOKSq/+6u/xpO0Lg59FgRGQrcBVTifJVfJSJ/VNVDXqX3+AxkcYDQIPgzo+uuL9AK7S0Rtgbh6MH07Ib81qOw9g+di6cDrOnlmJn/BOd9zctUJcR71WuYVTk11clIiKzJS9EIeG15wk+brCqmi4Etqrq92/oiwtflXAa8pKoHAUTkJeByYJGnqQwV68NpqbDtNWg5vivuCft2wLpt4Y9pa4I3H4TG9K/KALgQ4O+pTkWcKr8AEz8GwKp1Wzhr8vjI++YWwPBTk5Qwkw1aA4HO29QAj6rEkxUg5hNycxeRHwCfA44As8PsXw58GLK8w12XPOlegtj9Liz5UthNpwGs7+HYYROh4h9I+7oMYMvOPYwvPzHyDgOHevaQUJ8UDoZRHwVx/sYNO3PgpCz4pmrSRkOzvzMs5JZ4EyBEPf6mLCL5wC5giqru7bbtDqBAVe/qtv4b7vrvu8v/DjSr6r1hzn8TcBNAWVnZWYsXL44rnY2NjRQXFx9bEQzgNJOkp3Fbf83Inc+zavpCgr78LtuONrcysHBA+ANFaCk4AZWcJKSy7xqbWiguSsMAEKNsyQdYXtKFP3Ds3p2bm0djU1PXe1iUZs+evUpVK8NtS0YJYg7wdvfg4HoKeB6nvSHUTmBWyPJIoCrcyVX1YeBhgMrKSp01a1a43XpVVVVFl2ObDkCgPa5zeap+p1NN9O47MPpczp515XG7VFWvYWY21Kvi5CUb6oizJR9geUkHAVUONh1rexsy4mSW//114r3/RZKMAHEdXauXJqrqJndxHuErQ14EfigiQ9zljwF3eJrK7tKximn/Bnhy3rHlGTelLi3GmJQJBJPTRuppgBCRIuBS4Mshqxe6XWCDwHbcHkwiUgncrKpfVNWDInIP8JZ7zPc6GqyTJh0DxHa3tfayH0FBKYy5ILXpMcakRDBJnWg8DRCq2oTz3EPouqsj7FsNfDFk+THgMS/T16N07MW0cxWUjoIpn0x1SowxKZSs25MNtRFOyuZ/6IEGnQAx8qxUp8QYk2Jedy7qYENtdBcMQntTqlPhOLQNnv0S+FudANFyGMrDdjYwxvQjWVHFlJHSaYapTX+FIztg6rV
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_experiment([(rand,'RAND'), (rand_warmup,'RAND + warmup (3000)')], 'Random Search')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}