From 8961215416cfac8c501015eab0f883fa192f8d1d Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 27 May 2021 11:17:57 +0800 Subject: [PATCH] Re-org GeMOSA codes --- exps/GeMOSA/lfna_models.py | 117 ------------------ exps/GeMOSA/main.py | 10 +- .../{lfna_meta_model.py => meta_model.py} | 19 ++- exps/GeMOSA/vis-synthetic.py | 27 ++-- xautodl/datasets/math_core.py | 1 + xautodl/datasets/math_dynamic_funcs.py | 42 +++++-- xautodl/datasets/synthetic_core.py | 6 +- xautodl/datasets/synthetic_env.py | 22 ++-- 8 files changed, 82 insertions(+), 162 deletions(-) delete mode 100644 exps/GeMOSA/lfna_models.py rename exps/GeMOSA/{lfna_meta_model.py => meta_model.py} (96%) diff --git a/exps/GeMOSA/lfna_models.py b/exps/GeMOSA/lfna_models.py deleted file mode 100644 index 2133cb2..0000000 --- a/exps/GeMOSA/lfna_models.py +++ /dev/null @@ -1,117 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -import copy -import torch - -import torch.nn.functional as F - -from xlayers import super_core -from xlayers import trunc_normal_ -from models.xcore import get_model - - -class HyperNet(super_core.SuperModule): - """The hyper-network.""" - - def __init__( - self, - shape_container, - layer_embeding, - task_embedding, - num_tasks, - return_container=True, - ): - super(HyperNet, self).__init__() - self._shape_container = shape_container - self._num_layers = len(shape_container) - self._numel_per_layer = [] - for ilayer in range(self._num_layers): - self._numel_per_layer.append(shape_container[ilayer].numel()) - - self.register_parameter( - "_super_layer_embed", - torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), - ) - self.register_parameter( - "_super_task_embed", - torch.nn.Parameter(torch.Tensor(num_tasks, task_embedding)), - ) - trunc_normal_(self._super_layer_embed, std=0.02) - trunc_normal_(self._super_task_embed, std=0.02) - - model_kwargs = dict( - config=dict(model_type="dual_norm_mlp"), - input_dim=layer_embeding + task_embedding, - output_dim=max(self._numel_per_layer), - hidden_dims=[(layer_embeding + task_embedding) * 2] * 3, - act_cls="gelu", - norm_cls="layer_norm_1d", - dropout=0.2, - ) - self._generator = get_model(**model_kwargs) - self._return_container = return_container - print("generator: {:}".format(self._generator)) - - def forward_raw(self, task_embed_id): - layer_embed = self._super_layer_embed - task_embed = ( - self._super_task_embed[task_embed_id] - .view(1, -1) - .expand(self._num_layers, -1) - ) - - joint_embed = torch.cat((task_embed, layer_embed), dim=-1) - weights = self._generator(joint_embed) - if self._return_container: - weights = torch.split(weights, 1) - return self._shape_container.translate(weights) - else: - return weights - - def forward_candidate(self, input): - raise NotImplementedError - - def extra_repr(self) -> str: - return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) - - -class HyperNet_VX(super_core.SuperModule): - def __init__(self, shape_container, input_embeding, return_container=True): - super(HyperNet_VX, self).__init__() - self._shape_container = shape_container - self._num_layers = len(shape_container) - self._numel_per_layer = [] - for ilayer in range(self._num_layers): - self._numel_per_layer.append(shape_container[ilayer].numel()) - - self.register_parameter( - "_super_layer_embed", - torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), - ) - trunc_normal_(self._super_layer_embed, std=0.02) - - model_kwargs = dict( - input_dim=input_embeding, - output_dim=max(self._numel_per_layer), - hidden_dim=input_embeding * 4, - act_cls="sigmoid", - norm_cls="identity", - ) - self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) - self._return_container = return_container - print("generator: {:}".format(self._generator)) - - def forward_raw(self, input): - weights = self._generator(self._super_layer_embed) - if self._return_container: - weights = torch.split(weights, 1) - return self._shape_container.translate(weights) - else: - return weights - - def forward_candidate(self, input): - raise NotImplementedError - - def extra_repr(self) -> str: - return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) diff --git a/exps/GeMOSA/main.py b/exps/GeMOSA/main.py index 2ba0117..5a0b893 100644 --- a/exps/GeMOSA/main.py +++ b/exps/GeMOSA/main.py @@ -35,7 +35,7 @@ from xautodl.models.xcore import get_model from xautodl.xlayers import super_core, trunc_normal_ from lfna_utils import lfna_setup, train_model, TimeData -from lfna_meta_model import MetaModelV1 +from meta_model import MetaModelV1 def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): @@ -106,7 +106,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): ) optimizer.zero_grad() - generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True) + generated_time_embeds = gen_time_embed(meta_model.meta_timestamps) batch_indexes = random.choices(total_indexes, k=args.meta_batch) @@ -219,11 +219,11 @@ def main(args): w_containers, loss_meter = online_evaluate( all_env, meta_model, base_model, criterion, args, logger, True ) - logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter)) + logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) save_checkpoint( - {"w_containers": w_containers}, - logger.path(None) / "final-ckp.pth", + {"all_w_containers": w_containers}, + logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), logger, ) diff --git a/exps/GeMOSA/lfna_meta_model.py b/exps/GeMOSA/meta_model.py similarity index 96% rename from exps/GeMOSA/lfna_meta_model.py rename to exps/GeMOSA/meta_model.py index 0df20be..79f9fb9 100644 --- a/exps/GeMOSA/lfna_meta_model.py +++ b/exps/GeMOSA/meta_model.py @@ -154,8 +154,9 @@ class MetaModelV1(super_core.SuperModule): (self._append_meta_embed["fixed"], meta_embed), dim=0 ) - def _obtain_time_embed(self, timestamps): - # timestamps is a batch of sequence of timestamps + def gen_time_embed(self, timestamps): + # timestamps is a batch of timestamps + [B] = timestamps.shape # batch, seq = timestamps.shape timestamps = timestamps.view(-1, 1) meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed @@ -179,15 +180,8 @@ class MetaModelV1(super_core.SuperModule): ) return timestamp_embeds[:, -1, :] - def forward_raw(self, timestamps, time_embeds, tembed_only=False): - if time_embeds is None: - [B] = timestamps.shape - time_embeds = self._obtain_time_embed(timestamps) - else: # use the hyper-net only - time_seq = None - B, _ = time_embeds.shape - if tembed_only: - return time_embeds + def gen_model(self, time_embeds): + B, _ = time_embeds.shape # create joint embed num_layer, _ = self._super_layer_embed.shape # The shape of `joint_embed` is batch * num-layers * input-dim @@ -206,6 +200,9 @@ class MetaModelV1(super_core.SuperModule): ) return batch_containers, time_embeds + def forward_raw(self, timestamps, time_embeds, tembed_only=False): + raise NotImplementedError + def forward_candidate(self, input): raise NotImplementedError diff --git a/exps/GeMOSA/vis-synthetic.py b/exps/GeMOSA/vis-synthetic.py index 59b7140..a3f2885 100644 --- a/exps/GeMOSA/vis-synthetic.py +++ b/exps/GeMOSA/vis-synthetic.py @@ -1,8 +1,9 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ############################################################################ -# python exps/GMOA/vis-synthetic.py --env_version v1 # -# python exps/GMOA/vis-synthetic.py --env_version v2 # +# python exps/GeMOSA/vis-synthetic.py --env_version v1 # +# python exps/GeMOSA/vis-synthetic.py --env_version v2 # +# python exps/GeMOSA/vis-synthetic.py --env_version v2 # ############################################################################ import os, sys, copy, random import torch @@ -181,7 +182,7 @@ def compare_cl(save_dir): def visualize_env(save_dir, version): save_dir = Path(str(save_dir)) for substr in ("pdf", "png"): - sub_save_dir = save_dir / substr + sub_save_dir = save_dir / "{:}-{:}".format(substr, version) sub_save_dir.mkdir(parents=True, exist_ok=True) dynamic_env = get_synthetic_env(version=version) @@ -190,6 +191,8 @@ def visualize_env(save_dir, version): allxs.append(allx) allys.append(ally) allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) + print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) + print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): dpi, width, height = 30, 1800, 1400 figsize = width / float(dpi), height / float(dpi) @@ -210,14 +213,22 @@ def visualize_env(save_dir, version): cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) cur_ax.legend(loc=1, fontsize=LegendFontsize) - pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) + pdf_save_path = ( + save_dir + / "pdf-{:}".format(version) + / "v{:}-{:05d}.pdf".format(version, idx) + ) fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") - png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) + png_save_path = ( + save_dir + / "png-{:}".format(version) + / "v{:}-{:05d}.png".format(version, idx) + ) fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") plt.close("all") save_dir = save_dir.resolve() base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( - xdir=save_dir / "png", version=version + xdir=save_dir / "png-{:}".format(version), version=version ) print(base_cmd) os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) @@ -367,7 +378,7 @@ if __name__ == "__main__": ) args = parser.parse_args() - # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") + visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version) # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") - compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) + # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) # compare_cl(os.path.join(args.save_dir, "compare-cl")) diff --git a/xautodl/datasets/math_core.py b/xautodl/datasets/math_core.py index 0de2f9e..9e8929c 100644 --- a/xautodl/datasets/math_core.py +++ b/xautodl/datasets/math_core.py @@ -4,6 +4,7 @@ from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc from .math_dynamic_funcs import DynamicLinearFunc from .math_dynamic_funcs import DynamicQuadraticFunc +from .math_dynamic_funcs import DynamicSinQuadraticFunc from .math_adv_funcs import ConstantFunc from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc from .math_dynamic_generator import GaussianDGenerator diff --git a/xautodl/datasets/math_dynamic_funcs.py b/xautodl/datasets/math_dynamic_funcs.py index e83d8db..f475ff8 100644 --- a/xautodl/datasets/math_dynamic_funcs.py +++ b/xautodl/datasets/math_dynamic_funcs.py @@ -5,9 +5,6 @@ import math import abc import copy import numpy as np -from typing import Optional -import torch -import torch.utils.data as data from .math_base_funcs import FitFunc @@ -68,10 +65,11 @@ class DynamicQuadraticFunc(DynamicFunc): def __init__(self, params=None): super(DynamicQuadraticFunc, self).__init__(3, params) - def __call__(self, x, timestamp=None): + def __call__( + self, + x, + ): self.check_valid() - if timestamp is None: - timestamp = self._timestamp a = self._params[0](timestamp) b = self._params[1](timestamp) c = self._params[2](timestamp) @@ -80,10 +78,38 @@ class DynamicQuadraticFunc(DynamicFunc): return a * x * x + b * x + c def __repr__(self): - return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format( + return "{name}({a} * x^2 + {b} * x + {c})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + ) + + +class DynamicSinQuadraticFunc(DynamicFunc): + """The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c). + The a, b, and c is a function of timestamp. + """ + + def __init__(self, params=None): + super(DynamicSinQuadraticFunc, self).__init__(3, params) + + def __call__( + self, + x, + ): + self.check_valid() + a = self._params[0](timestamp) + b = self._params[1](timestamp) + c = self._params[2](timestamp) + convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x + a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) + return math.sin(a * x * x + b * x + c) + + def __repr__(self): + return "{name}({a} * x^2 + {b} * x + {c})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], c=self._params[2], - timestamp=self._timestamp, ) diff --git a/xautodl/datasets/synthetic_core.py b/xautodl/datasets/synthetic_core.py index 28d4bcc..c5b2da0 100644 --- a/xautodl/datasets/synthetic_core.py +++ b/xautodl/datasets/synthetic_core.py @@ -3,7 +3,7 @@ from .synthetic_utils import TimeStamp from .synthetic_env import SyntheticDEnv from .math_core import LinearFunc from .math_core import DynamicLinearFunc -from .math_core import DynamicQuadraticFunc +from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc from .math_core import ( ConstantFunc, ComposedSinFunc as SinFunc, @@ -63,9 +63,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio time_generator = TimeStamp( min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode ) - oracle_map = DynamicQuadraticFunc( + oracle_map = DynamicSinQuadraticFunc( params={ - 0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t + 0: CosFunc(params={0: 0.5, 1: 1, 2: 1}), # 0.5 cos(t) + 1 1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t) 2: ConstantFunc(0), } diff --git a/xautodl/datasets/synthetic_env.py b/xautodl/datasets/synthetic_env.py index 4b94e9d..a434018 100644 --- a/xautodl/datasets/synthetic_env.py +++ b/xautodl/datasets/synthetic_env.py @@ -1,6 +1,3 @@ -import math -import random -from typing import List, Optional, Dict import torch import torch.utils.data as data @@ -43,6 +40,18 @@ class SyntheticDEnv(data.Dataset): self._oracle_map = oracle_map self._num_per_task = num_per_task self._noise = noise + self._meta_info = dict() + + def set_regression(self): + self._meta_info["task"] = "regression" + + def set_classification(self, num_classes): + self._meta_info["task"] = "classification" + self._meta_info["num_classes"] = int(num_classes) + + @property + def meta_info(self): + return self._meta_info @property def min_timestamp(self): @@ -60,13 +69,6 @@ class SyntheticDEnv(data.Dataset): def mode(self): return self._time_generator.mode - def random_timestamp(self, min_timestamp=None, max_timestamp=None): - if min_timestamp is None: - min_timestamp = self.min_timestamp - if max_timestamp is None: - max_timestamp = self.max_timestamp - return random.random() * (max_timestamp - min_timestamp) + min_timestamp - def get_timestamp(self, index): if index is None: timestamps = []