diff --git a/.github/workflows/basic_test.yml b/.github/workflows/basic_test.yml index 5db1a13..af3d3e1 100644 --- a/.github/workflows/basic_test.yml +++ b/.github/workflows/basic_test.yml @@ -48,4 +48,5 @@ jobs: ls python --version python -m pytest ./tests/test_basic_space.py -s + python -m pytest ./tests/test_synthetic.py -s shell: bash diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py index f96d0ef..5750ebd 100644 --- a/lib/datasets/__init__.py +++ b/lib/datasets/__init__.py @@ -3,3 +3,5 @@ ################################################## from .get_dataset_with_transform import get_datasets, get_nas_search_loaders from .SearchDatasetWrap import SearchDataset + +from .synthetic_adaptive_environment import SynAdaptiveEnv diff --git a/lib/datasets/synthetic_adaptive_environment.py b/lib/datasets/synthetic_adaptive_environment.py new file mode 100644 index 0000000..4166973 --- /dev/null +++ b/lib/datasets/synthetic_adaptive_environment.py @@ -0,0 +1,84 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import numpy as np +from typing import Optional +import torch.utils.data as data + + +class SynAdaptiveEnv(data.Dataset): + """The synethtic dataset for adaptive environment.""" + + def __init__( + self, + max_num_phase: int = 100, + interval: float = 0.1, + max_scale: float = 4, + offset_scale: float = 1.5, + mode: Optional[str] = None, + ): + + self._max_num_phase = max_num_phase + self._interval = interval + + self._times = np.arange(0, np.pi * self._max_num_phase, self._interval) + xmin, xmax = self._times.min(), self._times.max() + self._inputs = [] + self._total_num = len(self._times) + for i in range(self._total_num): + scale = (i + 1.0) / self._total_num * max_scale + sin_scale = (i + 1.0) / self._total_num * 0.7 + sin_scale = -4 * (sin_scale - 0.5) ** 2 + 1 + # scale = -(self._times[i] - (xmin - xmax) / 2) + max_scale + self._inputs.append( + np.sin(self._times[i] * sin_scale) * (offset_scale - scale) + ) + self._inputs = np.array(self._inputs) + # Training Set 60% + num_of_train = int(self._total_num * 0.6) + # Validation Set 20% + num_of_valid = int(self._total_num * 0.2) + # Test Set 20% + num_of_set = self._total_num - num_of_train - num_of_valid + all_indexes = list(range(self._total_num)) + if mode is None: + self._indexes = all_indexes + elif mode.lower() in ("train", "training"): + self._indexes = all_indexes[:num_of_train] + elif mode.lower() in ("valid", "validation"): + self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] + elif mode.lower() in ("test", "testing"): + self._indexes = all_indexes[num_of_train + num_of_valid :] + else: + raise ValueError("Unkonwn mode of {:}".format(mode)) + # transformation function + self._transform = None + + def set_transform(self, fn): + self._transform = fn + + def __iter__(self): + self._iter_num = 0 + return self + + def __next__(self): + if self._iter_num >= len(self): + raise StopIteration + self._iter_num += 1 + return self.__getitem__(self._iter_num - 1) + + def __getitem__(self, index): + assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) + index = self._indexes[index] + value = float(self._inputs[index]) + if self._transform is not None: + value = self._transform(value) + return index, float(self._times[index]), value + + def __len__(self): + return len(self._indexes) + + def __repr__(self): + return "{name}({cur_num:}/{total} elements)".format( + name=self.__class__.__name__, cur_num=self._total_num, total=len(self) + ) diff --git a/notebooks/TOT/synthetic.ipynb b/notebooks/TOT/synthetic.ipynb index 6173183..10ed8f9 100644 --- a/notebooks/TOT/synthetic.ipynb +++ b/notebooks/TOT/synthetic.ipynb @@ -5,17 +5,39 @@ "execution_count": 1, "id": "filled-multiple", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", + "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" + ] + } + ], "source": [ - "#\n", - "# %matplotlib notebook\n", + "import os, sys\n", + "import torch\n", "from pathlib import Path\n", "import numpy as np\n", "import matplotlib\n", "from matplotlib import cm\n", "matplotlib.use(\"agg\")\n", "import matplotlib.pyplot as plt\n", - "import matplotlib.ticker as ticker" + "import matplotlib.ticker as ticker\n", + "\n", + "\n", + "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", + "root_dir = (Path(__file__).parent / \"..\").resolve()\n", + "lib_dir = (root_dir / \"lib\").resolve()\n", + "print(\"The root path: {:}\".format(root_dir))\n", + "print(\"The library path: {:}\".format(lib_dir))\n", + "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", + "if str(lib_dir) not in sys.path:\n", + " sys.path.insert(0, str(lib_dir))\n", + "\n", + "from datasets import SynAdaptiveEnv\n", + "from xlayers.super_core import SuperMLPv1" ] }, { @@ -25,49 +47,97 @@ "metadata": {}, "outputs": [], "source": [ + "def optimize_fn(xs, ys, test_sets):\n", + " xs = torch.FloatTensor(xs).view(-1, 1)\n", + " ys = torch.FloatTensor(ys).view(-1, 1)\n", + " \n", + " model = SuperMLPv1(1, 10, 1, torch.nn.ReLU)\n", + " optimizer = torch.optim.Adam(\n", + " model.parameters(),\n", + " lr=0.01, weight_decay=1e-4, amsgrad=True\n", + " )\n", + " for _iter in range(100):\n", + " preds = model(ys)\n", + "\n", + " optimizer.zero_grad()\n", + " loss = torch.nn.functional.mse_loss(preds, ys)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " with torch.no_grad():\n", + " answers = []\n", + " for test_set in test_sets:\n", + " test_set = torch.FloatTensor(test_set).view(-1, 1)\n", + " preds = model(test_set).view(-1).numpy()\n", + " answers.append(preds.tolist())\n", + " return answers\n", + "\n", + "def f(x):\n", + " return np.cos( 0.5 * x + 0.)\n", + "\n", + "def get_data(mode):\n", + " dataset = SynAdaptiveEnv(mode=mode)\n", + " times, xs, ys = [], [], []\n", + " for i, (_, t, x) in enumerate(dataset):\n", + " times.append(t)\n", + " xs.append(x)\n", + " dataset.set_transform(f)\n", + " for i, (_, _, y) in enumerate(dataset):\n", + " ys.append(y)\n", + " return times, xs, ys\n", + "\n", "def visualize_syn(save_path):\n", " save_dir = (save_path / '..').resolve()\n", " save_dir.mkdir(parents=True, exist_ok=True)\n", " \n", - " dpi, width, height = 50, 2000, 1000\n", + " dpi, width, height = 40, 2000, 900\n", " figsize = width / float(dpi), height / float(dpi)\n", - " LabelSize, font_gap = 30, 4\n", + " LabelSize, LegendFontsize, font_gap = 40, 40, 5\n", " \n", " fig = plt.figure(figsize=figsize)\n", " \n", - " times = np.arange(0, np.pi * 100, 0.1)\n", - " num = len(times)\n", - " x = []\n", - " for i in range(num):\n", - " scale = (i + 1.) / num * 4\n", - " value = times[i] * scale\n", - " x.append(np.sin(value) * (1.3 - scale))\n", - " x = np.array(x)\n", - " y = np.cos( x * x - 0.3 * x )\n", + " times, xs, ys = get_data(None)\n", + " \n", + " def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,\n", + " alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):\n", + " if legend is not None:\n", + " cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)\n", + " cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)\n", + " if not plot_only:\n", + " cur_ax.set_xlabel(xlabel, fontsize=LabelSize)\n", + " cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)\n", + " for tick in cur_ax.xaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", + " tick.label.set_rotation(10)\n", + " for tick in cur_ax.yaxis.get_major_ticks():\n", + " tick.label.set_fontsize(LabelSize - font_gap)\n", " \n", " cur_ax = fig.add_subplot(2, 1, 1)\n", - " cur_ax.plot(times, x)\n", - " cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n", - " cur_ax.set_ylabel(\"x\", fontsize=LabelSize)\n", - " for tick in cur_ax.xaxis.get_major_ticks():\n", - " tick.label.set_fontsize(LabelSize - font_gap)\n", - " tick.label.set_rotation(30)\n", - " for tick in cur_ax.yaxis.get_major_ticks():\n", - " tick.label.set_fontsize(LabelSize - font_gap)\n", - " \n", - " \n", + " draw_ax(cur_ax, times, xs, \"time\", \"x\", alpha=1.0, legend=None)\n", + "\n", " cur_ax = fig.add_subplot(2, 1, 2)\n", - " cur_ax.plot(times, y)\n", - " cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n", - " cur_ax.set_ylabel(\"f(x)\", fontsize=LabelSize)\n", - " for tick in cur_ax.xaxis.get_major_ticks():\n", - " tick.label.set_fontsize(LabelSize - font_gap)\n", - " tick.label.set_rotation(30)\n", - " for tick in cur_ax.yaxis.get_major_ticks():\n", - " tick.label.set_fontsize(LabelSize - font_gap)\n", - " \n", - " # fig.tight_layout()\n", - " # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n", + " draw_ax(cur_ax, times, ys, \"time\", \"y\", alpha=0.1, legend=\"ground truth\")\n", + " \n", + " train_times, train_xs, train_ys = get_data(\"train\")\n", + " draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)\n", + " \n", + " valid_times, valid_xs, valid_ys = get_data(\"valid\")\n", + " draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)\n", + " \n", + " test_times, test_xs, test_ys = get_data(\"test\")\n", + " draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)\n", + " \n", + " # optimize MLP models\n", + " [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])\n", + " draw_ax(cur_ax, train_times, train_preds, None, None,\n", + " alpha=1.0, linestyle='--', color='r', legend=\"MLP\", plot_only=True)\n", + " draw_ax(cur_ax, valid_times, valid_preds, None, None,\n", + " alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)\n", + " draw_ax(cur_ax, test_times, test_preds, None, None,\n", + " alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)\n", + "\n", + " plt.legend(loc=1, fontsize=LegendFontsize)\n", + "\n", " fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", " plt.close(\"all\")\n", " # plt.show()" @@ -94,14 +164,6 @@ "print('The Desktop is at: {:}'.format(desktop_dir))\n", "visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "romantic-ordinance", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py new file mode 100644 index 0000000..03bee19 --- /dev/null +++ b/tests/test_synthetic.py @@ -0,0 +1,27 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# pytest tests/test_synthetic.py -s # +##################################################### +import sys, random +import unittest +import pytest +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +print("library path: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +from datasets import SynAdaptiveEnv + + +class TestSynAdaptiveEnv(unittest.TestCase): + """Test the synethtic adaptive environment.""" + + def test_simple(self): + dataset = SynAdaptiveEnv() + for i, (idx, t, x) in enumerate(dataset): + assert i == idx, "First loop: {:} vs {:}".format(i, idx) + for i, (idx, t, x) in enumerate(dataset): + assert i == idx, "Second loop: {:} vs {:}".format(i, idx)