Add SuperSimpleNorm and update synthetic env
This commit is contained in:
		| @@ -1,9 +1,9 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
| ##################################################### | ############################################################################ | ||||||
| # python exps/synthetic/baseline.py                 # | # CUDA_VISIBLE_DEVICES=0 python exps/synthetic/baseline.py                 # | ||||||
| ##################################################### | ############################################################################ | ||||||
| import os, sys, copy | import os, sys, copy, random | ||||||
| import torch | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| import argparse | import argparse | ||||||
| @@ -28,6 +28,8 @@ from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv | |||||||
| from datasets import DynamicQuadraticFunc | from datasets import DynamicQuadraticFunc | ||||||
| from datasets.synthetic_example import create_example_v1 | from datasets.synthetic_example import create_example_v1 | ||||||
|  |  | ||||||
|  | from utils.temp_sync import optimize_fn, evaluate_fn | ||||||
|  |  | ||||||
|  |  | ||||||
| def draw_fig(save_dir, timestamp, scatter_list): | def draw_fig(save_dir, timestamp, scatter_list): | ||||||
|     save_path = save_dir / "{:04d}".format(timestamp) |     save_path = save_dir / "{:04d}".format(timestamp) | ||||||
| @@ -67,28 +69,55 @@ def draw_fig(save_dir, timestamp, scatter_list): | |||||||
| def main(save_dir): | def main(save_dir): | ||||||
|     save_dir = Path(str(save_dir)) |     save_dir = Path(str(save_dir)) | ||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |     save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|     dynamic_env, function = create_example_v1(100, num_per_task=500) |     dynamic_env, function = create_example_v1(100, num_per_task=1000) | ||||||
|  |  | ||||||
|     additional_xaxis = np.arange(-6, 6, 0.1) |     additional_xaxis = np.arange(-6, 6, 0.2) | ||||||
|     for timestamp, dataset in tqdm(dynamic_env, ncols=50): |     models = dict() | ||||||
|         num = dataset.shape[0] |      | ||||||
|         xaxis = dataset[:, 0].numpy() |     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|  |         xaxis_all = dataset[:, 0].numpy() | ||||||
|  |         # xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) | ||||||
|         # compute the ground truth |         # compute the ground truth | ||||||
|         function.set_timestamp(timestamp) |         function.set_timestamp(timestamp) | ||||||
|         yaxis = function(xaxis) |         yaxis_all = function.noise_call(xaxis_all) | ||||||
|         # xaxis = np.concatenate((additional_xaxis, xaxis)) |  | ||||||
|  |         # split the dataset | ||||||
|  |         indexes = list(range(xaxis_all.shape[0])) | ||||||
|  |         random.shuffle(indexes) | ||||||
|  |         train_indexes = indexes[:len(indexes)//2] | ||||||
|  |         valid_indexes = indexes[len(indexes)//2:] | ||||||
|  |         train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes] | ||||||
|  |         valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes] | ||||||
|  |          | ||||||
|  |         model, loss_fn, train_loss = optimize_fn(train_xs, train_ys) | ||||||
|  |         # model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all) | ||||||
|  |         pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn) | ||||||
|  |         print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss)) | ||||||
|  |  | ||||||
|         # the first plot |         # the first plot | ||||||
|         scatter_list = [] |         scatter_list = [] | ||||||
|         scatter_list.append( |         scatter_list.append( | ||||||
|             { |             { | ||||||
|                 "xaxis": xaxis, |                 "xaxis": valid_xs, | ||||||
|                 "yaxis": yaxis, |                 "yaxis": valid_ys, | ||||||
|                 "color": "k", |                 "color": "k", | ||||||
|                 "s": 10, |                 "s": 10, | ||||||
|                 "alpha": 0.99, |                 "alpha": 0.99, | ||||||
|                 "label": "Timestamp={:02d}".format(timestamp), |                 "label": "Timestamp={:02d}".format(timestamp), | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |         scatter_list.append( | ||||||
|  |             { | ||||||
|  |                 "xaxis": valid_xs, | ||||||
|  |                 "yaxis": pred_valid_ys, | ||||||
|  |                 "color": "r", | ||||||
|  |                 "s": 10, | ||||||
|  |                 "alpha": 0.5, | ||||||
|  |                 "label": "MLP at now" | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |          | ||||||
|         draw_fig(save_dir, timestamp, scatter_list) |         draw_fig(save_dir, timestamp, scatter_list) | ||||||
|     print("Save all figures into {:}".format(save_dir)) |     print("Save all figures into {:}".format(save_dir)) | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|   | |||||||
| @@ -33,6 +33,14 @@ class FitFunc(abc.ABC): | |||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def noise_call(self, x, std=0.1): | ||||||
|  |         clean_y = self.__call__(x) | ||||||
|  |         if isinstance(clean_y, np.ndarray): | ||||||
|  |             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||||
|  |         return noise_y | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def _getitem(self, x): |     def _getitem(self, x): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|   | |||||||
							
								
								
									
										63
									
								
								lib/utils/temp_sync.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								lib/utils/temp_sync.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | |||||||
|  | # To be deleted. | ||||||
|  | import copy | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from xlayers.super_core import SuperSequential, SuperMLPv1 | ||||||
|  | from xlayers.super_core import SuperSimpleNorm | ||||||
|  | from xlayers.super_core import SuperLinear | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1): | ||||||
|  |     xs = torch.FloatTensor(xs).view(-1, 1).to(device) | ||||||
|  |     ys = torch.FloatTensor(ys).view(-1, 1).to(device) | ||||||
|  |  | ||||||
|  |     model = SuperSequential( | ||||||
|  |         SuperSimpleNorm(xs.mean().item(), xs.std().item()), | ||||||
|  |         SuperLinear(1, 200), | ||||||
|  |         torch.nn.LeakyReLU(), | ||||||
|  |         SuperLinear(200, 100), | ||||||
|  |         torch.nn.LeakyReLU(), | ||||||
|  |         SuperLinear(100, 1), | ||||||
|  |     ).to(device) | ||||||
|  |     model.train() | ||||||
|  |     optimizer = torch.optim.Adam( | ||||||
|  |         model.parameters(), lr=max_lr, amsgrad=True | ||||||
|  |     ) | ||||||
|  |     loss_func = torch.nn.MSELoss() | ||||||
|  |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|  |         optimizer, | ||||||
|  |         milestones=[ | ||||||
|  |             int(max_iter * 0.25), | ||||||
|  |             int(max_iter * 0.5), | ||||||
|  |             int(max_iter * 0.75), | ||||||
|  |         ], | ||||||
|  |         gamma=0.3, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     best_loss, best_param = None, None | ||||||
|  |     for _iter in range(max_iter): | ||||||
|  |         preds = model(xs) | ||||||
|  |  | ||||||
|  |         optimizer.zero_grad() | ||||||
|  |         loss = loss_func(preds, ys) | ||||||
|  |         loss.backward() | ||||||
|  |         optimizer.step() | ||||||
|  |         lr_scheduler.step() | ||||||
|  |  | ||||||
|  |         if best_loss is None or best_loss > loss.item(): | ||||||
|  |             best_loss = loss.item() | ||||||
|  |             best_param = copy.deepcopy(model.state_dict()) | ||||||
|  |          | ||||||
|  |         # print('loss={:}, best-loss={:}'.format(loss.item(), best_loss)) | ||||||
|  |     model.load_state_dict(best_param) | ||||||
|  |     return model, loss_func, best_loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def evaluate_fn(model, xs, ys, loss_fn, device="cpu"): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         inputs = torch.FloatTensor(xs).view(-1, 1).to(device) | ||||||
|  |         ys = torch.FloatTensor(ys).view(-1, 1).to(device) | ||||||
|  |         preds = model(inputs) | ||||||
|  |         loss = loss_fn(preds, ys) | ||||||
|  |         preds = preds.view(-1).cpu().numpy() | ||||||
|  |     return preds, loss.item() | ||||||
| @@ -91,6 +91,8 @@ class SuperSequential(SuperModule): | |||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         root_node = spaces.VirtualNode(id(self)) |         root_node = spaces.VirtualNode(id(self)) | ||||||
|         for index, module in enumerate(self): |         for index, module in enumerate(self): | ||||||
|  |             if not isinstance(module, SuperModule): | ||||||
|  |                 continue | ||||||
|             space = module.abstract_search_space |             space = module.abstract_search_space | ||||||
|             if not spaces.is_determined(space): |             if not spaces.is_determined(space): | ||||||
|                 root_node.append(str(index), space) |                 root_node.append(str(index), space) | ||||||
| @@ -98,9 +100,9 @@ class SuperSequential(SuperModule): | |||||||
|  |  | ||||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|         super(SuperSequential, self).apply_candidate(abstract_child) |         super(SuperSequential, self).apply_candidate(abstract_child) | ||||||
|         for index in range(len(self)): |         for index, module in enumerate(self): | ||||||
|             if str(index) in abstract_child: |             if str(index) in abstract_child: | ||||||
|                 self.__getitem__(index).apply_candidate(abstract_child[str(index)]) |                 module.apply_candidate(abstract_child[str(index)]) | ||||||
|  |  | ||||||
|     def forward_candidate(self, input): |     def forward_candidate(self, input): | ||||||
|         return self.forward_raw(input) |         return self.forward_raw(input) | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ from .super_module import SuperModule | |||||||
| from .super_container import SuperSequential | from .super_container import SuperSequential | ||||||
| from .super_linear import SuperLinear | from .super_linear import SuperLinear | ||||||
| from .super_linear import SuperMLPv1, SuperMLPv2 | from .super_linear import SuperMLPv1, SuperMLPv2 | ||||||
|  | from .super_norm import SuperSimpleNorm | ||||||
| from .super_norm import SuperLayerNorm1D | from .super_norm import SuperLayerNorm1D | ||||||
| from .super_attention import SuperAttention | from .super_attention import SuperAttention | ||||||
| from .super_transformer import SuperTransformerEncoderLayer | from .super_transformer import SuperTransformerEncoderLayer | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| import abc | import abc | ||||||
|  | import warnings | ||||||
| from typing import Optional, Union, Callable | from typing import Optional, Union, Callable | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| @@ -45,6 +46,17 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|  |  | ||||||
|         self.apply(_reset_super_run) |         self.apply(_reset_super_run) | ||||||
|  |  | ||||||
|  |     def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: | ||||||
|  |         if not isinstance(module, SuperModule): | ||||||
|  |             warnings.warn( | ||||||
|  |                 "Add {:} module, which is not SuperModule, into {:}".format( | ||||||
|  |                     name, self.__class__.__name__ | ||||||
|  |                 ) | ||||||
|  |                 + "\n" | ||||||
|  |                 + "It may cause some functions invalid." | ||||||
|  |             ) | ||||||
|  |         super(SuperModule, self).add_module(name, module) | ||||||
|  |  | ||||||
|     def apply_verbose(self, verbose): |     def apply_verbose(self, verbose): | ||||||
|         def _reset_verbose(m): |         def _reset_verbose(m): | ||||||
|             if isinstance(m, SuperModule): |             if isinstance(m, SuperModule): | ||||||
|   | |||||||
| @@ -82,3 +82,43 @@ class SuperLayerNorm1D(SuperModule): | |||||||
|                 elementwise_affine=self._elementwise_affine, |                 elementwise_affine=self._elementwise_affine, | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperSimpleNorm(SuperModule): | ||||||
|  |     """Super simple normalization.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, mean, std, inplace=False) -> None: | ||||||
|  |         super(SuperSimpleNorm, self).__init__() | ||||||
|  |         self._mean = mean | ||||||
|  |         self._std = std | ||||||
|  |         self._inplace = inplace | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         return spaces.VirtualNode(id(self)) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         # check inputs -> | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         if not self._inplace: | ||||||
|  |             tensor = input.clone() | ||||||
|  |         else: | ||||||
|  |             tensor = input | ||||||
|  |         mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device) | ||||||
|  |         std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device) | ||||||
|  |         if (std == 0).any(): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "std evaluated to zero after conversion to {}, leading to division by zero.".format( | ||||||
|  |                     dtype | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         while mean.ndim < tensor.ndim: | ||||||
|  |             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||||
|  |         return tensor.sub_(mean).div_(std) | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "mean={mean}, std={mean}, inplace={inplace}".format( | ||||||
|  |             mean=self._mean, std=self._std, inplace=self._inplace | ||||||
|  |         ) | ||||||
|   | |||||||
| @@ -107,113 +107,6 @@ | |||||||
|     "visualize_env()" |     "visualize_env()" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": 3, |  | ||||||
|    "id": "supreme-basis", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "# def optimize_fn(xs, ys, test_sets):\n", |  | ||||||
|     "#     xs = torch.FloatTensor(xs).view(-1, 1)\n", |  | ||||||
|     "#     ys = torch.FloatTensor(ys).view(-1, 1)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     model = SuperSequential(\n", |  | ||||||
|     "#         SuperMLPv1(1, 10, 20, torch.nn.ReLU),\n", |  | ||||||
|     "#         SuperMLPv1(20, 10, 1, torch.nn.ReLU)\n", |  | ||||||
|     "#     )\n", |  | ||||||
|     "#     optimizer = torch.optim.Adam(\n", |  | ||||||
|     "#         model.parameters(),\n", |  | ||||||
|     "#         lr=0.01, weight_decay=1e-4, amsgrad=True\n", |  | ||||||
|     "#     )\n", |  | ||||||
|     "#     for _iter in range(100):\n", |  | ||||||
|     "#         preds = model(ys)\n", |  | ||||||
|     "\n", |  | ||||||
|     "#         optimizer.zero_grad()\n", |  | ||||||
|     "#         loss = torch.nn.functional.mse_loss(preds, ys)\n", |  | ||||||
|     "#         loss.backward()\n", |  | ||||||
|     "#         optimizer.step()\n", |  | ||||||
|     "        \n", |  | ||||||
|     "#     with torch.no_grad():\n", |  | ||||||
|     "#         answers = []\n", |  | ||||||
|     "#         for test_set in test_sets:\n", |  | ||||||
|     "#             test_set = torch.FloatTensor(test_set).view(-1, 1)\n", |  | ||||||
|     "#             preds = model(test_set).view(-1).numpy()\n", |  | ||||||
|     "#             answers.append(preds.tolist())\n", |  | ||||||
|     "#     return answers\n", |  | ||||||
|     "\n", |  | ||||||
|     "# def f(x):\n", |  | ||||||
|     "#     return np.cos( 0.5 * x + x * x)\n", |  | ||||||
|     "\n", |  | ||||||
|     "# def get_data(mode):\n", |  | ||||||
|     "#     dataset = SynAdaptiveEnv(mode=mode)\n", |  | ||||||
|     "#     times, xs, ys = [], [], []\n", |  | ||||||
|     "#     for i, (_, t, x) in enumerate(dataset):\n", |  | ||||||
|     "#         times.append(t)\n", |  | ||||||
|     "#         xs.append(x)\n", |  | ||||||
|     "#     dataset.set_transform(f)\n", |  | ||||||
|     "#     for i, (_, _, y) in enumerate(dataset):\n", |  | ||||||
|     "#         ys.append(y)\n", |  | ||||||
|     "#     return times, xs, ys\n", |  | ||||||
|     "\n", |  | ||||||
|     "# def visualize_syn(save_path):\n", |  | ||||||
|     "#     save_dir = (save_path / '..').resolve()\n", |  | ||||||
|     "#     save_dir.mkdir(parents=True, exist_ok=True)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     dpi, width, height = 40, 2000, 900\n", |  | ||||||
|     "#     figsize = width / float(dpi), height / float(dpi)\n", |  | ||||||
|     "#     LabelSize, LegendFontsize, font_gap = 40, 40, 5\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     fig = plt.figure(figsize=figsize)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     times, xs, ys = get_data(None)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,\n", |  | ||||||
|     "#                 alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):\n", |  | ||||||
|     "#         if legend is not None:\n", |  | ||||||
|     "#             cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)\n", |  | ||||||
|     "#         cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)\n", |  | ||||||
|     "#         if not plot_only:\n", |  | ||||||
|     "#             cur_ax.set_xlabel(xlabel, fontsize=LabelSize)\n", |  | ||||||
|     "#             cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)\n", |  | ||||||
|     "#             for tick in cur_ax.xaxis.get_major_ticks():\n", |  | ||||||
|     "#                 tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "#                 tick.label.set_rotation(10)\n", |  | ||||||
|     "#             for tick in cur_ax.yaxis.get_major_ticks():\n", |  | ||||||
|     "#                 tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     cur_ax = fig.add_subplot(2, 1, 1)\n", |  | ||||||
|     "#     draw_ax(cur_ax, times, xs, \"time\", \"x\", alpha=1.0, legend=None)\n", |  | ||||||
|     "\n", |  | ||||||
|     "#     cur_ax = fig.add_subplot(2, 1, 2)\n", |  | ||||||
|     "#     draw_ax(cur_ax, times, ys, \"time\", \"y\", alpha=0.1, legend=\"ground truth\")\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     train_times, train_xs, train_ys = get_data(\"train\")\n", |  | ||||||
|     "#     draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     valid_times, valid_xs, valid_ys = get_data(\"valid\")\n", |  | ||||||
|     "#     draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     test_times, test_xs, test_ys = get_data(\"test\")\n", |  | ||||||
|     "#     draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     # optimize MLP models\n", |  | ||||||
|     "# #     [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])\n", |  | ||||||
|     "# #     draw_ax(cur_ax, train_times, train_preds, None, None,\n", |  | ||||||
|     "# #             alpha=1.0, linestyle='--', color='r', legend=\"MLP\", plot_only=True)\n", |  | ||||||
|     "# #     import pdb; pdb.set_trace()\n", |  | ||||||
|     "# #     draw_ax(cur_ax, valid_times, valid_preds, None, None,\n", |  | ||||||
|     "# #             alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)\n", |  | ||||||
|     "# #     draw_ax(cur_ax, test_times, test_preds, None, None,\n", |  | ||||||
|     "# #             alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)\n", |  | ||||||
|     "\n", |  | ||||||
|     "#     plt.legend(loc=1, fontsize=LegendFontsize)\n", |  | ||||||
|     "\n", |  | ||||||
|     "#     fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", |  | ||||||
|     "#     plt.close(\"all\")\n", |  | ||||||
|     "#     # plt.show()" |  | ||||||
|    ] |  | ||||||
|   } |  | ||||||
|  ], |  ], | ||||||
|  "metadata": { |  "metadata": { | ||||||
|   "kernelspec": { |   "kernelspec": { | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ gpu=$1 | |||||||
| market=$2 | market=$2 | ||||||
|  |  | ||||||
| # algorithms="NAIVE-V1 NAIVE-V2 MLP GRU LSTM ALSTM XGBoost LightGBM SFM TabNet DoubleE" | # algorithms="NAIVE-V1 NAIVE-V2 MLP GRU LSTM ALSTM XGBoost LightGBM SFM TabNet DoubleE" | ||||||
| algorithms="MLP GRU LSTM ALSTM XGBoost LightGBM SFM TabNet DoubleE" | algorithms="XGBoost LightGBM SFM TabNet DoubleE" | ||||||
|  |  | ||||||
| for alg in ${algorithms} | for alg in ${algorithms} | ||||||
| do | do | ||||||
|   | |||||||
| @@ -1,251 +0,0 @@ | |||||||
| { |  | ||||||
|  "cells": [ |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": 1, |  | ||||||
|    "id": "filled-multiple", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [ |  | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "The root path: /Users/xuanyidong\n", |  | ||||||
|       "The library path: /Users/xuanyidong/lib\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "ename": "AssertionError", |  | ||||||
|      "evalue": "/Users/xuanyidong/lib does not exist", |  | ||||||
|      "output_type": "error", |  | ||||||
|      "traceback": [ |  | ||||||
|       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |  | ||||||
|       "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)", |  | ||||||
|       "\u001b[0;32m~/Desktop/AutoDL-Projects\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The root path: {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The library path: {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mlib_dir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"{:} does not exist\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |  | ||||||
|       "\u001b[0;31mAssertionError\u001b[0m: /Users/xuanyidong/lib does not exist" |  | ||||||
|      ] |  | ||||||
|     } |  | ||||||
|    ], |  | ||||||
|    "source": [ |  | ||||||
|     "import os, sys\n", |  | ||||||
|     "import torch\n", |  | ||||||
|     "from pathlib import Path\n", |  | ||||||
|     "import numpy as np\n", |  | ||||||
|     "import matplotlib\n", |  | ||||||
|     "from matplotlib import cm\n", |  | ||||||
|     "# matplotlib.use(\"agg\")\n", |  | ||||||
|     "import matplotlib.pyplot as plt\n", |  | ||||||
|     "import matplotlib.ticker as ticker\n", |  | ||||||
|     "\n", |  | ||||||
|     "\n", |  | ||||||
|     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", |  | ||||||
|     "root_dir = (Path(__file__).parent / \"..\").resolve()\n", |  | ||||||
|     "lib_dir = (root_dir / \"lib\").resolve()\n", |  | ||||||
|     "print(\"The root path: {:}\".format(root_dir))\n", |  | ||||||
|     "print(\"The library path: {:}\".format(lib_dir))\n", |  | ||||||
|     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", |  | ||||||
|     "if str(lib_dir) not in sys.path:\n", |  | ||||||
|     "    sys.path.insert(0, str(lib_dir))\n", |  | ||||||
|     "\n", |  | ||||||
|     "from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv\n", |  | ||||||
|     "from datasets import DynamicQuadraticFunc\n", |  | ||||||
|     "from datasets.synthetic_example import create_example_v1" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": null, |  | ||||||
|    "id": "detected-second", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "def visualize_env():\n", |  | ||||||
|     "    \n", |  | ||||||
|     "    dpi, width, height = 10, 800, 400\n", |  | ||||||
|     "    figsize = width / float(dpi), height / float(dpi)\n", |  | ||||||
|     "    LabelSize, LegendFontsize, font_gap = 40, 40, 5\n", |  | ||||||
|     "\n", |  | ||||||
|     "    fig = plt.figure(figsize=figsize)\n", |  | ||||||
|     "\n", |  | ||||||
|     "    dynamic_env, function = create_example_v1(100, num_per_task=250)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "    timeaxis, xaxis, yaxis = [], [], []\n", |  | ||||||
|     "    for timestamp, dataset in dynamic_env:\n", |  | ||||||
|     "        num = dataset.shape[0]\n", |  | ||||||
|     "        timeaxis.append(torch.zeros(num) + timestamp)\n", |  | ||||||
|     "        xaxis.append(dataset[:,0])\n", |  | ||||||
|     "        # compute the ground truth\n", |  | ||||||
|     "        function.set_timestamp(timestamp)\n", |  | ||||||
|     "        yaxis.append(function(dataset[:,0]))\n", |  | ||||||
|     "    timeaxis = torch.cat(timeaxis).numpy()\n", |  | ||||||
|     "    xaxis = torch.cat(xaxis).numpy()\n", |  | ||||||
|     "    yaxis = torch.cat(yaxis).numpy()\n", |  | ||||||
|     "\n", |  | ||||||
|     "    cur_ax = fig.add_subplot(2, 1, 1)\n", |  | ||||||
|     "    cur_ax.scatter(timeaxis, xaxis, color=\"k\", linestyle=\"-\", alpha=0.9, label=None)\n", |  | ||||||
|     "    cur_ax.set_xlabel(\"Time\", fontsize=LabelSize)\n", |  | ||||||
|     "    cur_ax.set_ylabel(\"X\", rotation=0, fontsize=LabelSize)\n", |  | ||||||
|     "    for tick in cur_ax.xaxis.get_major_ticks():\n", |  | ||||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "        tick.label.set_rotation(10)\n", |  | ||||||
|     "    for tick in cur_ax.yaxis.get_major_ticks():\n", |  | ||||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "    cur_ax = fig.add_subplot(2, 1, 2)\n", |  | ||||||
|     "    cur_ax.scatter(timeaxis, yaxis, color=\"k\", linestyle=\"-\", alpha=0.9, label=None)\n", |  | ||||||
|     "    cur_ax.set_xlabel(\"Time\", fontsize=LabelSize)\n", |  | ||||||
|     "    cur_ax.set_ylabel(\"Y\", rotation=0, fontsize=LabelSize)\n", |  | ||||||
|     "    for tick in cur_ax.xaxis.get_major_ticks():\n", |  | ||||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "        tick.label.set_rotation(10)\n", |  | ||||||
|     "    for tick in cur_ax.yaxis.get_major_ticks():\n", |  | ||||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "    plt.show()\n", |  | ||||||
|     "\n", |  | ||||||
|     "visualize_env()" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": null, |  | ||||||
|    "id": "supreme-basis", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "# def optimize_fn(xs, ys, test_sets):\n", |  | ||||||
|     "#     xs = torch.FloatTensor(xs).view(-1, 1)\n", |  | ||||||
|     "#     ys = torch.FloatTensor(ys).view(-1, 1)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     model = SuperSequential(\n", |  | ||||||
|     "#         SuperMLPv1(1, 10, 20, torch.nn.ReLU),\n", |  | ||||||
|     "#         SuperMLPv1(20, 10, 1, torch.nn.ReLU)\n", |  | ||||||
|     "#     )\n", |  | ||||||
|     "#     optimizer = torch.optim.Adam(\n", |  | ||||||
|     "#         model.parameters(),\n", |  | ||||||
|     "#         lr=0.01, weight_decay=1e-4, amsgrad=True\n", |  | ||||||
|     "#     )\n", |  | ||||||
|     "#     for _iter in range(100):\n", |  | ||||||
|     "#         preds = model(ys)\n", |  | ||||||
|     "\n", |  | ||||||
|     "#         optimizer.zero_grad()\n", |  | ||||||
|     "#         loss = torch.nn.functional.mse_loss(preds, ys)\n", |  | ||||||
|     "#         loss.backward()\n", |  | ||||||
|     "#         optimizer.step()\n", |  | ||||||
|     "        \n", |  | ||||||
|     "#     with torch.no_grad():\n", |  | ||||||
|     "#         answers = []\n", |  | ||||||
|     "#         for test_set in test_sets:\n", |  | ||||||
|     "#             test_set = torch.FloatTensor(test_set).view(-1, 1)\n", |  | ||||||
|     "#             preds = model(test_set).view(-1).numpy()\n", |  | ||||||
|     "#             answers.append(preds.tolist())\n", |  | ||||||
|     "#     return answers\n", |  | ||||||
|     "\n", |  | ||||||
|     "# def f(x):\n", |  | ||||||
|     "#     return np.cos( 0.5 * x + x * x)\n", |  | ||||||
|     "\n", |  | ||||||
|     "# def get_data(mode):\n", |  | ||||||
|     "#     dataset = SynAdaptiveEnv(mode=mode)\n", |  | ||||||
|     "#     times, xs, ys = [], [], []\n", |  | ||||||
|     "#     for i, (_, t, x) in enumerate(dataset):\n", |  | ||||||
|     "#         times.append(t)\n", |  | ||||||
|     "#         xs.append(x)\n", |  | ||||||
|     "#     dataset.set_transform(f)\n", |  | ||||||
|     "#     for i, (_, _, y) in enumerate(dataset):\n", |  | ||||||
|     "#         ys.append(y)\n", |  | ||||||
|     "#     return times, xs, ys\n", |  | ||||||
|     "\n", |  | ||||||
|     "# def visualize_syn(save_path):\n", |  | ||||||
|     "#     save_dir = (save_path / '..').resolve()\n", |  | ||||||
|     "#     save_dir.mkdir(parents=True, exist_ok=True)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     dpi, width, height = 40, 2000, 900\n", |  | ||||||
|     "#     figsize = width / float(dpi), height / float(dpi)\n", |  | ||||||
|     "#     LabelSize, LegendFontsize, font_gap = 40, 40, 5\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     fig = plt.figure(figsize=figsize)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     times, xs, ys = get_data(None)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,\n", |  | ||||||
|     "#                 alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):\n", |  | ||||||
|     "#         if legend is not None:\n", |  | ||||||
|     "#             cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)\n", |  | ||||||
|     "#         cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)\n", |  | ||||||
|     "#         if not plot_only:\n", |  | ||||||
|     "#             cur_ax.set_xlabel(xlabel, fontsize=LabelSize)\n", |  | ||||||
|     "#             cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)\n", |  | ||||||
|     "#             for tick in cur_ax.xaxis.get_major_ticks():\n", |  | ||||||
|     "#                 tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "#                 tick.label.set_rotation(10)\n", |  | ||||||
|     "#             for tick in cur_ax.yaxis.get_major_ticks():\n", |  | ||||||
|     "#                 tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     cur_ax = fig.add_subplot(2, 1, 1)\n", |  | ||||||
|     "#     draw_ax(cur_ax, times, xs, \"time\", \"x\", alpha=1.0, legend=None)\n", |  | ||||||
|     "\n", |  | ||||||
|     "#     cur_ax = fig.add_subplot(2, 1, 2)\n", |  | ||||||
|     "#     draw_ax(cur_ax, times, ys, \"time\", \"y\", alpha=0.1, legend=\"ground truth\")\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     train_times, train_xs, train_ys = get_data(\"train\")\n", |  | ||||||
|     "#     draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     valid_times, valid_xs, valid_ys = get_data(\"valid\")\n", |  | ||||||
|     "#     draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     test_times, test_xs, test_ys = get_data(\"test\")\n", |  | ||||||
|     "#     draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "#     # optimize MLP models\n", |  | ||||||
|     "# #     [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])\n", |  | ||||||
|     "# #     draw_ax(cur_ax, train_times, train_preds, None, None,\n", |  | ||||||
|     "# #             alpha=1.0, linestyle='--', color='r', legend=\"MLP\", plot_only=True)\n", |  | ||||||
|     "# #     import pdb; pdb.set_trace()\n", |  | ||||||
|     "# #     draw_ax(cur_ax, valid_times, valid_preds, None, None,\n", |  | ||||||
|     "# #             alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)\n", |  | ||||||
|     "# #     draw_ax(cur_ax, test_times, test_preds, None, None,\n", |  | ||||||
|     "# #             alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)\n", |  | ||||||
|     "\n", |  | ||||||
|     "#     plt.legend(loc=1, fontsize=LegendFontsize)\n", |  | ||||||
|     "\n", |  | ||||||
|     "#     fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", |  | ||||||
|     "#     plt.close(\"all\")\n", |  | ||||||
|     "#     # plt.show()" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": null, |  | ||||||
|    "id": "shared-envelope", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "# Visualization\n", |  | ||||||
|     "# home_dir = Path.home()\n", |  | ||||||
|     "# desktop_dir = home_dir / 'Desktop'\n", |  | ||||||
|     "# print('The Desktop is at: {:}'.format(desktop_dir))\n", |  | ||||||
|     "# visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')" |  | ||||||
|    ] |  | ||||||
|   } |  | ||||||
|  ], |  | ||||||
|  "metadata": { |  | ||||||
|   "kernelspec": { |  | ||||||
|    "display_name": "Python 3", |  | ||||||
|    "language": "python", |  | ||||||
|    "name": "python3" |  | ||||||
|   }, |  | ||||||
|   "language_info": { |  | ||||||
|    "codemirror_mode": { |  | ||||||
|     "name": "ipython", |  | ||||||
|     "version": 3 |  | ||||||
|    }, |  | ||||||
|    "file_extension": ".py", |  | ||||||
|    "mimetype": "text/x-python", |  | ||||||
|    "name": "python", |  | ||||||
|    "nbconvert_exporter": "python", |  | ||||||
|    "pygments_lexer": "ipython3", |  | ||||||
|    "version": "3.8.8" |  | ||||||
|   } |  | ||||||
|  }, |  | ||||||
|  "nbformat": 4, |  | ||||||
|  "nbformat_minor": 5 |  | ||||||
| } |  | ||||||
| @@ -1,145 +0,0 @@ | |||||||
| { |  | ||||||
|  "cells": [ |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": 1, |  | ||||||
|    "id": "filled-multiple", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [ |  | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "The root path: /Users/xuanyidong\n", |  | ||||||
|       "The library path: /Users/xuanyidong/lib\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "ename": "AssertionError", |  | ||||||
|      "evalue": "/Users/xuanyidong/lib does not exist", |  | ||||||
|      "output_type": "error", |  | ||||||
|      "traceback": [ |  | ||||||
|       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |  | ||||||
|       "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)", |  | ||||||
|       "\u001b[0;32m~/Desktop/AutoDL-Projects\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The root path: {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The library path: {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mlib_dir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"{:} does not exist\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |  | ||||||
|       "\u001b[0;31mAssertionError\u001b[0m: /Users/xuanyidong/lib does not exist" |  | ||||||
|      ] |  | ||||||
|     } |  | ||||||
|    ], |  | ||||||
|    "source": [ |  | ||||||
|     "import os, sys\n", |  | ||||||
|     "import torch\n", |  | ||||||
|     "from pathlib import Path\n", |  | ||||||
|     "import numpy as np\n", |  | ||||||
|     "import matplotlib\n", |  | ||||||
|     "from matplotlib import cm\n", |  | ||||||
|     "matplotlib.use(\"agg\")\n", |  | ||||||
|     "import matplotlib.pyplot as plt\n", |  | ||||||
|     "import matplotlib.ticker as ticker\n", |  | ||||||
|     "\n", |  | ||||||
|     "\n", |  | ||||||
|     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", |  | ||||||
|     "root_dir = (Path(__file__).parent / \"..\").resolve()\n", |  | ||||||
|     "lib_dir = (root_dir / \"lib\").resolve()\n", |  | ||||||
|     "print(\"The root path: {:}\".format(root_dir))\n", |  | ||||||
|     "print(\"The library path: {:}\".format(lib_dir))\n", |  | ||||||
|     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", |  | ||||||
|     "if str(lib_dir) not in sys.path:\n", |  | ||||||
|     "    sys.path.insert(0, str(lib_dir))\n", |  | ||||||
|     "\n", |  | ||||||
|     "from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv\n", |  | ||||||
|     "from datasets import DynamicQuadraticFunc\n", |  | ||||||
|     "from datasets.synthetic_example import create_example_v1" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": null, |  | ||||||
|    "id": "detected-second", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "def draw_fig(save_dir, timestamp, xaxis, yaxis):\n", |  | ||||||
|     "    save_path = save_dir / '{:04d}'.format(timestamp)\n", |  | ||||||
|     "    # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))\n", |  | ||||||
|     "    dpi, width, height = 40, 1500, 1500\n", |  | ||||||
|     "    figsize = width / float(dpi), height / float(dpi)\n", |  | ||||||
|     "    LabelSize, LegendFontsize, font_gap = 80, 80, 5\n", |  | ||||||
|     "\n", |  | ||||||
|     "    fig = plt.figure(figsize=figsize)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "    cur_ax = fig.add_subplot(1, 1, 1)\n", |  | ||||||
|     "    cur_ax.scatter(xaxis, yaxis, color=\"k\", s=10, alpha=0.9, label=\"Timestamp={:02d}\".format(timestamp))\n", |  | ||||||
|     "    cur_ax.set_xlabel(\"X\", fontsize=LabelSize)\n", |  | ||||||
|     "    cur_ax.set_ylabel(\"f(X)\", rotation=0, fontsize=LabelSize)\n", |  | ||||||
|     "    cur_ax.set_xlim(-6, 6)\n", |  | ||||||
|     "    cur_ax.set_ylim(-40, 40)\n", |  | ||||||
|     "    for tick in cur_ax.xaxis.get_major_ticks():\n", |  | ||||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "        tick.label.set_rotation(10)\n", |  | ||||||
|     "    for tick in cur_ax.yaxis.get_major_ticks():\n", |  | ||||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", |  | ||||||
|     "        \n", |  | ||||||
|     "    plt.legend(loc=1, fontsize=LegendFontsize)\n", |  | ||||||
|     "    fig.savefig(str(save_path) + '.pdf', dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", |  | ||||||
|     "    fig.savefig(str(save_path) + '.png', dpi=dpi, bbox_inches=\"tight\", format=\"png\")\n", |  | ||||||
|     "    plt.close(\"all\")\n", |  | ||||||
|     "\n", |  | ||||||
|     "\n", |  | ||||||
|     "def visualize_env(save_dir):\n", |  | ||||||
|     "    save_dir.mkdir(parents=True, exist_ok=True)\n", |  | ||||||
|     "    dynamic_env, function = create_example_v1(100, num_per_task=500)\n", |  | ||||||
|     "    \n", |  | ||||||
|     "    additional_xaxis = np.arange(-6, 6, 0.1)\n", |  | ||||||
|     "    for timestamp, dataset in dynamic_env:\n", |  | ||||||
|     "        num = dataset.shape[0]\n", |  | ||||||
|     "        # timeaxis = (torch.zeros(num) + timestamp).numpy()\n", |  | ||||||
|     "        xaxis = dataset[:,0].numpy()\n", |  | ||||||
|     "        xaxis = np.concatenate((additional_xaxis, xaxis))\n", |  | ||||||
|     "        # compute the ground truth\n", |  | ||||||
|     "        function.set_timestamp(timestamp)\n", |  | ||||||
|     "        yaxis = function(xaxis)\n", |  | ||||||
|     "        draw_fig(save_dir, timestamp, xaxis, yaxis)\n", |  | ||||||
|     "\n", |  | ||||||
|     "home_dir = Path.home()\n", |  | ||||||
|     "desktop_dir = home_dir / 'Desktop'\n", |  | ||||||
|     "vis_save_dir = desktop_dir / 'vis-synthetic'\n", |  | ||||||
|     "visualize_env(vis_save_dir)" |  | ||||||
|    ] |  | ||||||
|   }, |  | ||||||
|   { |  | ||||||
|    "cell_type": "code", |  | ||||||
|    "execution_count": null, |  | ||||||
|    "id": "greatest-pepper", |  | ||||||
|    "metadata": {}, |  | ||||||
|    "outputs": [], |  | ||||||
|    "source": [ |  | ||||||
|     "# Plot the data\n", |  | ||||||
|     "cmd = 'ffmpeg -y -i {:}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {:}/vis.mp4'.format(vis_save_dir, vis_save_dir)\n", |  | ||||||
|     "print(cmd)\n", |  | ||||||
|     "os.system(cmd)" |  | ||||||
|    ] |  | ||||||
|   } |  | ||||||
|  ], |  | ||||||
|  "metadata": { |  | ||||||
|   "kernelspec": { |  | ||||||
|    "display_name": "Python 3", |  | ||||||
|    "language": "python", |  | ||||||
|    "name": "python3" |  | ||||||
|   }, |  | ||||||
|   "language_info": { |  | ||||||
|    "codemirror_mode": { |  | ||||||
|     "name": "ipython", |  | ||||||
|     "version": 3 |  | ||||||
|    }, |  | ||||||
|    "file_extension": ".py", |  | ||||||
|    "mimetype": "text/x-python", |  | ||||||
|    "name": "python", |  | ||||||
|    "nbconvert_exporter": "python", |  | ||||||
|    "pygments_lexer": "ipython3", |  | ||||||
|    "version": "3.8.8" |  | ||||||
|   } |  | ||||||
|  }, |  | ||||||
|  "nbformat": 4, |  | ||||||
|  "nbformat_minor": 5 |  | ||||||
| } |  | ||||||
| @@ -72,3 +72,17 @@ def test_super_sequential(batch, seq_dim, input_dim, order): | |||||||
|         out3_dim.abstract(reuse_last=True).random(reuse_last=True).value, |         out3_dim.abstract(reuse_last=True).random(reuse_last=True).value, | ||||||
|     ) |     ) | ||||||
|     assert tuple(outputs.shape) == output_shape |     assert tuple(outputs.shape) == output_shape | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_super_sequential_v1(): | ||||||
|  |     model = super_core.SuperSequential( | ||||||
|  |         super_core.SuperSimpleNorm(1, 1), | ||||||
|  |         torch.nn.ReLU(), | ||||||
|  |         super_core.SuperLinear(10, 10), | ||||||
|  |     ) | ||||||
|  |     inputs = torch.rand(10, 10) | ||||||
|  |     print(model) | ||||||
|  |     outputs = model(inputs) | ||||||
|  |  | ||||||
|  |     abstract_search_space = model.abstract_search_space | ||||||
|  |     print(abstract_search_space) | ||||||
|   | |||||||
							
								
								
									
										53
									
								
								tests/test_super_norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								tests/test_super_norm.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | # pytest ./tests/test_super_norm.py -s              # | ||||||
|  | ##################################################### | ||||||
|  | import sys, random | ||||||
|  | import unittest | ||||||
|  | import pytest | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||||
|  | print("library path: {:}".format(lib_dir)) | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from xlayers import super_core | ||||||
|  | import spaces | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestSuperSimpleNorm(unittest.TestCase): | ||||||
|  |     """Test the super simple norm.""" | ||||||
|  |  | ||||||
|  |     def test_super_simple_norm(self): | ||||||
|  |         out_features = spaces.Categorical(12, 24, 36) | ||||||
|  |         bias = spaces.Categorical(True, False) | ||||||
|  |         model = super_core.SuperSequential( | ||||||
|  |             super_core.SuperSimpleNorm(5, 0.5), | ||||||
|  |             super_core.SuperLinear(10, out_features, bias=bias), | ||||||
|  |         ) | ||||||
|  |         print("The simple super module is:\n{:}".format(model)) | ||||||
|  |         model.apply_verbose(True) | ||||||
|  |  | ||||||
|  |         print(model.super_run_type) | ||||||
|  |         self.assertTrue(model[1].bias) | ||||||
|  |  | ||||||
|  |         inputs = torch.rand(20, 10) | ||||||
|  |         print("Input shape: {:}".format(inputs.shape)) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), (20, 36)) | ||||||
|  |  | ||||||
|  |         abstract_space = model.abstract_search_space | ||||||
|  |         abstract_space.clean_last() | ||||||
|  |         abstract_child = abstract_space.random() | ||||||
|  |         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||||
|  |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|  |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.apply_candidate(abstract_child) | ||||||
|  |  | ||||||
|  |         output_shape = (20, abstract_child["1"]["_out_features"].value) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
		Reference in New Issue
	
	Block a user