diff --git a/exps/LFNA/lfna-test-hpnet.py b/exps/LFNA/lfna-test-hpnet.py index 76993de..6b19462 100644 --- a/exps/LFNA/lfna-test-hpnet.py +++ b/exps/LFNA/lfna-test-hpnet.py @@ -1,7 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 50000 +# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.02 +# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.02 --device cuda ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -37,19 +38,31 @@ def main(args): model = model.to(args.device) criterion = torch.nn.MSELoss() - logger.log("There are {:} weights.".format(model.get_w_container().numel())) - shape_container = model.get_w_container().to_shape_container() hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim) hypernet = hypernet.to(args.device) + + logger.log( + "{:} There are {:} weights in the base-model.".format( + time_string(), model.numel() + ) + ) + logger.log( + "{:} There are {:} weights in the meta-model.".format( + time_string(), hypernet.numel() + ) + ) # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) - total_bar = 16 + total_bar = 100 task_embeds = [] for i in range(total_bar): tensor = torch.Tensor(1, args.task_dim).to(args.device) task_embeds.append(torch.nn.Parameter(tensor)) for task_embed in task_embeds: trunc_normal_(task_embed, std=0.02) + for i in range(total_bar): + env_info["{:}-x".format(i)] = env_info["{:}-x".format(i)].to(args.device) + env_info["{:}-y".format(i)] = env_info["{:}-y".format(i)].to(args.device) model.train() hypernet.train() diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index c99be99..09b1e69 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -75,6 +75,15 @@ class SuperModule(abc.ABC, nn.Module): ) print(finalstr) + def numel(self, buffer=True): + total = 0 + for name, param in self.named_parameters(): + total += param.numel() + if buffer: + for name, buf in self.named_buffers(): + total += buf.numel() + return total + @property def abstract_search_space(self): raise NotImplementedError