diff --git a/exps/LFNA/lfna-tall-hpnet.py b/exps/LFNA/lfna-tall-hpnet.py index f05821c..0dccefd 100644 --- a/exps/LFNA/lfna-tall-hpnet.py +++ b/exps/LFNA/lfna-tall-hpnet.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 16 +# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 64 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -33,7 +33,7 @@ from lfna_models import HyperNet def main(args): logger, env_info, model_kwargs = lfna_setup(args) dynamic_env = env_info["dynamic_env"] - model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + model = get_model(**model_kwargs) criterion = torch.nn.MSELoss() logger.log("There are {:} weights.".format(model.get_w_container().numel())) @@ -72,7 +72,7 @@ def main(args): ) limit_bar = float(iepoch + 1) / args.epochs * total_bar - limit_bar = min(max(0, int(limit_bar)), total_bar) + limit_bar = min(max(32, int(limit_bar)), total_bar) losses = [] for ibatch in range(args.meta_batch): cur_time = random.randint(0, limit_bar) diff --git a/exps/LFNA/lfna-test-hpnet.py b/exps/LFNA/lfna-test-hpnet.py index b4aa9c1..76993de 100644 --- a/exps/LFNA/lfna-test-hpnet.py +++ b/exps/LFNA/lfna-test-hpnet.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 +# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 50000 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -33,17 +33,17 @@ from lfna_models import HyperNet def main(args): logger, env_info, model_kwargs = lfna_setup(args) dynamic_env = env_info["dynamic_env"] - model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + model = get_model(**model_kwargs) model = model.to(args.device) criterion = torch.nn.MSELoss() logger.log("There are {:} weights.".format(model.get_w_container().numel())) shape_container = model.get_w_container().to_shape_container() - hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) + hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim) hypernet = hypernet.to(args.device) # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) - total_bar = 10 + total_bar = 16 task_embeds = [] for i in range(total_bar): tensor = torch.Tensor(1, args.task_dim).to(args.device) @@ -51,8 +51,12 @@ def main(args): for task_embed in task_embeds: trunc_normal_(task_embed, std=0.02) + model.train() + hypernet.train() + parameters = list(hypernet.parameters()) + task_embeds - optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True) + # optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True) + optimizer = torch.optim.Adam(parameters, lr=args.init_lr, weight_decay=1e-5) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ @@ -98,7 +102,7 @@ def main(args): lr_scheduler.step() loss_meter.update(final_loss.item()) - if iepoch % 200 == 0: + if iepoch % 100 == 0: logger.log( head_str + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( @@ -126,6 +130,26 @@ def main(args): print(model) print(hypernet) + w_container_per_epoch = dict() + for idx in range(0, total_bar): + future_time = env_info["{:}-timestamp".format(idx)] + future_x = env_info["{:}-x".format(idx)] + future_y = env_info["{:}-y".format(idx)] + future_container = hypernet(task_embeds[idx]) + w_container_per_epoch[idx] = future_container.no_grad_clone() + with torch.no_grad(): + future_y_hat = model.forward_with_container( + future_x, w_container_per_epoch[idx] + ) + future_loss = criterion(future_y_hat, future_y) + logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) + + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) + logger.log("-" * 200 + "\n") logger.close() @@ -150,6 +174,12 @@ if __name__ == "__main__": required=True, help="The hidden dimension.", ) + parser.add_argument( + "--layer_dim", + type=int, + required=True, + help="The hidden dimension.", + ) ##### parser.add_argument( "--init_lr", @@ -181,7 +211,7 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.task_dim = args.hidden_dim + args.task_dim = args.layer_dim args.save_dir = "{:}-{:}-d{:}".format( args.save_dir, args.env_version, args.hidden_dim ) diff --git a/exps/LFNA/lfna-ttss-hpnet.py b/exps/LFNA/lfna-ttss-hpnet.py index 1f3bbde..a3e85a7 100644 --- a/exps/LFNA/lfna-ttss-hpnet.py +++ b/exps/LFNA/lfna-ttss-hpnet.py @@ -31,7 +31,7 @@ from lfna_models import HyperNet_VX as HyperNet def main(args): logger, env_info, model_kwargs = lfna_setup(args) dynamic_env = env_info["dynamic_env"] - model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + model = get_model(**model_kwargs) total_time = env_info["total"] for i in range(total_time): diff --git a/exps/LFNA/lfna_models.py b/exps/LFNA/lfna_models.py index 063d5b6..b4dbcbc 100644 --- a/exps/LFNA/lfna_models.py +++ b/exps/LFNA/lfna_models.py @@ -4,6 +4,8 @@ import copy import torch +import torch.nn.functional as F + from xlayers import super_core from xlayers import trunc_normal_ from models.xcore import get_model @@ -29,13 +31,15 @@ class HyperNet(super_core.SuperModule): trunc_normal_(self._super_layer_embed, std=0.02) model_kwargs = dict( + config=dict(model_type="dual_norm_mlp"), input_dim=layer_embeding + task_embedding, output_dim=max(self._numel_per_layer), - hidden_dims=[layer_embeding * 4] * 4, + hidden_dims=[layer_embeding * 4] * 3, act_cls="gelu", norm_cls="layer_norm_1d", + dropout=0.1, ) - self._generator = get_model(dict(model_type="norm_mlp"), **model_kwargs) + self._generator = get_model(**model_kwargs) """ model_kwargs = dict( input_dim=layer_embeding + task_embedding, @@ -50,8 +54,12 @@ class HyperNet(super_core.SuperModule): print("generator: {:}".format(self._generator)) def forward_raw(self, task_embed): + # task_embed = F.normalize(task_embed, dim=-1, p=2) + # layer_embed = F.normalize(self._super_layer_embed, dim=-1, p=2) + layer_embed = self._super_layer_embed task_embed = task_embed.view(1, -1).expand(self._num_layers, -1) - joint_embed = torch.cat((task_embed, self._super_layer_embed), dim=-1) + + joint_embed = torch.cat((task_embed, layer_embed), dim=-1) weights = self._generator(joint_embed) if self._return_container: weights = torch.split(weights, 1) diff --git a/lib/models/xcore.py b/lib/models/xcore.py index 143278c..819f272 100644 --- a/lib/models/xcore.py +++ b/lib/models/xcore.py @@ -11,6 +11,7 @@ __all__ = ["get_model"] from xlayers.super_core import SuperSequential from xlayers.super_core import SuperLinear +from xlayers.super_core import SuperDropout from xlayers.super_core import super_name2norm from xlayers.super_core import super_name2activation @@ -47,7 +48,20 @@ def get_model(config: Dict[Text, Any], **kwargs): last_dim = hidden_dim sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) model = SuperSequential(*sub_layers) - + elif model_type == "dual_norm_mlp": + act_cls = super_name2activation[kwargs["act_cls"]] + norm_cls = super_name2norm[kwargs["norm_cls"]] + sub_layers, last_dim = [], kwargs["input_dim"] + for i, hidden_dim in enumerate(kwargs["hidden_dims"]): + if i > 0: + sub_layers.append(norm_cls(last_dim, elementwise_affine=False)) + sub_layers.append(SuperLinear(last_dim, hidden_dim)) + sub_layers.append(SuperDropout(kwargs["dropout"])) + sub_layers.append(SuperLinear(hidden_dim, hidden_dim)) + sub_layers.append(act_cls()) + last_dim = hidden_dim + sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) + model = SuperSequential(*sub_layers) else: raise TypeError("Unkonwn model type: {:}".format(model_type)) return model diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index a6564d4..03aa6c0 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -14,6 +14,7 @@ from .super_norm import SuperSimpleNorm from .super_norm import SuperLayerNorm1D from .super_norm import SuperSimpleLearnableNorm from .super_norm import SuperIdentity +from .super_dropout import SuperDropout super_name2norm = { "simple_norm": SuperSimpleNorm, diff --git a/lib/xlayers/super_dropout.py b/lib/xlayers/super_dropout.py new file mode 100644 index 0000000..124f2db --- /dev/null +++ b/lib/xlayers/super_dropout.py @@ -0,0 +1,40 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +from typing import Optional, Callable + +import spaces +from .super_module import SuperModule +from .super_module import IntSpaceType +from .super_module import BoolSpaceType + + +class SuperDropout(SuperModule): + """Applies a the dropout function element-wise.""" + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + super(SuperDropout, self).__init__() + self._p = p + self._inplace = inplace + + @property + def abstract_search_space(self): + return spaces.VirtualNode(id(self)) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + return F.dropout(input, self._p, self.training, self._inplace) + + def forward_with_container(self, input, container, prefix=[]): + return self.forward_raw(input) + + def extra_repr(self) -> str: + xstr = "inplace=True" if self._inplace else "" + return "p={:}".format(self._p) + ", " + xstr diff --git a/lib/xlayers/super_norm.py b/lib/xlayers/super_norm.py index b745e9d..00e6530 100644 --- a/lib/xlayers/super_norm.py +++ b/lib/xlayers/super_norm.py @@ -74,6 +74,19 @@ class SuperLayerNorm1D(SuperModule): def forward_raw(self, input: torch.Tensor) -> torch.Tensor: return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps) + def forward_with_container(self, input, container, prefix=[]): + super_weight_name = ".".join(prefix + ["weight"]) + if container.has(super_weight_name): + weight = container.query(super_weight_name) + else: + weight = None + super_bias_name = ".".join(prefix + ["bias"]) + if container.has(super_bias_name): + bias = container.query(super_bias_name) + else: + bias = None + return F.layer_norm(input, (self.in_dim,), weight, bias, self.eps) + def extra_repr(self) -> str: return ( "shape={in_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(