119 lines
2.8 KiB
Plaintext
119 lines
2.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "german-madonna",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Implementation for \"A Tutorial on Bayesian Optimization\"\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"def get_data():\n",
|
|
" return np.random.random(2) * 10\n",
|
|
"\n",
|
|
"def f(x):\n",
|
|
" return float(np.power((x[0] * 3 - x[1]), 3) - np.exp(x[1]) + np.power(x[0], 2))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "broke-citizenship",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Kernels typically have the property that points closer in the input space are more strongly correlated\n",
|
|
"# i.e., if |x1 - x2| < |x1 - x3|, then sigma(x1, x2) > sigma(x1, x3).\n",
|
|
"# the commonly used and simple kernel is the power exponential or Gaussian kernel:\n",
|
|
"def sigma0(x1, x2, alpha0=1, alpha=[1,1]):\n",
|
|
" \"\"\"alpha could be a vector\"\"\"\n",
|
|
" power = np.array(alpha, dtype=np.float32) * np.power(np.array(x1)-np.array(x2), 2)\n",
|
|
" return alpha0 * np.exp( -np.sum(power) )\n",
|
|
"\n",
|
|
"# the most common choice for the mean function is a constant value\n",
|
|
"def mu0(x, mu):\n",
|
|
" return mu"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "aerial-carnival",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"K = 5\n",
|
|
"X = np.array([get_data() for i in range(K)])\n",
|
|
"mu = np.mean(X, axis=0)\n",
|
|
"mu0_over_K = [mu0(x, mu) for x in X]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "polished-discussion",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"sigma0_over_KK = []\n",
|
|
"for i in range(K):\n",
|
|
" sigma0_over_KK.append(np.array([sigma0(X[i], X[j]) for j in range(K)]))\n",
|
|
"sigma0_over_KK = np.array(sigma0_over_KK)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "comic-jesus",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(20, 20)\n",
|
|
"1.1038803861344952e-06\n",
|
|
"1.1038803861344952e-06\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(sigma0_over_KK.shape)\n",
|
|
"print(sigma0_over_KK[1][2])\n",
|
|
"print(sigma0_over_KK[2][1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "statistical-wrist",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|