Update sync codes

This commit is contained in:
D-X-Y 2021-04-14 01:04:46 +08:00
parent c82c7e9f3f
commit cd253112ee
5 changed files with 220 additions and 44 deletions

View File

@ -48,4 +48,5 @@ jobs:
ls
python --version
python -m pytest ./tests/test_basic_space.py -s
python -m pytest ./tests/test_synthetic.py -s
shell: bash

View File

@ -3,3 +3,5 @@
##################################################
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset
from .synthetic_adaptive_environment import SynAdaptiveEnv

View File

@ -0,0 +1,84 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import numpy as np
from typing import Optional
import torch.utils.data as data
class SynAdaptiveEnv(data.Dataset):
"""The synethtic dataset for adaptive environment."""
def __init__(
self,
max_num_phase: int = 100,
interval: float = 0.1,
max_scale: float = 4,
offset_scale: float = 1.5,
mode: Optional[str] = None,
):
self._max_num_phase = max_num_phase
self._interval = interval
self._times = np.arange(0, np.pi * self._max_num_phase, self._interval)
xmin, xmax = self._times.min(), self._times.max()
self._inputs = []
self._total_num = len(self._times)
for i in range(self._total_num):
scale = (i + 1.0) / self._total_num * max_scale
sin_scale = (i + 1.0) / self._total_num * 0.7
sin_scale = -4 * (sin_scale - 0.5) ** 2 + 1
# scale = -(self._times[i] - (xmin - xmax) / 2) + max_scale
self._inputs.append(
np.sin(self._times[i] * sin_scale) * (offset_scale - scale)
)
self._inputs = np.array(self._inputs)
# Training Set 60%
num_of_train = int(self._total_num * 0.6)
# Validation Set 20%
num_of_valid = int(self._total_num * 0.2)
# Test Set 20%
num_of_set = self._total_num - num_of_train - num_of_valid
all_indexes = list(range(self._total_num))
if mode is None:
self._indexes = all_indexes
elif mode.lower() in ("train", "training"):
self._indexes = all_indexes[:num_of_train]
elif mode.lower() in ("valid", "validation"):
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
elif mode.lower() in ("test", "testing"):
self._indexes = all_indexes[num_of_train + num_of_valid :]
else:
raise ValueError("Unkonwn mode of {:}".format(mode))
# transformation function
self._transform = None
def set_transform(self, fn):
self._transform = fn
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
value = float(self._inputs[index])
if self._transform is not None:
value = self._transform(value)
return index, float(self._times[index]), value
def __len__(self):
return len(self._indexes)
def __repr__(self):
return "{name}({cur_num:}/{total} elements)".format(
name=self.__class__.__name__, cur_num=self._total_num, total=len(self)
)

View File

@ -5,17 +5,39 @@
"execution_count": 1,
"id": "filled-multiple",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
}
],
"source": [
"#\n",
"# %matplotlib notebook\n",
"import os, sys\n",
"import torch\n",
"from pathlib import Path\n",
"import numpy as np\n",
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
"import matplotlib.ticker as ticker\n",
"\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"from datasets import SynAdaptiveEnv\n",
"from xlayers.super_core import SuperMLPv1"
]
},
{
@ -25,49 +47,97 @@
"metadata": {},
"outputs": [],
"source": [
"def optimize_fn(xs, ys, test_sets):\n",
" xs = torch.FloatTensor(xs).view(-1, 1)\n",
" ys = torch.FloatTensor(ys).view(-1, 1)\n",
" \n",
" model = SuperMLPv1(1, 10, 1, torch.nn.ReLU)\n",
" optimizer = torch.optim.Adam(\n",
" model.parameters(),\n",
" lr=0.01, weight_decay=1e-4, amsgrad=True\n",
" )\n",
" for _iter in range(100):\n",
" preds = model(ys)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = torch.nn.functional.mse_loss(preds, ys)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" with torch.no_grad():\n",
" answers = []\n",
" for test_set in test_sets:\n",
" test_set = torch.FloatTensor(test_set).view(-1, 1)\n",
" preds = model(test_set).view(-1).numpy()\n",
" answers.append(preds.tolist())\n",
" return answers\n",
"\n",
"def f(x):\n",
" return np.cos( 0.5 * x + 0.)\n",
"\n",
"def get_data(mode):\n",
" dataset = SynAdaptiveEnv(mode=mode)\n",
" times, xs, ys = [], [], []\n",
" for i, (_, t, x) in enumerate(dataset):\n",
" times.append(t)\n",
" xs.append(x)\n",
" dataset.set_transform(f)\n",
" for i, (_, _, y) in enumerate(dataset):\n",
" ys.append(y)\n",
" return times, xs, ys\n",
"\n",
"def visualize_syn(save_path):\n",
" save_dir = (save_path / '..').resolve()\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" \n",
" dpi, width, height = 50, 2000, 1000\n",
" dpi, width, height = 40, 2000, 900\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, font_gap = 30, 4\n",
" LabelSize, LegendFontsize, font_gap = 40, 40, 5\n",
" \n",
" fig = plt.figure(figsize=figsize)\n",
" \n",
" times = np.arange(0, np.pi * 100, 0.1)\n",
" num = len(times)\n",
" x = []\n",
" for i in range(num):\n",
" scale = (i + 1.) / num * 4\n",
" value = times[i] * scale\n",
" x.append(np.sin(value) * (1.3 - scale))\n",
" x = np.array(x)\n",
" y = np.cos( x * x - 0.3 * x )\n",
" times, xs, ys = get_data(None)\n",
" \n",
" def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,\n",
" alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):\n",
" if legend is not None:\n",
" cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)\n",
" cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)\n",
" if not plot_only:\n",
" cur_ax.set_xlabel(xlabel, fontsize=LabelSize)\n",
" cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" tick.label.set_rotation(10)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" \n",
" cur_ax = fig.add_subplot(2, 1, 1)\n",
" cur_ax.plot(times, x)\n",
" cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"x\", fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" tick.label.set_rotation(30)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" \n",
" \n",
" draw_ax(cur_ax, times, xs, \"time\", \"x\", alpha=1.0, legend=None)\n",
"\n",
" cur_ax = fig.add_subplot(2, 1, 2)\n",
" cur_ax.plot(times, y)\n",
" cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"f(x)\", fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" tick.label.set_rotation(30)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" \n",
" # fig.tight_layout()\n",
" # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
" draw_ax(cur_ax, times, ys, \"time\", \"y\", alpha=0.1, legend=\"ground truth\")\n",
" \n",
" train_times, train_xs, train_ys = get_data(\"train\")\n",
" draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)\n",
" \n",
" valid_times, valid_xs, valid_ys = get_data(\"valid\")\n",
" draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)\n",
" \n",
" test_times, test_xs, test_ys = get_data(\"test\")\n",
" draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)\n",
" \n",
" # optimize MLP models\n",
" [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])\n",
" draw_ax(cur_ax, train_times, train_preds, None, None,\n",
" alpha=1.0, linestyle='--', color='r', legend=\"MLP\", plot_only=True)\n",
" draw_ax(cur_ax, valid_times, valid_preds, None, None,\n",
" alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)\n",
" draw_ax(cur_ax, test_times, test_preds, None, None,\n",
" alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)\n",
"\n",
" plt.legend(loc=1, fontsize=LegendFontsize)\n",
"\n",
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" plt.close(\"all\")\n",
" # plt.show()"
@ -94,14 +164,6 @@
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
"visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "romantic-ordinance",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

27
tests/test_synthetic.py Normal file
View File

@ -0,0 +1,27 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest tests/test_synthetic.py -s #
#####################################################
import sys, random
import unittest
import pytest
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from datasets import SynAdaptiveEnv
class TestSynAdaptiveEnv(unittest.TestCase):
"""Test the synethtic adaptive environment."""
def test_simple(self):
dataset = SynAdaptiveEnv()
for i, (idx, t, x) in enumerate(dataset):
assert i == idx, "First loop: {:} vs {:}".format(i, idx)
for i, (idx, t, x) in enumerate(dataset):
assert i == idx, "Second loop: {:} vs {:}".format(i, idx)