Re-org GeMOSA codes
This commit is contained in:
		@@ -1,117 +0,0 @@
 | 
				
			|||||||
#####################################################
 | 
					 | 
				
			||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
 | 
					 | 
				
			||||||
#####################################################
 | 
					 | 
				
			||||||
import copy
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch.nn.functional as F
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from xlayers import super_core
 | 
					 | 
				
			||||||
from xlayers import trunc_normal_
 | 
					 | 
				
			||||||
from models.xcore import get_model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class HyperNet(super_core.SuperModule):
 | 
					 | 
				
			||||||
    """The hyper-network."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        shape_container,
 | 
					 | 
				
			||||||
        layer_embeding,
 | 
					 | 
				
			||||||
        task_embedding,
 | 
					 | 
				
			||||||
        num_tasks,
 | 
					 | 
				
			||||||
        return_container=True,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super(HyperNet, self).__init__()
 | 
					 | 
				
			||||||
        self._shape_container = shape_container
 | 
					 | 
				
			||||||
        self._num_layers = len(shape_container)
 | 
					 | 
				
			||||||
        self._numel_per_layer = []
 | 
					 | 
				
			||||||
        for ilayer in range(self._num_layers):
 | 
					 | 
				
			||||||
            self._numel_per_layer.append(shape_container[ilayer].numel())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.register_parameter(
 | 
					 | 
				
			||||||
            "_super_layer_embed",
 | 
					 | 
				
			||||||
            torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.register_parameter(
 | 
					 | 
				
			||||||
            "_super_task_embed",
 | 
					 | 
				
			||||||
            torch.nn.Parameter(torch.Tensor(num_tasks, task_embedding)),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        trunc_normal_(self._super_layer_embed, std=0.02)
 | 
					 | 
				
			||||||
        trunc_normal_(self._super_task_embed, std=0.02)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        model_kwargs = dict(
 | 
					 | 
				
			||||||
            config=dict(model_type="dual_norm_mlp"),
 | 
					 | 
				
			||||||
            input_dim=layer_embeding + task_embedding,
 | 
					 | 
				
			||||||
            output_dim=max(self._numel_per_layer),
 | 
					 | 
				
			||||||
            hidden_dims=[(layer_embeding + task_embedding) * 2] * 3,
 | 
					 | 
				
			||||||
            act_cls="gelu",
 | 
					 | 
				
			||||||
            norm_cls="layer_norm_1d",
 | 
					 | 
				
			||||||
            dropout=0.2,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self._generator = get_model(**model_kwargs)
 | 
					 | 
				
			||||||
        self._return_container = return_container
 | 
					 | 
				
			||||||
        print("generator: {:}".format(self._generator))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward_raw(self, task_embed_id):
 | 
					 | 
				
			||||||
        layer_embed = self._super_layer_embed
 | 
					 | 
				
			||||||
        task_embed = (
 | 
					 | 
				
			||||||
            self._super_task_embed[task_embed_id]
 | 
					 | 
				
			||||||
            .view(1, -1)
 | 
					 | 
				
			||||||
            .expand(self._num_layers, -1)
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        joint_embed = torch.cat((task_embed, layer_embed), dim=-1)
 | 
					 | 
				
			||||||
        weights = self._generator(joint_embed)
 | 
					 | 
				
			||||||
        if self._return_container:
 | 
					 | 
				
			||||||
            weights = torch.split(weights, 1)
 | 
					 | 
				
			||||||
            return self._shape_container.translate(weights)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return weights
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward_candidate(self, input):
 | 
					 | 
				
			||||||
        raise NotImplementedError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def extra_repr(self) -> str:
 | 
					 | 
				
			||||||
        return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class HyperNet_VX(super_core.SuperModule):
 | 
					 | 
				
			||||||
    def __init__(self, shape_container, input_embeding, return_container=True):
 | 
					 | 
				
			||||||
        super(HyperNet_VX, self).__init__()
 | 
					 | 
				
			||||||
        self._shape_container = shape_container
 | 
					 | 
				
			||||||
        self._num_layers = len(shape_container)
 | 
					 | 
				
			||||||
        self._numel_per_layer = []
 | 
					 | 
				
			||||||
        for ilayer in range(self._num_layers):
 | 
					 | 
				
			||||||
            self._numel_per_layer.append(shape_container[ilayer].numel())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.register_parameter(
 | 
					 | 
				
			||||||
            "_super_layer_embed",
 | 
					 | 
				
			||||||
            torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        trunc_normal_(self._super_layer_embed, std=0.02)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        model_kwargs = dict(
 | 
					 | 
				
			||||||
            input_dim=input_embeding,
 | 
					 | 
				
			||||||
            output_dim=max(self._numel_per_layer),
 | 
					 | 
				
			||||||
            hidden_dim=input_embeding * 4,
 | 
					 | 
				
			||||||
            act_cls="sigmoid",
 | 
					 | 
				
			||||||
            norm_cls="identity",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs)
 | 
					 | 
				
			||||||
        self._return_container = return_container
 | 
					 | 
				
			||||||
        print("generator: {:}".format(self._generator))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward_raw(self, input):
 | 
					 | 
				
			||||||
        weights = self._generator(self._super_layer_embed)
 | 
					 | 
				
			||||||
        if self._return_container:
 | 
					 | 
				
			||||||
            weights = torch.split(weights, 1)
 | 
					 | 
				
			||||||
            return self._shape_container.translate(weights)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return weights
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward_candidate(self, input):
 | 
					 | 
				
			||||||
        raise NotImplementedError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def extra_repr(self) -> str:
 | 
					 | 
				
			||||||
        return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape))
 | 
					 | 
				
			||||||
@@ -35,7 +35,7 @@ from xautodl.models.xcore import get_model
 | 
				
			|||||||
from xautodl.xlayers import super_core, trunc_normal_
 | 
					from xautodl.xlayers import super_core, trunc_normal_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from lfna_utils import lfna_setup, train_model, TimeData
 | 
					from lfna_utils import lfna_setup, train_model, TimeData
 | 
				
			||||||
from lfna_meta_model import MetaModelV1
 | 
					from meta_model import MetaModelV1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
 | 
					def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
 | 
				
			||||||
@@ -106,7 +106,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        optimizer.zero_grad()
 | 
					        optimizer.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True)
 | 
					        generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        batch_indexes = random.choices(total_indexes, k=args.meta_batch)
 | 
					        batch_indexes = random.choices(total_indexes, k=args.meta_batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -219,11 +219,11 @@ def main(args):
 | 
				
			|||||||
    w_containers, loss_meter = online_evaluate(
 | 
					    w_containers, loss_meter = online_evaluate(
 | 
				
			||||||
        all_env, meta_model, base_model, criterion, args, logger, True
 | 
					        all_env, meta_model, base_model, criterion, args, logger, True
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
 | 
					    logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    save_checkpoint(
 | 
					    save_checkpoint(
 | 
				
			||||||
        {"w_containers": w_containers},
 | 
					        {"all_w_containers": w_containers},
 | 
				
			||||||
        logger.path(None) / "final-ckp.pth",
 | 
					        logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
 | 
				
			||||||
        logger,
 | 
					        logger,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -154,8 +154,9 @@ class MetaModelV1(super_core.SuperModule):
 | 
				
			|||||||
                    (self._append_meta_embed["fixed"], meta_embed), dim=0
 | 
					                    (self._append_meta_embed["fixed"], meta_embed), dim=0
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _obtain_time_embed(self, timestamps):
 | 
					    def gen_time_embed(self, timestamps):
 | 
				
			||||||
        # timestamps is a batch of sequence of timestamps
 | 
					        # timestamps is a batch of timestamps
 | 
				
			||||||
 | 
					        [B] = timestamps.shape
 | 
				
			||||||
        # batch, seq = timestamps.shape
 | 
					        # batch, seq = timestamps.shape
 | 
				
			||||||
        timestamps = timestamps.view(-1, 1)
 | 
					        timestamps = timestamps.view(-1, 1)
 | 
				
			||||||
        meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
 | 
					        meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
 | 
				
			||||||
@@ -179,15 +180,8 @@ class MetaModelV1(super_core.SuperModule):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        return timestamp_embeds[:, -1, :]
 | 
					        return timestamp_embeds[:, -1, :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward_raw(self, timestamps, time_embeds, tembed_only=False):
 | 
					    def gen_model(self, time_embeds):
 | 
				
			||||||
        if time_embeds is None:
 | 
					        B, _ = time_embeds.shape
 | 
				
			||||||
            [B] = timestamps.shape
 | 
					 | 
				
			||||||
            time_embeds = self._obtain_time_embed(timestamps)
 | 
					 | 
				
			||||||
        else:  # use the hyper-net only
 | 
					 | 
				
			||||||
            time_seq = None
 | 
					 | 
				
			||||||
            B, _ = time_embeds.shape
 | 
					 | 
				
			||||||
        if tembed_only:
 | 
					 | 
				
			||||||
            return time_embeds
 | 
					 | 
				
			||||||
        # create joint embed
 | 
					        # create joint embed
 | 
				
			||||||
        num_layer, _ = self._super_layer_embed.shape
 | 
					        num_layer, _ = self._super_layer_embed.shape
 | 
				
			||||||
        # The shape of `joint_embed` is batch * num-layers * input-dim
 | 
					        # The shape of `joint_embed` is batch * num-layers * input-dim
 | 
				
			||||||
@@ -206,6 +200,9 @@ class MetaModelV1(super_core.SuperModule):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
        return batch_containers, time_embeds
 | 
					        return batch_containers, time_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward_raw(self, timestamps, time_embeds, tembed_only=False):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward_candidate(self, input):
 | 
					    def forward_candidate(self, input):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1,8 +1,9 @@
 | 
				
			|||||||
#####################################################
 | 
					#####################################################
 | 
				
			||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
 | 
					# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
 | 
				
			||||||
############################################################################
 | 
					############################################################################
 | 
				
			||||||
# python exps/GMOA/vis-synthetic.py --env_version v1                       #
 | 
					# python exps/GeMOSA/vis-synthetic.py --env_version v1                     #
 | 
				
			||||||
# python exps/GMOA/vis-synthetic.py --env_version v2                       #
 | 
					# python exps/GeMOSA/vis-synthetic.py --env_version v2                     #
 | 
				
			||||||
 | 
					# python exps/GeMOSA/vis-synthetic.py --env_version v2                     #
 | 
				
			||||||
############################################################################
 | 
					############################################################################
 | 
				
			||||||
import os, sys, copy, random
 | 
					import os, sys, copy, random
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
@@ -181,7 +182,7 @@ def compare_cl(save_dir):
 | 
				
			|||||||
def visualize_env(save_dir, version):
 | 
					def visualize_env(save_dir, version):
 | 
				
			||||||
    save_dir = Path(str(save_dir))
 | 
					    save_dir = Path(str(save_dir))
 | 
				
			||||||
    for substr in ("pdf", "png"):
 | 
					    for substr in ("pdf", "png"):
 | 
				
			||||||
        sub_save_dir = save_dir / substr
 | 
					        sub_save_dir = save_dir / "{:}-{:}".format(substr, version)
 | 
				
			||||||
        sub_save_dir.mkdir(parents=True, exist_ok=True)
 | 
					        sub_save_dir.mkdir(parents=True, exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dynamic_env = get_synthetic_env(version=version)
 | 
					    dynamic_env = get_synthetic_env(version=version)
 | 
				
			||||||
@@ -190,6 +191,8 @@ def visualize_env(save_dir, version):
 | 
				
			|||||||
        allxs.append(allx)
 | 
					        allxs.append(allx)
 | 
				
			||||||
        allys.append(ally)
 | 
					        allys.append(ally)
 | 
				
			||||||
    allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
 | 
					    allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
 | 
				
			||||||
 | 
					    print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
 | 
				
			||||||
 | 
					    print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
 | 
				
			||||||
    for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
 | 
					    for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
 | 
				
			||||||
        dpi, width, height = 30, 1800, 1400
 | 
					        dpi, width, height = 30, 1800, 1400
 | 
				
			||||||
        figsize = width / float(dpi), height / float(dpi)
 | 
					        figsize = width / float(dpi), height / float(dpi)
 | 
				
			||||||
@@ -210,14 +213,22 @@ def visualize_env(save_dir, version):
 | 
				
			|||||||
        cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
 | 
					        cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
 | 
				
			||||||
        cur_ax.legend(loc=1, fontsize=LegendFontsize)
 | 
					        cur_ax.legend(loc=1, fontsize=LegendFontsize)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
 | 
					        pdf_save_path = (
 | 
				
			||||||
 | 
					            save_dir
 | 
				
			||||||
 | 
					            / "pdf-{:}".format(version)
 | 
				
			||||||
 | 
					            / "v{:}-{:05d}.pdf".format(version, idx)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
 | 
					        fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
 | 
				
			||||||
        png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
 | 
					        png_save_path = (
 | 
				
			||||||
 | 
					            save_dir
 | 
				
			||||||
 | 
					            / "png-{:}".format(version)
 | 
				
			||||||
 | 
					            / "v{:}-{:05d}.png".format(version, idx)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
 | 
					        fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
 | 
				
			||||||
        plt.close("all")
 | 
					        plt.close("all")
 | 
				
			||||||
    save_dir = save_dir.resolve()
 | 
					    save_dir = save_dir.resolve()
 | 
				
			||||||
    base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
 | 
					    base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
 | 
				
			||||||
        xdir=save_dir / "png", version=version
 | 
					        xdir=save_dir / "png-{:}".format(version), version=version
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    print(base_cmd)
 | 
					    print(base_cmd)
 | 
				
			||||||
    os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
 | 
					    os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
 | 
				
			||||||
@@ -367,7 +378,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
 | 
					    visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version)
 | 
				
			||||||
    # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
 | 
					    # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
 | 
				
			||||||
    compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
 | 
					    # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
 | 
				
			||||||
    # compare_cl(os.path.join(args.save_dir, "compare-cl"))
 | 
					    # compare_cl(os.path.join(args.save_dir, "compare-cl"))
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@
 | 
				
			|||||||
from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
 | 
					from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
 | 
				
			||||||
from .math_dynamic_funcs import DynamicLinearFunc
 | 
					from .math_dynamic_funcs import DynamicLinearFunc
 | 
				
			||||||
from .math_dynamic_funcs import DynamicQuadraticFunc
 | 
					from .math_dynamic_funcs import DynamicQuadraticFunc
 | 
				
			||||||
 | 
					from .math_dynamic_funcs import DynamicSinQuadraticFunc
 | 
				
			||||||
from .math_adv_funcs import ConstantFunc
 | 
					from .math_adv_funcs import ConstantFunc
 | 
				
			||||||
from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc
 | 
					from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc
 | 
				
			||||||
from .math_dynamic_generator import GaussianDGenerator
 | 
					from .math_dynamic_generator import GaussianDGenerator
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,9 +5,6 @@ import math
 | 
				
			|||||||
import abc
 | 
					import abc
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from typing import Optional
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
import torch.utils.data as data
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .math_base_funcs import FitFunc
 | 
					from .math_base_funcs import FitFunc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -68,10 +65,11 @@ class DynamicQuadraticFunc(DynamicFunc):
 | 
				
			|||||||
    def __init__(self, params=None):
 | 
					    def __init__(self, params=None):
 | 
				
			||||||
        super(DynamicQuadraticFunc, self).__init__(3, params)
 | 
					        super(DynamicQuadraticFunc, self).__init__(3, params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, x, timestamp=None):
 | 
					    def __call__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
        self.check_valid()
 | 
					        self.check_valid()
 | 
				
			||||||
        if timestamp is None:
 | 
					 | 
				
			||||||
            timestamp = self._timestamp
 | 
					 | 
				
			||||||
        a = self._params[0](timestamp)
 | 
					        a = self._params[0](timestamp)
 | 
				
			||||||
        b = self._params[1](timestamp)
 | 
					        b = self._params[1](timestamp)
 | 
				
			||||||
        c = self._params[2](timestamp)
 | 
					        c = self._params[2](timestamp)
 | 
				
			||||||
@@ -80,10 +78,38 @@ class DynamicQuadraticFunc(DynamicFunc):
 | 
				
			|||||||
        return a * x * x + b * x + c
 | 
					        return a * x * x + b * x + c
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __repr__(self):
 | 
					    def __repr__(self):
 | 
				
			||||||
        return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format(
 | 
					        return "{name}({a} * x^2 + {b} * x + {c})".format(
 | 
				
			||||||
 | 
					            name=self.__class__.__name__,
 | 
				
			||||||
 | 
					            a=self._params[0],
 | 
				
			||||||
 | 
					            b=self._params[1],
 | 
				
			||||||
 | 
					            c=self._params[2],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DynamicSinQuadraticFunc(DynamicFunc):
 | 
				
			||||||
 | 
					    """The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c).
 | 
				
			||||||
 | 
					    The a, b, and c is a function of timestamp.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, params=None):
 | 
				
			||||||
 | 
					        super(DynamicSinQuadraticFunc, self).__init__(3, params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        self.check_valid()
 | 
				
			||||||
 | 
					        a = self._params[0](timestamp)
 | 
				
			||||||
 | 
					        b = self._params[1](timestamp)
 | 
				
			||||||
 | 
					        c = self._params[2](timestamp)
 | 
				
			||||||
 | 
					        convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
 | 
				
			||||||
 | 
					        a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
 | 
				
			||||||
 | 
					        return math.sin(a * x * x + b * x + c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __repr__(self):
 | 
				
			||||||
 | 
					        return "{name}({a} * x^2 + {b} * x + {c})".format(
 | 
				
			||||||
            name=self.__class__.__name__,
 | 
					            name=self.__class__.__name__,
 | 
				
			||||||
            a=self._params[0],
 | 
					            a=self._params[0],
 | 
				
			||||||
            b=self._params[1],
 | 
					            b=self._params[1],
 | 
				
			||||||
            c=self._params[2],
 | 
					            c=self._params[2],
 | 
				
			||||||
            timestamp=self._timestamp,
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,7 +3,7 @@ from .synthetic_utils import TimeStamp
 | 
				
			|||||||
from .synthetic_env import SyntheticDEnv
 | 
					from .synthetic_env import SyntheticDEnv
 | 
				
			||||||
from .math_core import LinearFunc
 | 
					from .math_core import LinearFunc
 | 
				
			||||||
from .math_core import DynamicLinearFunc
 | 
					from .math_core import DynamicLinearFunc
 | 
				
			||||||
from .math_core import DynamicQuadraticFunc
 | 
					from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc
 | 
				
			||||||
from .math_core import (
 | 
					from .math_core import (
 | 
				
			||||||
    ConstantFunc,
 | 
					    ConstantFunc,
 | 
				
			||||||
    ComposedSinFunc as SinFunc,
 | 
					    ComposedSinFunc as SinFunc,
 | 
				
			||||||
@@ -63,9 +63,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
 | 
				
			|||||||
        time_generator = TimeStamp(
 | 
					        time_generator = TimeStamp(
 | 
				
			||||||
            min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
 | 
					            min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        oracle_map = DynamicQuadraticFunc(
 | 
					        oracle_map = DynamicSinQuadraticFunc(
 | 
				
			||||||
            params={
 | 
					            params={
 | 
				
			||||||
                0: LinearFunc(params={0: 0.1, 1: 0}),  # 0.1 * t
 | 
					                0: CosFunc(params={0: 0.5, 1: 1, 2: 1}),  # 0.5 cos(t) + 1
 | 
				
			||||||
                1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t)
 | 
					                1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t)
 | 
				
			||||||
                2: ConstantFunc(0),
 | 
					                2: ConstantFunc(0),
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,3 @@
 | 
				
			|||||||
import math
 | 
					 | 
				
			||||||
import random
 | 
					 | 
				
			||||||
from typing import List, Optional, Dict
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.utils.data as data
 | 
					import torch.utils.data as data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -43,6 +40,18 @@ class SyntheticDEnv(data.Dataset):
 | 
				
			|||||||
        self._oracle_map = oracle_map
 | 
					        self._oracle_map = oracle_map
 | 
				
			||||||
        self._num_per_task = num_per_task
 | 
					        self._num_per_task = num_per_task
 | 
				
			||||||
        self._noise = noise
 | 
					        self._noise = noise
 | 
				
			||||||
 | 
					        self._meta_info = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_regression(self):
 | 
				
			||||||
 | 
					        self._meta_info["task"] = "regression"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_classification(self, num_classes):
 | 
				
			||||||
 | 
					        self._meta_info["task"] = "classification"
 | 
				
			||||||
 | 
					        self._meta_info["num_classes"] = int(num_classes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def meta_info(self):
 | 
				
			||||||
 | 
					        return self._meta_info
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def min_timestamp(self):
 | 
					    def min_timestamp(self):
 | 
				
			||||||
@@ -60,13 +69,6 @@ class SyntheticDEnv(data.Dataset):
 | 
				
			|||||||
    def mode(self):
 | 
					    def mode(self):
 | 
				
			||||||
        return self._time_generator.mode
 | 
					        return self._time_generator.mode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def random_timestamp(self, min_timestamp=None, max_timestamp=None):
 | 
					 | 
				
			||||||
        if min_timestamp is None:
 | 
					 | 
				
			||||||
            min_timestamp = self.min_timestamp
 | 
					 | 
				
			||||||
        if max_timestamp is None:
 | 
					 | 
				
			||||||
            max_timestamp = self.max_timestamp
 | 
					 | 
				
			||||||
        return random.random() * (max_timestamp - min_timestamp) + min_timestamp
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_timestamp(self, index):
 | 
					    def get_timestamp(self, index):
 | 
				
			||||||
        if index is None:
 | 
					        if index is None:
 | 
				
			||||||
            timestamps = []
 | 
					            timestamps = []
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user