Add SuperSimpleNorm and update synthetic env
This commit is contained in:
		@@ -1,9 +1,9 @@
 | 
			
		||||
#####################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
 | 
			
		||||
#####################################################
 | 
			
		||||
# python exps/synthetic/baseline.py                 #
 | 
			
		||||
#####################################################
 | 
			
		||||
import os, sys, copy
 | 
			
		||||
############################################################################
 | 
			
		||||
# CUDA_VISIBLE_DEVICES=0 python exps/synthetic/baseline.py                 #
 | 
			
		||||
############################################################################
 | 
			
		||||
import os, sys, copy, random
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
import argparse
 | 
			
		||||
@@ -28,6 +28,8 @@ from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv
 | 
			
		||||
from datasets import DynamicQuadraticFunc
 | 
			
		||||
from datasets.synthetic_example import create_example_v1
 | 
			
		||||
 | 
			
		||||
from utils.temp_sync import optimize_fn, evaluate_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_fig(save_dir, timestamp, scatter_list):
 | 
			
		||||
    save_path = save_dir / "{:04d}".format(timestamp)
 | 
			
		||||
@@ -67,28 +69,55 @@ def draw_fig(save_dir, timestamp, scatter_list):
 | 
			
		||||
def main(save_dir):
 | 
			
		||||
    save_dir = Path(str(save_dir))
 | 
			
		||||
    save_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
    dynamic_env, function = create_example_v1(100, num_per_task=500)
 | 
			
		||||
    dynamic_env, function = create_example_v1(100, num_per_task=1000)
 | 
			
		||||
 | 
			
		||||
    additional_xaxis = np.arange(-6, 6, 0.1)
 | 
			
		||||
    for timestamp, dataset in tqdm(dynamic_env, ncols=50):
 | 
			
		||||
        num = dataset.shape[0]
 | 
			
		||||
        xaxis = dataset[:, 0].numpy()
 | 
			
		||||
    additional_xaxis = np.arange(-6, 6, 0.2)
 | 
			
		||||
    models = dict()
 | 
			
		||||
    
 | 
			
		||||
    for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
 | 
			
		||||
        xaxis_all = dataset[:, 0].numpy()
 | 
			
		||||
        # xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
 | 
			
		||||
        # compute the ground truth
 | 
			
		||||
        function.set_timestamp(timestamp)
 | 
			
		||||
        yaxis = function(xaxis)
 | 
			
		||||
        # xaxis = np.concatenate((additional_xaxis, xaxis))
 | 
			
		||||
        yaxis_all = function.noise_call(xaxis_all)
 | 
			
		||||
 | 
			
		||||
        # split the dataset
 | 
			
		||||
        indexes = list(range(xaxis_all.shape[0]))
 | 
			
		||||
        random.shuffle(indexes)
 | 
			
		||||
        train_indexes = indexes[:len(indexes)//2]
 | 
			
		||||
        valid_indexes = indexes[len(indexes)//2:]
 | 
			
		||||
        train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes]
 | 
			
		||||
        valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes]
 | 
			
		||||
        
 | 
			
		||||
        model, loss_fn, train_loss = optimize_fn(train_xs, train_ys)
 | 
			
		||||
        # model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all)
 | 
			
		||||
        pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn)
 | 
			
		||||
        print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss))
 | 
			
		||||
 | 
			
		||||
        # the first plot
 | 
			
		||||
        scatter_list = []
 | 
			
		||||
        scatter_list.append(
 | 
			
		||||
            {
 | 
			
		||||
                "xaxis": xaxis,
 | 
			
		||||
                "yaxis": yaxis,
 | 
			
		||||
                "xaxis": valid_xs,
 | 
			
		||||
                "yaxis": valid_ys,
 | 
			
		||||
                "color": "k",
 | 
			
		||||
                "s": 10,
 | 
			
		||||
                "alpha": 0.99,
 | 
			
		||||
                "label": "Timestamp={:02d}".format(timestamp),
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        scatter_list.append(
 | 
			
		||||
            {
 | 
			
		||||
                "xaxis": valid_xs,
 | 
			
		||||
                "yaxis": pred_valid_ys,
 | 
			
		||||
                "color": "r",
 | 
			
		||||
                "s": 10,
 | 
			
		||||
                "alpha": 0.5,
 | 
			
		||||
                "label": "MLP at now"
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
        draw_fig(save_dir, timestamp, scatter_list)
 | 
			
		||||
    print("Save all figures into {:}".format(save_dir))
 | 
			
		||||
    save_dir = save_dir.resolve()
 | 
			
		||||
 
 | 
			
		||||
@@ -33,6 +33,14 @@ class FitFunc(abc.ABC):
 | 
			
		||||
    def __call__(self, x):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def noise_call(self, x, std=0.1):
 | 
			
		||||
        clean_y = self.__call__(x)
 | 
			
		||||
        if isinstance(clean_y, np.ndarray):
 | 
			
		||||
            noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
 | 
			
		||||
        return noise_y
 | 
			
		||||
 | 
			
		||||
    @abc.abstractmethod
 | 
			
		||||
    def _getitem(self, x):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										63
									
								
								lib/utils/temp_sync.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								lib/utils/temp_sync.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,63 @@
 | 
			
		||||
# To be deleted.
 | 
			
		||||
import copy
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from xlayers.super_core import SuperSequential, SuperMLPv1
 | 
			
		||||
from xlayers.super_core import SuperSimpleNorm
 | 
			
		||||
from xlayers.super_core import SuperLinear
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
 | 
			
		||||
    xs = torch.FloatTensor(xs).view(-1, 1).to(device)
 | 
			
		||||
    ys = torch.FloatTensor(ys).view(-1, 1).to(device)
 | 
			
		||||
 | 
			
		||||
    model = SuperSequential(
 | 
			
		||||
        SuperSimpleNorm(xs.mean().item(), xs.std().item()),
 | 
			
		||||
        SuperLinear(1, 200),
 | 
			
		||||
        torch.nn.LeakyReLU(),
 | 
			
		||||
        SuperLinear(200, 100),
 | 
			
		||||
        torch.nn.LeakyReLU(),
 | 
			
		||||
        SuperLinear(100, 1),
 | 
			
		||||
    ).to(device)
 | 
			
		||||
    model.train()
 | 
			
		||||
    optimizer = torch.optim.Adam(
 | 
			
		||||
        model.parameters(), lr=max_lr, amsgrad=True
 | 
			
		||||
    )
 | 
			
		||||
    loss_func = torch.nn.MSELoss()
 | 
			
		||||
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
 | 
			
		||||
        optimizer,
 | 
			
		||||
        milestones=[
 | 
			
		||||
            int(max_iter * 0.25),
 | 
			
		||||
            int(max_iter * 0.5),
 | 
			
		||||
            int(max_iter * 0.75),
 | 
			
		||||
        ],
 | 
			
		||||
        gamma=0.3,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    best_loss, best_param = None, None
 | 
			
		||||
    for _iter in range(max_iter):
 | 
			
		||||
        preds = model(xs)
 | 
			
		||||
 | 
			
		||||
        optimizer.zero_grad()
 | 
			
		||||
        loss = loss_func(preds, ys)
 | 
			
		||||
        loss.backward()
 | 
			
		||||
        optimizer.step()
 | 
			
		||||
        lr_scheduler.step()
 | 
			
		||||
 | 
			
		||||
        if best_loss is None or best_loss > loss.item():
 | 
			
		||||
            best_loss = loss.item()
 | 
			
		||||
            best_param = copy.deepcopy(model.state_dict())
 | 
			
		||||
        
 | 
			
		||||
        # print('loss={:}, best-loss={:}'.format(loss.item(), best_loss))
 | 
			
		||||
    model.load_state_dict(best_param)
 | 
			
		||||
    return model, loss_func, best_loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def evaluate_fn(model, xs, ys, loss_fn, device="cpu"):
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        inputs = torch.FloatTensor(xs).view(-1, 1).to(device)
 | 
			
		||||
        ys = torch.FloatTensor(ys).view(-1, 1).to(device)
 | 
			
		||||
        preds = model(inputs)
 | 
			
		||||
        loss = loss_fn(preds, ys)
 | 
			
		||||
        preds = preds.view(-1).cpu().numpy()
 | 
			
		||||
    return preds, loss.item()
 | 
			
		||||
@@ -91,6 +91,8 @@ class SuperSequential(SuperModule):
 | 
			
		||||
    def abstract_search_space(self):
 | 
			
		||||
        root_node = spaces.VirtualNode(id(self))
 | 
			
		||||
        for index, module in enumerate(self):
 | 
			
		||||
            if not isinstance(module, SuperModule):
 | 
			
		||||
                continue
 | 
			
		||||
            space = module.abstract_search_space
 | 
			
		||||
            if not spaces.is_determined(space):
 | 
			
		||||
                root_node.append(str(index), space)
 | 
			
		||||
@@ -98,9 +100,9 @@ class SuperSequential(SuperModule):
 | 
			
		||||
 | 
			
		||||
    def apply_candidate(self, abstract_child: spaces.VirtualNode):
 | 
			
		||||
        super(SuperSequential, self).apply_candidate(abstract_child)
 | 
			
		||||
        for index in range(len(self)):
 | 
			
		||||
        for index, module in enumerate(self):
 | 
			
		||||
            if str(index) in abstract_child:
 | 
			
		||||
                self.__getitem__(index).apply_candidate(abstract_child[str(index)])
 | 
			
		||||
                module.apply_candidate(abstract_child[str(index)])
 | 
			
		||||
 | 
			
		||||
    def forward_candidate(self, input):
 | 
			
		||||
        return self.forward_raw(input)
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,7 @@ from .super_module import SuperModule
 | 
			
		||||
from .super_container import SuperSequential
 | 
			
		||||
from .super_linear import SuperLinear
 | 
			
		||||
from .super_linear import SuperMLPv1, SuperMLPv2
 | 
			
		||||
from .super_norm import SuperSimpleNorm
 | 
			
		||||
from .super_norm import SuperLayerNorm1D
 | 
			
		||||
from .super_attention import SuperAttention
 | 
			
		||||
from .super_transformer import SuperTransformerEncoderLayer
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@
 | 
			
		||||
#####################################################
 | 
			
		||||
 | 
			
		||||
import abc
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Optional, Union, Callable
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
@@ -45,6 +46,17 @@ class SuperModule(abc.ABC, nn.Module):
 | 
			
		||||
 | 
			
		||||
        self.apply(_reset_super_run)
 | 
			
		||||
 | 
			
		||||
    def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
 | 
			
		||||
        if not isinstance(module, SuperModule):
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "Add {:} module, which is not SuperModule, into {:}".format(
 | 
			
		||||
                    name, self.__class__.__name__
 | 
			
		||||
                )
 | 
			
		||||
                + "\n"
 | 
			
		||||
                + "It may cause some functions invalid."
 | 
			
		||||
            )
 | 
			
		||||
        super(SuperModule, self).add_module(name, module)
 | 
			
		||||
 | 
			
		||||
    def apply_verbose(self, verbose):
 | 
			
		||||
        def _reset_verbose(m):
 | 
			
		||||
            if isinstance(m, SuperModule):
 | 
			
		||||
 
 | 
			
		||||
@@ -82,3 +82,43 @@ class SuperLayerNorm1D(SuperModule):
 | 
			
		||||
                elementwise_affine=self._elementwise_affine,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SuperSimpleNorm(SuperModule):
 | 
			
		||||
    """Super simple normalization."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, mean, std, inplace=False) -> None:
 | 
			
		||||
        super(SuperSimpleNorm, self).__init__()
 | 
			
		||||
        self._mean = mean
 | 
			
		||||
        self._std = std
 | 
			
		||||
        self._inplace = inplace
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def abstract_search_space(self):
 | 
			
		||||
        return spaces.VirtualNode(id(self))
 | 
			
		||||
 | 
			
		||||
    def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        # check inputs ->
 | 
			
		||||
        return self.forward_raw(input)
 | 
			
		||||
 | 
			
		||||
    def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        if not self._inplace:
 | 
			
		||||
            tensor = input.clone()
 | 
			
		||||
        else:
 | 
			
		||||
            tensor = input
 | 
			
		||||
        mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device)
 | 
			
		||||
        std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device)
 | 
			
		||||
        if (std == 0).any():
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "std evaluated to zero after conversion to {}, leading to division by zero.".format(
 | 
			
		||||
                    dtype
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        while mean.ndim < tensor.ndim:
 | 
			
		||||
            mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
 | 
			
		||||
        return tensor.sub_(mean).div_(std)
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return "mean={mean}, std={mean}, inplace={inplace}".format(
 | 
			
		||||
            mean=self._mean, std=self._std, inplace=self._inplace
 | 
			
		||||
        )
 | 
			
		||||
 
 | 
			
		||||
@@ -107,113 +107,6 @@
 | 
			
		||||
    "visualize_env()"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 3,
 | 
			
		||||
   "id": "supreme-basis",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "# def optimize_fn(xs, ys, test_sets):\n",
 | 
			
		||||
    "#     xs = torch.FloatTensor(xs).view(-1, 1)\n",
 | 
			
		||||
    "#     ys = torch.FloatTensor(ys).view(-1, 1)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     model = SuperSequential(\n",
 | 
			
		||||
    "#         SuperMLPv1(1, 10, 20, torch.nn.ReLU),\n",
 | 
			
		||||
    "#         SuperMLPv1(20, 10, 1, torch.nn.ReLU)\n",
 | 
			
		||||
    "#     )\n",
 | 
			
		||||
    "#     optimizer = torch.optim.Adam(\n",
 | 
			
		||||
    "#         model.parameters(),\n",
 | 
			
		||||
    "#         lr=0.01, weight_decay=1e-4, amsgrad=True\n",
 | 
			
		||||
    "#     )\n",
 | 
			
		||||
    "#     for _iter in range(100):\n",
 | 
			
		||||
    "#         preds = model(ys)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "#         optimizer.zero_grad()\n",
 | 
			
		||||
    "#         loss = torch.nn.functional.mse_loss(preds, ys)\n",
 | 
			
		||||
    "#         loss.backward()\n",
 | 
			
		||||
    "#         optimizer.step()\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "#     with torch.no_grad():\n",
 | 
			
		||||
    "#         answers = []\n",
 | 
			
		||||
    "#         for test_set in test_sets:\n",
 | 
			
		||||
    "#             test_set = torch.FloatTensor(test_set).view(-1, 1)\n",
 | 
			
		||||
    "#             preds = model(test_set).view(-1).numpy()\n",
 | 
			
		||||
    "#             answers.append(preds.tolist())\n",
 | 
			
		||||
    "#     return answers\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# def f(x):\n",
 | 
			
		||||
    "#     return np.cos( 0.5 * x + x * x)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# def get_data(mode):\n",
 | 
			
		||||
    "#     dataset = SynAdaptiveEnv(mode=mode)\n",
 | 
			
		||||
    "#     times, xs, ys = [], [], []\n",
 | 
			
		||||
    "#     for i, (_, t, x) in enumerate(dataset):\n",
 | 
			
		||||
    "#         times.append(t)\n",
 | 
			
		||||
    "#         xs.append(x)\n",
 | 
			
		||||
    "#     dataset.set_transform(f)\n",
 | 
			
		||||
    "#     for i, (_, _, y) in enumerate(dataset):\n",
 | 
			
		||||
    "#         ys.append(y)\n",
 | 
			
		||||
    "#     return times, xs, ys\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# def visualize_syn(save_path):\n",
 | 
			
		||||
    "#     save_dir = (save_path / '..').resolve()\n",
 | 
			
		||||
    "#     save_dir.mkdir(parents=True, exist_ok=True)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     dpi, width, height = 40, 2000, 900\n",
 | 
			
		||||
    "#     figsize = width / float(dpi), height / float(dpi)\n",
 | 
			
		||||
    "#     LabelSize, LegendFontsize, font_gap = 40, 40, 5\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     fig = plt.figure(figsize=figsize)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     times, xs, ys = get_data(None)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,\n",
 | 
			
		||||
    "#                 alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):\n",
 | 
			
		||||
    "#         if legend is not None:\n",
 | 
			
		||||
    "#             cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)\n",
 | 
			
		||||
    "#         cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)\n",
 | 
			
		||||
    "#         if not plot_only:\n",
 | 
			
		||||
    "#             cur_ax.set_xlabel(xlabel, fontsize=LabelSize)\n",
 | 
			
		||||
    "#             cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)\n",
 | 
			
		||||
    "#             for tick in cur_ax.xaxis.get_major_ticks():\n",
 | 
			
		||||
    "#                 tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "#                 tick.label.set_rotation(10)\n",
 | 
			
		||||
    "#             for tick in cur_ax.yaxis.get_major_ticks():\n",
 | 
			
		||||
    "#                 tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     cur_ax = fig.add_subplot(2, 1, 1)\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, times, xs, \"time\", \"x\", alpha=1.0, legend=None)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "#     cur_ax = fig.add_subplot(2, 1, 2)\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, times, ys, \"time\", \"y\", alpha=0.1, legend=\"ground truth\")\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     train_times, train_xs, train_ys = get_data(\"train\")\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     valid_times, valid_xs, valid_ys = get_data(\"valid\")\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     test_times, test_xs, test_ys = get_data(\"test\")\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     # optimize MLP models\n",
 | 
			
		||||
    "# #     [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])\n",
 | 
			
		||||
    "# #     draw_ax(cur_ax, train_times, train_preds, None, None,\n",
 | 
			
		||||
    "# #             alpha=1.0, linestyle='--', color='r', legend=\"MLP\", plot_only=True)\n",
 | 
			
		||||
    "# #     import pdb; pdb.set_trace()\n",
 | 
			
		||||
    "# #     draw_ax(cur_ax, valid_times, valid_preds, None, None,\n",
 | 
			
		||||
    "# #             alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)\n",
 | 
			
		||||
    "# #     draw_ax(cur_ax, test_times, test_preds, None, None,\n",
 | 
			
		||||
    "# #             alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "#     plt.legend(loc=1, fontsize=LegendFontsize)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "#     fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
 | 
			
		||||
    "#     plt.close(\"all\")\n",
 | 
			
		||||
    "#     # plt.show()"
 | 
			
		||||
   ]
 | 
			
		||||
  }
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
 
 | 
			
		||||
@@ -17,7 +17,7 @@ gpu=$1
 | 
			
		||||
market=$2
 | 
			
		||||
 | 
			
		||||
# algorithms="NAIVE-V1 NAIVE-V2 MLP GRU LSTM ALSTM XGBoost LightGBM SFM TabNet DoubleE"
 | 
			
		||||
algorithms="MLP GRU LSTM ALSTM XGBoost LightGBM SFM TabNet DoubleE"
 | 
			
		||||
algorithms="XGBoost LightGBM SFM TabNet DoubleE"
 | 
			
		||||
 | 
			
		||||
for alg in ${algorithms}
 | 
			
		||||
do
 | 
			
		||||
 
 | 
			
		||||
@@ -1,251 +0,0 @@
 | 
			
		||||
{
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 1,
 | 
			
		||||
   "id": "filled-multiple",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "The root path: /Users/xuanyidong\n",
 | 
			
		||||
      "The library path: /Users/xuanyidong/lib\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "ename": "AssertionError",
 | 
			
		||||
     "evalue": "/Users/xuanyidong/lib does not exist",
 | 
			
		||||
     "output_type": "error",
 | 
			
		||||
     "traceback": [
 | 
			
		||||
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
 | 
			
		||||
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
 | 
			
		||||
      "\u001b[0;32m~/Desktop/AutoDL-Projects\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The root path: {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The library path: {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mlib_dir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"{:} does not exist\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
			
		||||
      "\u001b[0;31mAssertionError\u001b[0m: /Users/xuanyidong/lib does not exist"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os, sys\n",
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "from pathlib import Path\n",
 | 
			
		||||
    "import numpy as np\n",
 | 
			
		||||
    "import matplotlib\n",
 | 
			
		||||
    "from matplotlib import cm\n",
 | 
			
		||||
    "# matplotlib.use(\"agg\")\n",
 | 
			
		||||
    "import matplotlib.pyplot as plt\n",
 | 
			
		||||
    "import matplotlib.ticker as ticker\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
 | 
			
		||||
    "root_dir = (Path(__file__).parent / \"..\").resolve()\n",
 | 
			
		||||
    "lib_dir = (root_dir / \"lib\").resolve()\n",
 | 
			
		||||
    "print(\"The root path: {:}\".format(root_dir))\n",
 | 
			
		||||
    "print(\"The library path: {:}\".format(lib_dir))\n",
 | 
			
		||||
    "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
 | 
			
		||||
    "if str(lib_dir) not in sys.path:\n",
 | 
			
		||||
    "    sys.path.insert(0, str(lib_dir))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv\n",
 | 
			
		||||
    "from datasets import DynamicQuadraticFunc\n",
 | 
			
		||||
    "from datasets.synthetic_example import create_example_v1"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "id": "detected-second",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "def visualize_env():\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    dpi, width, height = 10, 800, 400\n",
 | 
			
		||||
    "    figsize = width / float(dpi), height / float(dpi)\n",
 | 
			
		||||
    "    LabelSize, LegendFontsize, font_gap = 40, 40, 5\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    fig = plt.figure(figsize=figsize)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    dynamic_env, function = create_example_v1(100, num_per_task=250)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    timeaxis, xaxis, yaxis = [], [], []\n",
 | 
			
		||||
    "    for timestamp, dataset in dynamic_env:\n",
 | 
			
		||||
    "        num = dataset.shape[0]\n",
 | 
			
		||||
    "        timeaxis.append(torch.zeros(num) + timestamp)\n",
 | 
			
		||||
    "        xaxis.append(dataset[:,0])\n",
 | 
			
		||||
    "        # compute the ground truth\n",
 | 
			
		||||
    "        function.set_timestamp(timestamp)\n",
 | 
			
		||||
    "        yaxis.append(function(dataset[:,0]))\n",
 | 
			
		||||
    "    timeaxis = torch.cat(timeaxis).numpy()\n",
 | 
			
		||||
    "    xaxis = torch.cat(xaxis).numpy()\n",
 | 
			
		||||
    "    yaxis = torch.cat(yaxis).numpy()\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    cur_ax = fig.add_subplot(2, 1, 1)\n",
 | 
			
		||||
    "    cur_ax.scatter(timeaxis, xaxis, color=\"k\", linestyle=\"-\", alpha=0.9, label=None)\n",
 | 
			
		||||
    "    cur_ax.set_xlabel(\"Time\", fontsize=LabelSize)\n",
 | 
			
		||||
    "    cur_ax.set_ylabel(\"X\", rotation=0, fontsize=LabelSize)\n",
 | 
			
		||||
    "    for tick in cur_ax.xaxis.get_major_ticks():\n",
 | 
			
		||||
    "        tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "        tick.label.set_rotation(10)\n",
 | 
			
		||||
    "    for tick in cur_ax.yaxis.get_major_ticks():\n",
 | 
			
		||||
    "        tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    cur_ax = fig.add_subplot(2, 1, 2)\n",
 | 
			
		||||
    "    cur_ax.scatter(timeaxis, yaxis, color=\"k\", linestyle=\"-\", alpha=0.9, label=None)\n",
 | 
			
		||||
    "    cur_ax.set_xlabel(\"Time\", fontsize=LabelSize)\n",
 | 
			
		||||
    "    cur_ax.set_ylabel(\"Y\", rotation=0, fontsize=LabelSize)\n",
 | 
			
		||||
    "    for tick in cur_ax.xaxis.get_major_ticks():\n",
 | 
			
		||||
    "        tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "        tick.label.set_rotation(10)\n",
 | 
			
		||||
    "    for tick in cur_ax.yaxis.get_major_ticks():\n",
 | 
			
		||||
    "        tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "    plt.show()\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "visualize_env()"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "id": "supreme-basis",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "# def optimize_fn(xs, ys, test_sets):\n",
 | 
			
		||||
    "#     xs = torch.FloatTensor(xs).view(-1, 1)\n",
 | 
			
		||||
    "#     ys = torch.FloatTensor(ys).view(-1, 1)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     model = SuperSequential(\n",
 | 
			
		||||
    "#         SuperMLPv1(1, 10, 20, torch.nn.ReLU),\n",
 | 
			
		||||
    "#         SuperMLPv1(20, 10, 1, torch.nn.ReLU)\n",
 | 
			
		||||
    "#     )\n",
 | 
			
		||||
    "#     optimizer = torch.optim.Adam(\n",
 | 
			
		||||
    "#         model.parameters(),\n",
 | 
			
		||||
    "#         lr=0.01, weight_decay=1e-4, amsgrad=True\n",
 | 
			
		||||
    "#     )\n",
 | 
			
		||||
    "#     for _iter in range(100):\n",
 | 
			
		||||
    "#         preds = model(ys)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "#         optimizer.zero_grad()\n",
 | 
			
		||||
    "#         loss = torch.nn.functional.mse_loss(preds, ys)\n",
 | 
			
		||||
    "#         loss.backward()\n",
 | 
			
		||||
    "#         optimizer.step()\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "#     with torch.no_grad():\n",
 | 
			
		||||
    "#         answers = []\n",
 | 
			
		||||
    "#         for test_set in test_sets:\n",
 | 
			
		||||
    "#             test_set = torch.FloatTensor(test_set).view(-1, 1)\n",
 | 
			
		||||
    "#             preds = model(test_set).view(-1).numpy()\n",
 | 
			
		||||
    "#             answers.append(preds.tolist())\n",
 | 
			
		||||
    "#     return answers\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# def f(x):\n",
 | 
			
		||||
    "#     return np.cos( 0.5 * x + x * x)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# def get_data(mode):\n",
 | 
			
		||||
    "#     dataset = SynAdaptiveEnv(mode=mode)\n",
 | 
			
		||||
    "#     times, xs, ys = [], [], []\n",
 | 
			
		||||
    "#     for i, (_, t, x) in enumerate(dataset):\n",
 | 
			
		||||
    "#         times.append(t)\n",
 | 
			
		||||
    "#         xs.append(x)\n",
 | 
			
		||||
    "#     dataset.set_transform(f)\n",
 | 
			
		||||
    "#     for i, (_, _, y) in enumerate(dataset):\n",
 | 
			
		||||
    "#         ys.append(y)\n",
 | 
			
		||||
    "#     return times, xs, ys\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# def visualize_syn(save_path):\n",
 | 
			
		||||
    "#     save_dir = (save_path / '..').resolve()\n",
 | 
			
		||||
    "#     save_dir.mkdir(parents=True, exist_ok=True)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     dpi, width, height = 40, 2000, 900\n",
 | 
			
		||||
    "#     figsize = width / float(dpi), height / float(dpi)\n",
 | 
			
		||||
    "#     LabelSize, LegendFontsize, font_gap = 40, 40, 5\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     fig = plt.figure(figsize=figsize)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     times, xs, ys = get_data(None)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,\n",
 | 
			
		||||
    "#                 alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):\n",
 | 
			
		||||
    "#         if legend is not None:\n",
 | 
			
		||||
    "#             cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)\n",
 | 
			
		||||
    "#         cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)\n",
 | 
			
		||||
    "#         if not plot_only:\n",
 | 
			
		||||
    "#             cur_ax.set_xlabel(xlabel, fontsize=LabelSize)\n",
 | 
			
		||||
    "#             cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)\n",
 | 
			
		||||
    "#             for tick in cur_ax.xaxis.get_major_ticks():\n",
 | 
			
		||||
    "#                 tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "#                 tick.label.set_rotation(10)\n",
 | 
			
		||||
    "#             for tick in cur_ax.yaxis.get_major_ticks():\n",
 | 
			
		||||
    "#                 tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     cur_ax = fig.add_subplot(2, 1, 1)\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, times, xs, \"time\", \"x\", alpha=1.0, legend=None)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "#     cur_ax = fig.add_subplot(2, 1, 2)\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, times, ys, \"time\", \"y\", alpha=0.1, legend=\"ground truth\")\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     train_times, train_xs, train_ys = get_data(\"train\")\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     valid_times, valid_xs, valid_ys = get_data(\"valid\")\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     test_times, test_xs, test_ys = get_data(\"test\")\n",
 | 
			
		||||
    "#     draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "#     # optimize MLP models\n",
 | 
			
		||||
    "# #     [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])\n",
 | 
			
		||||
    "# #     draw_ax(cur_ax, train_times, train_preds, None, None,\n",
 | 
			
		||||
    "# #             alpha=1.0, linestyle='--', color='r', legend=\"MLP\", plot_only=True)\n",
 | 
			
		||||
    "# #     import pdb; pdb.set_trace()\n",
 | 
			
		||||
    "# #     draw_ax(cur_ax, valid_times, valid_preds, None, None,\n",
 | 
			
		||||
    "# #             alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)\n",
 | 
			
		||||
    "# #     draw_ax(cur_ax, test_times, test_preds, None, None,\n",
 | 
			
		||||
    "# #             alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "#     plt.legend(loc=1, fontsize=LegendFontsize)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "#     fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
 | 
			
		||||
    "#     plt.close(\"all\")\n",
 | 
			
		||||
    "#     # plt.show()"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "id": "shared-envelope",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "# Visualization\n",
 | 
			
		||||
    "# home_dir = Path.home()\n",
 | 
			
		||||
    "# desktop_dir = home_dir / 'Desktop'\n",
 | 
			
		||||
    "# print('The Desktop is at: {:}'.format(desktop_dir))\n",
 | 
			
		||||
    "# visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')"
 | 
			
		||||
   ]
 | 
			
		||||
  }
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "display_name": "Python 3",
 | 
			
		||||
   "language": "python",
 | 
			
		||||
   "name": "python3"
 | 
			
		||||
  },
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.8.8"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 5
 | 
			
		||||
}
 | 
			
		||||
@@ -1,145 +0,0 @@
 | 
			
		||||
{
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 1,
 | 
			
		||||
   "id": "filled-multiple",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "The root path: /Users/xuanyidong\n",
 | 
			
		||||
      "The library path: /Users/xuanyidong/lib\n"
 | 
			
		||||
     ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
     "ename": "AssertionError",
 | 
			
		||||
     "evalue": "/Users/xuanyidong/lib does not exist",
 | 
			
		||||
     "output_type": "error",
 | 
			
		||||
     "traceback": [
 | 
			
		||||
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
 | 
			
		||||
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
 | 
			
		||||
      "\u001b[0;32m~/Desktop/AutoDL-Projects\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The root path: {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The library path: {:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mlib_dir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"{:} does not exist\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlib_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | 
			
		||||
      "\u001b[0;31mAssertionError\u001b[0m: /Users/xuanyidong/lib does not exist"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import os, sys\n",
 | 
			
		||||
    "import torch\n",
 | 
			
		||||
    "from pathlib import Path\n",
 | 
			
		||||
    "import numpy as np\n",
 | 
			
		||||
    "import matplotlib\n",
 | 
			
		||||
    "from matplotlib import cm\n",
 | 
			
		||||
    "matplotlib.use(\"agg\")\n",
 | 
			
		||||
    "import matplotlib.pyplot as plt\n",
 | 
			
		||||
    "import matplotlib.ticker as ticker\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
 | 
			
		||||
    "root_dir = (Path(__file__).parent / \"..\").resolve()\n",
 | 
			
		||||
    "lib_dir = (root_dir / \"lib\").resolve()\n",
 | 
			
		||||
    "print(\"The root path: {:}\".format(root_dir))\n",
 | 
			
		||||
    "print(\"The library path: {:}\".format(lib_dir))\n",
 | 
			
		||||
    "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
 | 
			
		||||
    "if str(lib_dir) not in sys.path:\n",
 | 
			
		||||
    "    sys.path.insert(0, str(lib_dir))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv\n",
 | 
			
		||||
    "from datasets import DynamicQuadraticFunc\n",
 | 
			
		||||
    "from datasets.synthetic_example import create_example_v1"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "id": "detected-second",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "def draw_fig(save_dir, timestamp, xaxis, yaxis):\n",
 | 
			
		||||
    "    save_path = save_dir / '{:04d}'.format(timestamp)\n",
 | 
			
		||||
    "    # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))\n",
 | 
			
		||||
    "    dpi, width, height = 40, 1500, 1500\n",
 | 
			
		||||
    "    figsize = width / float(dpi), height / float(dpi)\n",
 | 
			
		||||
    "    LabelSize, LegendFontsize, font_gap = 80, 80, 5\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    fig = plt.figure(figsize=figsize)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    cur_ax = fig.add_subplot(1, 1, 1)\n",
 | 
			
		||||
    "    cur_ax.scatter(xaxis, yaxis, color=\"k\", s=10, alpha=0.9, label=\"Timestamp={:02d}\".format(timestamp))\n",
 | 
			
		||||
    "    cur_ax.set_xlabel(\"X\", fontsize=LabelSize)\n",
 | 
			
		||||
    "    cur_ax.set_ylabel(\"f(X)\", rotation=0, fontsize=LabelSize)\n",
 | 
			
		||||
    "    cur_ax.set_xlim(-6, 6)\n",
 | 
			
		||||
    "    cur_ax.set_ylim(-40, 40)\n",
 | 
			
		||||
    "    for tick in cur_ax.xaxis.get_major_ticks():\n",
 | 
			
		||||
    "        tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "        tick.label.set_rotation(10)\n",
 | 
			
		||||
    "    for tick in cur_ax.yaxis.get_major_ticks():\n",
 | 
			
		||||
    "        tick.label.set_fontsize(LabelSize - font_gap)\n",
 | 
			
		||||
    "        \n",
 | 
			
		||||
    "    plt.legend(loc=1, fontsize=LegendFontsize)\n",
 | 
			
		||||
    "    fig.savefig(str(save_path) + '.pdf', dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
 | 
			
		||||
    "    fig.savefig(str(save_path) + '.png', dpi=dpi, bbox_inches=\"tight\", format=\"png\")\n",
 | 
			
		||||
    "    plt.close(\"all\")\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "def visualize_env(save_dir):\n",
 | 
			
		||||
    "    save_dir.mkdir(parents=True, exist_ok=True)\n",
 | 
			
		||||
    "    dynamic_env, function = create_example_v1(100, num_per_task=500)\n",
 | 
			
		||||
    "    \n",
 | 
			
		||||
    "    additional_xaxis = np.arange(-6, 6, 0.1)\n",
 | 
			
		||||
    "    for timestamp, dataset in dynamic_env:\n",
 | 
			
		||||
    "        num = dataset.shape[0]\n",
 | 
			
		||||
    "        # timeaxis = (torch.zeros(num) + timestamp).numpy()\n",
 | 
			
		||||
    "        xaxis = dataset[:,0].numpy()\n",
 | 
			
		||||
    "        xaxis = np.concatenate((additional_xaxis, xaxis))\n",
 | 
			
		||||
    "        # compute the ground truth\n",
 | 
			
		||||
    "        function.set_timestamp(timestamp)\n",
 | 
			
		||||
    "        yaxis = function(xaxis)\n",
 | 
			
		||||
    "        draw_fig(save_dir, timestamp, xaxis, yaxis)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "home_dir = Path.home()\n",
 | 
			
		||||
    "desktop_dir = home_dir / 'Desktop'\n",
 | 
			
		||||
    "vis_save_dir = desktop_dir / 'vis-synthetic'\n",
 | 
			
		||||
    "visualize_env(vis_save_dir)"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "id": "greatest-pepper",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "# Plot the data\n",
 | 
			
		||||
    "cmd = 'ffmpeg -y -i {:}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {:}/vis.mp4'.format(vis_save_dir, vis_save_dir)\n",
 | 
			
		||||
    "print(cmd)\n",
 | 
			
		||||
    "os.system(cmd)"
 | 
			
		||||
   ]
 | 
			
		||||
  }
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "display_name": "Python 3",
 | 
			
		||||
   "language": "python",
 | 
			
		||||
   "name": "python3"
 | 
			
		||||
  },
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.8.8"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 5
 | 
			
		||||
}
 | 
			
		||||
@@ -72,3 +72,17 @@ def test_super_sequential(batch, seq_dim, input_dim, order):
 | 
			
		||||
        out3_dim.abstract(reuse_last=True).random(reuse_last=True).value,
 | 
			
		||||
    )
 | 
			
		||||
    assert tuple(outputs.shape) == output_shape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_super_sequential_v1():
 | 
			
		||||
    model = super_core.SuperSequential(
 | 
			
		||||
        super_core.SuperSimpleNorm(1, 1),
 | 
			
		||||
        torch.nn.ReLU(),
 | 
			
		||||
        super_core.SuperLinear(10, 10),
 | 
			
		||||
    )
 | 
			
		||||
    inputs = torch.rand(10, 10)
 | 
			
		||||
    print(model)
 | 
			
		||||
    outputs = model(inputs)
 | 
			
		||||
 | 
			
		||||
    abstract_search_space = model.abstract_search_space
 | 
			
		||||
    print(abstract_search_space)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										53
									
								
								tests/test_super_norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								tests/test_super_norm.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
#####################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
 | 
			
		||||
#####################################################
 | 
			
		||||
# pytest ./tests/test_super_norm.py -s              #
 | 
			
		||||
#####################################################
 | 
			
		||||
import sys, random
 | 
			
		||||
import unittest
 | 
			
		||||
import pytest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
 | 
			
		||||
print("library path: {:}".format(lib_dir))
 | 
			
		||||
if str(lib_dir) not in sys.path:
 | 
			
		||||
    sys.path.insert(0, str(lib_dir))
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from xlayers import super_core
 | 
			
		||||
import spaces
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSuperSimpleNorm(unittest.TestCase):
 | 
			
		||||
    """Test the super simple norm."""
 | 
			
		||||
 | 
			
		||||
    def test_super_simple_norm(self):
 | 
			
		||||
        out_features = spaces.Categorical(12, 24, 36)
 | 
			
		||||
        bias = spaces.Categorical(True, False)
 | 
			
		||||
        model = super_core.SuperSequential(
 | 
			
		||||
            super_core.SuperSimpleNorm(5, 0.5),
 | 
			
		||||
            super_core.SuperLinear(10, out_features, bias=bias),
 | 
			
		||||
        )
 | 
			
		||||
        print("The simple super module is:\n{:}".format(model))
 | 
			
		||||
        model.apply_verbose(True)
 | 
			
		||||
 | 
			
		||||
        print(model.super_run_type)
 | 
			
		||||
        self.assertTrue(model[1].bias)
 | 
			
		||||
 | 
			
		||||
        inputs = torch.rand(20, 10)
 | 
			
		||||
        print("Input shape: {:}".format(inputs.shape))
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.assertEqual(tuple(outputs.shape), (20, 36))
 | 
			
		||||
 | 
			
		||||
        abstract_space = model.abstract_search_space
 | 
			
		||||
        abstract_space.clean_last()
 | 
			
		||||
        abstract_child = abstract_space.random()
 | 
			
		||||
        print("The abstract searc space:\n{:}".format(abstract_space))
 | 
			
		||||
        print("The abstract child program:\n{:}".format(abstract_child))
 | 
			
		||||
 | 
			
		||||
        model.set_super_run_type(super_core.SuperRunMode.Candidate)
 | 
			
		||||
        model.apply_candidate(abstract_child)
 | 
			
		||||
 | 
			
		||||
        output_shape = (20, abstract_child["1"]["_out_features"].value)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.assertEqual(tuple(outputs.shape), output_shape)
 | 
			
		||||
		Reference in New Issue
	
	Block a user