Update LFNA test
This commit is contained in:
		| @@ -1,7 +1,8 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 50000 | # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.02 | ||||||
|  | # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.02 --device cuda | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -37,19 +38,31 @@ def main(args): | |||||||
|     model = model.to(args.device) |     model = model.to(args.device) | ||||||
|     criterion = torch.nn.MSELoss() |     criterion = torch.nn.MSELoss() | ||||||
|  |  | ||||||
|     logger.log("There are {:} weights.".format(model.get_w_container().numel())) |  | ||||||
|  |  | ||||||
|     shape_container = model.get_w_container().to_shape_container() |     shape_container = model.get_w_container().to_shape_container() | ||||||
|     hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim) |     hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim) | ||||||
|     hypernet = hypernet.to(args.device) |     hypernet = hypernet.to(args.device) | ||||||
|  |  | ||||||
|  |     logger.log( | ||||||
|  |         "{:} There are {:} weights in the base-model.".format( | ||||||
|  |             time_string(), model.numel() | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |     logger.log( | ||||||
|  |         "{:} There are {:} weights in the meta-model.".format( | ||||||
|  |             time_string(), hypernet.numel() | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|     # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) |     # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) | ||||||
|     total_bar = 16 |     total_bar = 100 | ||||||
|     task_embeds = [] |     task_embeds = [] | ||||||
|     for i in range(total_bar): |     for i in range(total_bar): | ||||||
|         tensor = torch.Tensor(1, args.task_dim).to(args.device) |         tensor = torch.Tensor(1, args.task_dim).to(args.device) | ||||||
|         task_embeds.append(torch.nn.Parameter(tensor)) |         task_embeds.append(torch.nn.Parameter(tensor)) | ||||||
|     for task_embed in task_embeds: |     for task_embed in task_embeds: | ||||||
|         trunc_normal_(task_embed, std=0.02) |         trunc_normal_(task_embed, std=0.02) | ||||||
|  |     for i in range(total_bar): | ||||||
|  |         env_info["{:}-x".format(i)] = env_info["{:}-x".format(i)].to(args.device) | ||||||
|  |         env_info["{:}-y".format(i)] = env_info["{:}-y".format(i)].to(args.device) | ||||||
|  |  | ||||||
|     model.train() |     model.train() | ||||||
|     hypernet.train() |     hypernet.train() | ||||||
|   | |||||||
| @@ -75,6 +75,15 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|                 ) |                 ) | ||||||
|                 print(finalstr) |                 print(finalstr) | ||||||
|  |  | ||||||
|  |     def numel(self, buffer=True): | ||||||
|  |         total = 0 | ||||||
|  |         for name, param in self.named_parameters(): | ||||||
|  |             total += param.numel() | ||||||
|  |         if buffer: | ||||||
|  |             for name, buf in self.named_buffers(): | ||||||
|  |                 total += buf.numel() | ||||||
|  |         return total | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user