Update LFNA test

This commit is contained in:
D-X-Y 2021-05-13 03:40:04 +00:00
parent 0b1ca45c44
commit ee5d8a8e21
2 changed files with 26 additions and 4 deletions

View File

@ -1,7 +1,8 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 50000
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.02
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.02 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@ -37,19 +38,31 @@ def main(args):
model = model.to(args.device)
criterion = torch.nn.MSELoss()
logger.log("There are {:} weights.".format(model.get_w_container().numel()))
shape_container = model.get_w_container().to_shape_container()
hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim)
hypernet = hypernet.to(args.device)
logger.log(
"{:} There are {:} weights in the base-model.".format(
time_string(), model.numel()
)
)
logger.log(
"{:} There are {:} weights in the meta-model.".format(
time_string(), hypernet.numel()
)
)
# task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim))
total_bar = 16
total_bar = 100
task_embeds = []
for i in range(total_bar):
tensor = torch.Tensor(1, args.task_dim).to(args.device)
task_embeds.append(torch.nn.Parameter(tensor))
for task_embed in task_embeds:
trunc_normal_(task_embed, std=0.02)
for i in range(total_bar):
env_info["{:}-x".format(i)] = env_info["{:}-x".format(i)].to(args.device)
env_info["{:}-y".format(i)] = env_info["{:}-y".format(i)].to(args.device)
model.train()
hypernet.train()

View File

@ -75,6 +75,15 @@ class SuperModule(abc.ABC, nn.Module):
)
print(finalstr)
def numel(self, buffer=True):
total = 0
for name, param in self.named_parameters():
total += param.numel()
if buffer:
for name, buf in self.named_buffers():
total += buf.numel()
return total
@property
def abstract_search_space(self):
raise NotImplementedError