120 lines
25 KiB
Plaintext
120 lines
25 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"The config dir path: /Users/xuanyidong/Desktop/AutoDL-Projects/configs\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"#####################################################\n",
|
||
|
"# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #\n",
|
||
|
"#####################################################\n",
|
||
|
"import os, sys, math\n",
|
||
|
"import numpy as np\n",
|
||
|
"from pathlib import Path\n",
|
||
|
"\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"import torch\n",
|
||
|
"from xautodl.xmisc.scheduler_utils import CosineParamScheduler, MultiStepParamScheduler\n",
|
||
|
"from xautodl.xmisc.scheduler_utils import LRMultiplier, WarmupParamScheduler\n",
|
||
|
"\n",
|
||
|
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||
|
"\n",
|
||
|
"config_dir = (Path(__file__).parent / \"..\" / \"configs\").resolve()\n",
|
||
|
"print(\"The config dir path: {:}\".format(config_dir))\n",
|
||
|
"\n",
|
||
|
"def draw(steps, lrs):\n",
|
||
|
" plt.close()\n",
|
||
|
" dpi, width, height = 200, 1400, 800\n",
|
||
|
" figsize = width / float(dpi), height / float(dpi)\n",
|
||
|
" fig = plt.figure(figsize=figsize)\n",
|
||
|
" ax = fig.add_subplot(111)\n",
|
||
|
"\n",
|
||
|
" plt.plot(steps, lrs)\n",
|
||
|
" plt.title(\"Plot Cosine Decayed LR with Warmup\")\n",
|
||
|
" plt.xlabel(\"steps\")\n",
|
||
|
" plt.ylabel(\"learning rate\")\n",
|
||
|
" plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbkAAAEWCAYAAAD7HukTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAABBaUlEQVR4nO3dd3hUZdrH8e+dShIg9N577yBgryCK2Luuuq6rq7vruq762nvX1bUhuuraxY4KKhaKYqH3FkILvQbIhExmcr9/nBMdYxIGyORMuT/XNVdmzpyZ+c3JZO485zzneURVMcYYY+JRktcBjDHGmEixImeMMSZuWZEzxhgTt6zIGWOMiVtW5IwxxsQtK3LGGGPilhU5UyERmSQil3udozIicrOIvOh1jmglIqtE5Divc5QSkT0i0q6S+6Mqr4l9VuQSnPulUuh++WwSkZdFpOZ+PkcbEVERSdnHep1E5F0R2Soi+SIyT0SuE5HkA82vqverapUXYhE5SkRK3O2yR0TyRGSsiAys6tfyioi8IiL3VnCfikiB+97XicjjB/N7KqWqNVU1d1+vvy8ikuJmGxSy7AI3d9llSw42t4ldVuQMwEhVrQn0AwYCt1b1C4hIe+AnYC3QU1WzgbOAAUCtqn69KrLe3S61gMHAEmCqiBzrbaxq09t9/0cC5wCXeZznF6oaAH7AyVbqCJzfUdllU/bnucVh341xwn6R5hequg6YAPQoe5+IJInIrSKyWkQ2i8irIpLt3l36JbLT/e96SDlPfxcwTVWvU9UN7ustVdXzVXWn+xqniMhCEdnp7irtGvL6N7otit0isrS00IjInSLyunu9tEX5BxFZ47YYbynzHm4SkRUiss1tmdULY7uoquap6u3Ai8BDIc/ZRUQmish2N9fZIfdliMhj7jbLF5HvRCTDve9dEdnoLp8iIt3d5QPdFnVKyPOcISJzwnkPInKR+3rbQt/7wVDVHOB7oE9594vIpSLyScjtHBEZG3J7rYj0ca+riHQQkSuAC4Ab3M/MJyFP2cdt5eeLyDsiUqOCaFNwilipw3F+N2WXTRGRuiLyqYhsEZEd7vUWIRknich9IvI94APauVn/IiLL3c/dPSLSXkR+EJFd7rZPcx9/iYh8V2a7qIh0cK+/IiKj3c/KbhGZLCKtK3hfpgpZkTO/EJGWwAhgdjl3X+JejgbaATWBp937Sr9U6ri7o34o5/HHAe9V8tqdgLeAa4GGwHjgExFJE5HOwDXAQFWtBQwDVlXyVg4DOgPHAreHFMu/Aafi/KffDNgBPFPJ85TnA6CfiGSJSBYwEXgTaAScBzxbWrCAR4H+wFCgHnADUOLeNwHo6D5uFvAGgKpOB7YBx4e85oXAa/t6DyLSDXgOuMi9rz7QgoMkIl1wikVOBatMBg53C3BTIBU41H1s6WdlXugDVHUMznt+2P3MjAy5+2xgONAW6IXzuSvPFOBQ93UbAFnAWGBQyLIu7npJwMtAa6AVUMivn99SFwFX4LTcV7vLhuP8Dgfj/P7G4BTnljj/DJ5XQbbyXADcAzQA5rjv30SaqtolgS84xWIPsBPnD/tZIMO9bxJwuXv9a+AvIY/rDBQDKUAbQIGUSl6nGBheyf23AWNDbicB64CjgA7AZpxCmVrmcXcCr7vXS3O0CLn/Z+Bc9/pi4NiQ+5qWvody8hwF5JWzvIv7Gs1xduFNLXP/88Adbv5CnF1++/od1HGfM9u9fSPwhnu9Hk7Loum+3gNwO/B2yH1ZgB84roLXfQW4t4L7FNgFFLjX3wLSK3kPa3F2d5+LUwh+drfVpcC4Ms/boaLXdz+PF4bcfhgYXcFr1gD2Ar2B00K22Y8hy1ZW8Ng+wI6Q25OAu8vZBoeG3J4J3Bhy+zHgCff6JcB35Tw+9L2G/m5qAkGgZVX/Tdvlt5dKOwqYhHGqqn61j3Wa8et/t7jXU4DGYb7GNpwv5LCeX1VLRGQt0FxVJ4nItTgFrbuIfAFcp6rrK3iujSHXfThfKOD8F/+hiJSE3B9038O6MN9Hc5wvr53u8x0iIjtD7k/BaXU1wPkSXlH2CcTpwHEfzjHJhvzaumsA5AOvA4vF6QB0Nk4h3RDGe2iGU2wAUNUCEdkW5vsqTz83/1nAgzhFs6iCdSfz6z8kk3G2z5HAEPf2/ij7+2tW3kqquldEfsbZk9AOmOre9V3IsikAIpIJ/BunZVbXXa+WiCSratC9/cu2C7Ep5HphObebhPmefvP8qrpHRLZT5ndmqp7trjThWo/zBVuqFRDA+aMPZyqLr4Azwn1+ERGcXULrAFT1TVU9zF1HCTkuth/WAieqap2QSw11jkWG6zRglqoWuM83uczz1VTVq4CtOK2M9uU8x/nAKJyWaTZOCxRA4Jdjoz+4r3URv+6q3Nd72ICzzZwnc77Y6+/He/sddYx189xeyaqlRe5w9/pknCJ3JBUXuaqYAqX0uNzh/FrkpoYsKz1e/E+cvQ+HqGptft3FLlWUpwDILL0hIuUVv9DfTU2cVnpF/6iZKmJFzoTrLeAfItLW/QO9H3hHnV5uW3BaIxWe/4SzC2+oiDxS+gXgdkB4XUTq4BxLOUlEjhWRVJwvpSJgmoh0FpFjRCQdp3AU4rRe9tdo4L7SA/4i0lBERu3rQeJoLiJ3AJcDN7t3fQp0cjt7pLqXgSLSVVVLgJeAx0WkmYgki8gQ9z3Uct/bNpwvxvvLedlXcY4B9QQ+DPM9vAecLCKHuR0i7mbff+PJIlIj5JJWwXoPAldU8OUNTiE7GmdXdx5OoRmOU2TLO8YLzj9IlX1mwjHFfd2WwCJ32Xc4BbcPvxa5Wjifm53idNS54yBft6y5OHsZ+rgdZe4sZ50RIb+be4CfVNVacRFmRc6E6yWcFsUUYCVOsfkrgKr6cHa/fS9Oz8jBZR+sqitwdl21ARaKSD7wPjAD2K2qS3E6WDyF0woaiXNqgx9Ix/mS3YqzK6sRvxaa/fEkMA74UkR24xy7OaSS9ZuJyB6cY5bTcQrOUar6pfuedgMn4ByHWu9me8jNC3A9MN997Hb3viScArYap5W6yM1R1oe4uybdVuM+34OqLgSuxukIswGnU0rePrbJTThf/qWXb8pbSVXn4xSyf1Vw/zKc7TTVvb0LyAW+D9kdWNZ/gW7uZ+ajfeSsyDSc1vBPqu6BMNVtOP94bVbV5e56TwAZOJ+hH4HPD/D1yuW+/7tx9lgsxym0Zb2JU1y343RmuaAqM5jyifu5MMZEGRFZAfw5jOOlJsqJyCs4HZmq/BxUUzlryRkThUTkDJxjROW2rIwx4bHelcZEGRGZBHQDLnKP7RljDpDtrjTGGBO3bHelMcaYuBVzuysbNGigbdq08TqGMcaYKDJz5sytqtqw7PKYK3Jt2rRhxowZXscwxhgTRURkdXnLbXelMcaYuGVFzhhjTNyyImeMMSZuWZEzxhgTt6zIGWOMiVsRLXIiMlxElopIjojcVM79R4kzxf0c91LZVB7GGGPMfonYKQTuxJDPAMfjjIQ+XUTGqeqiMqtOVdWTI5XDGGNM4orkeXKDgBxVzQUQkbdxJoosW+SMMcYkAFVl3c5CNuTvZUP+XjbmF5KZlsKFg1vv+8EHKJJFrjm/ndY9j/Ln7hoiInNx5uO63p0T6zdE5ArgCoBWrVpFIOrv5W7ZQ3FQ6dS4Js4k1cYYY/ZHsESZsWo7P6/czuy1O5m9Zgc7fMW/Wad3i+yYLXLlVYayo0HPAlqr6h4RGQF8BHT83YNUxwBjAAYMGFAtI0pf/uoMcrcU0K5BFsN7NGFEz6b0aJ5dHS9tjDExK1iiTFuxlQkLNvLlwo1s3eNHBDo0rMnx3RrTq0UdWtXLpGl2DZpk16BWjdSI5olkkcvDmZK+VAuc1tov3NmDS6+PF5FnRaSBqm6NYK6wbC/w07tlHWrXSOH5Kbk8O2kFvVvW4dKhbRjRsylpKdYx1RhjSuUXFjN2+lr+98Mq8nYUkpmWzNFdGjGiR1MO69iA7IzIFrOKRLLITQc6ikhbYB1wLnB+6Aoi0gTYpKoqIoNwentui2CmsPn8QYa0q89NJ3ZhR4GfT+a
|
||
|
"text/plain": [
|
||
|
"<Figure size 504x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"p = torch.nn.Parameter(torch.zeros(0))\n",
|
||
|
"opt = torch.optim.SGD([p], lr=5)\n",
|
||
|
"multiplier = WarmupParamScheduler(\n",
|
||
|
" CosineParamScheduler(0.1, 0.0001),\n",
|
||
|
" warmup_factor = 0.001,\n",
|
||
|
" warmup_length = 0.05,\n",
|
||
|
" warmup_method = 'linear'\n",
|
||
|
")\n",
|
||
|
"total = 100\n",
|
||
|
"scheduler = LRMultiplier(opt, multiplier, total)\n",
|
||
|
"steps, lrs = [], []\n",
|
||
|
"\n",
|
||
|
"for _iter in range(total * 2):\n",
|
||
|
" p.sum().backward()\n",
|
||
|
" opt.step()\n",
|
||
|
" lrs.append(opt.param_groups[0][\"lr\"])\n",
|
||
|
" steps.append(_iter)\n",
|
||
|
"\n",
|
||
|
" scheduler.step()\n",
|
||
|
"draw(steps, lrs)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.8"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|