diff --git a/exps/LFNA/lfna-tall-hpnet.py b/exps/LFNA/lfna-tall-hpnet.py index 0dccefd..7d2dbee 100644 --- a/exps/LFNA/lfna-tall-hpnet.py +++ b/exps/LFNA/lfna-tall-hpnet.py @@ -36,7 +36,7 @@ def main(args): model = get_model(**model_kwargs) criterion = torch.nn.MSELoss() - logger.log("There are {:} weights.".format(model.get_w_container().numel())) + logger.log("There are {:} weights.".format(model.numel())) shape_container = model.get_w_container().to_shape_container() hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) diff --git a/exps/LFNA/lfna-test-hpnet.py b/exps/LFNA/lfna-test-hpnet.py index 2b815ed..1f14642 100644 --- a/exps/LFNA/lfna-test-hpnet.py +++ b/exps/LFNA/lfna-test-hpnet.py @@ -1,8 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.01 -# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.01 --device cuda +# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 20000 --init_lr 0.01 +# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -39,7 +39,8 @@ def main(args): criterion = torch.nn.MSELoss() shape_container = model.get_w_container().to_shape_container() - hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim) + total_bar = 100 + hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim, total_bar) hypernet = hypernet.to(args.device) logger.log( @@ -52,14 +53,6 @@ def main(args): time_string(), hypernet.numel() ) ) - # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) - total_bar = 100 - task_embeds = [] - for i in range(total_bar): - tensor = torch.Tensor(1, args.task_dim).to(args.device) - task_embeds.append(torch.nn.Parameter(tensor)) - for task_embed in task_embeds: - trunc_normal_(task_embed, std=0.02) for i in range(total_bar): env_info["{:}-x".format(i)] = env_info["{:}-x".format(i)].to(args.device) env_info["{:}-y".format(i)] = env_info["{:}-y".format(i)].to(args.device) @@ -67,9 +60,9 @@ def main(args): model.train() hypernet.train() - parameters = list(hypernet.parameters()) + task_embeds - # optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True) - optimizer = torch.optim.Adam(parameters, lr=args.init_lr, weight_decay=1e-5) + optimizer = torch.optim.Adam( + hypernet.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True + ) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ @@ -97,10 +90,10 @@ def main(args): # for ibatch in range(args.meta_batch): for cur_time in range(total_bar): # cur_time = random.randint(0, total_bar) - cur_task_embed = task_embeds[cur_time] - cur_container = hypernet(cur_task_embed) - cur_x = env_info["{:}-x".format(cur_time)].to(args.device) - cur_y = env_info["{:}-y".format(cur_time)].to(args.device) + # cur_task_embed = task_embeds[cur_time] + cur_container = hypernet(cur_time) + cur_x = env_info["{:}-x".format(cur_time)] + cur_y = env_info["{:}-y".format(cur_time)] cur_dataset = TimeData(cur_time, cur_x, cur_y) preds = model.forward_with_container(cur_dataset.x, cur_container) @@ -126,10 +119,14 @@ def main(args): ) ) + success, best_score = hypernet.save_best(-loss_meter.avg) + if success: + logger.log( + "Achieve the best with best_score = {:.3f}".format(best_score) + ) save_checkpoint( { "hypernet": hypernet.state_dict(), - "task_embed": task_embed, "lr_scheduler": lr_scheduler.state_dict(), "iepoch": iepoch, }, @@ -142,13 +139,15 @@ def main(args): print(model) print(hypernet) + hypernet.load_best() w_container_per_epoch = dict() for idx in range(0, total_bar): future_time = env_info["{:}-timestamp".format(idx)] future_x = env_info["{:}-x".format(idx)] future_y = env_info["{:}-y".format(idx)] - future_container = hypernet(task_embeds[idx]) + # future_container = hypernet(task_embeds[idx]) + future_container = hypernet(idx) w_container_per_epoch[idx] = future_container.no_grad_clone() with torch.no_grad(): future_y_hat = model.forward_with_container( diff --git a/exps/LFNA/lfna_models.py b/exps/LFNA/lfna_models.py index f07ffb0..2133cb2 100644 --- a/exps/LFNA/lfna_models.py +++ b/exps/LFNA/lfna_models.py @@ -15,7 +15,12 @@ class HyperNet(super_core.SuperModule): """The hyper-network.""" def __init__( - self, shape_container, layer_embeding, task_embedding, return_container=True + self, + shape_container, + layer_embeding, + task_embedding, + num_tasks, + return_container=True, ): super(HyperNet, self).__init__() self._shape_container = shape_container @@ -28,36 +33,33 @@ class HyperNet(super_core.SuperModule): "_super_layer_embed", torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), ) + self.register_parameter( + "_super_task_embed", + torch.nn.Parameter(torch.Tensor(num_tasks, task_embedding)), + ) trunc_normal_(self._super_layer_embed, std=0.02) + trunc_normal_(self._super_task_embed, std=0.02) model_kwargs = dict( config=dict(model_type="dual_norm_mlp"), input_dim=layer_embeding + task_embedding, output_dim=max(self._numel_per_layer), - hidden_dims=[layer_embeding * 2] * 3, + hidden_dims=[(layer_embeding + task_embedding) * 2] * 3, act_cls="gelu", norm_cls="layer_norm_1d", - dropout=0.1, + dropout=0.2, ) self._generator = get_model(**model_kwargs) - """ - model_kwargs = dict( - input_dim=layer_embeding + task_embedding, - output_dim=max(self._numel_per_layer), - hidden_dim=layer_embeding * 4, - act_cls="sigmoid", - norm_cls="identity", - ) - self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) - """ self._return_container = return_container print("generator: {:}".format(self._generator)) - def forward_raw(self, task_embed): - # task_embed = F.normalize(task_embed, dim=-1, p=2) - # layer_embed = F.normalize(self._super_layer_embed, dim=-1, p=2) + def forward_raw(self, task_embed_id): layer_embed = self._super_layer_embed - task_embed = task_embed.view(1, -1).expand(self._num_layers, -1) + task_embed = ( + self._super_task_embed[task_embed_id] + .view(1, -1) + .expand(self._num_layers, -1) + ) joint_embed = torch.cat((task_embed, layer_embed), dim=-1) weights = self._generator(joint_embed) diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index 09b1e69..9beeb23 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -2,7 +2,9 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # ##################################################### +import os import abc +import tempfile import warnings from typing import Optional, Union, Callable import torch @@ -16,6 +18,9 @@ from .super_utils import LayerOrder, SuperRunMode from .super_utils import TensorContainer from .super_utils import ShapeContainer +BEST_DIR_KEY = "best_model_dir" +BEST_SCORE_KEY = "best_model_score" + class SuperModule(abc.ABC, nn.Module): """This class equips the nn.Module class with the ability to apply AutoDL.""" @@ -25,6 +30,7 @@ class SuperModule(abc.ABC, nn.Module): self._super_run_type = SuperRunMode.Default self._abstract_child = None self._verbose = False + self._meta_info = {} def set_super_run_type(self, super_run_type): def _reset_super_run(m): @@ -84,6 +90,34 @@ class SuperModule(abc.ABC, nn.Module): total += buf.numel() return total + def save_best(self, score): + if BEST_DIR_KEY not in self._meta_info: + tempdir = tempfile.mkdtemp("-xlayers") + self._meta_info[BEST_DIR_KEY] = tempdir + if BEST_SCORE_KEY not in self._meta_info: + self._meta_info[BEST_SCORE_KEY] = None + best_score = self._meta_info[BEST_SCORE_KEY] + if best_score is None or best_score < score: + best_save_path = os.path.join( + self._meta_info[BEST_DIR_KEY], + "best-{:}.pth".format(self.__class__.__name__), + ) + self._meta_info[BEST_SCORE_KEY] = score + torch.save(self.state_dict(), best_save_path) + return True, self._meta_info[BEST_SCORE_KEY] + else: + return False, self._meta_info[BEST_SCORE_KEY] + + def load_best(self): + if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info: + raise ValueError("Please call save_best at first") + best_save_path = os.path.join( + self._meta_info[BEST_DIR_KEY], + "best-{:}.pth".format(self.__class__.__name__), + ) + state_dict = torch.load(best_save_path) + self.load_state_dict(state_dict) + @property def abstract_search_space(self): raise NotImplementedError