Upgrade lfna debug
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 | # python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 16 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -42,7 +42,7 @@ def main(args): | |||||||
|     hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) |     hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) | ||||||
|     total_bar = env_info["total"] - 1 |     total_bar = env_info["total"] - 1 | ||||||
|     task_embeds = [] |     task_embeds = [] | ||||||
|     for i in range(total_bar): |     for i in range(env_info["total"]): | ||||||
|         task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) |         task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) | ||||||
|     for task_embed in task_embeds: |     for task_embed in task_embeds: | ||||||
|         trunc_normal_(task_embed, std=0.02) |         trunc_normal_(task_embed, std=0.02) | ||||||
| @@ -109,7 +109,7 @@ def main(args): | |||||||
|             save_checkpoint( |             save_checkpoint( | ||||||
|                 { |                 { | ||||||
|                     "hypernet": hypernet.state_dict(), |                     "hypernet": hypernet.state_dict(), | ||||||
|                     "task_embed": task_embed, |                     "task_embeds": task_embeds, | ||||||
|                     "lr_scheduler": lr_scheduler.state_dict(), |                     "lr_scheduler": lr_scheduler.state_dict(), | ||||||
|                     "iepoch": iepoch, |                     "iepoch": iepoch, | ||||||
|                 }, |                 }, | ||||||
| @@ -122,6 +122,25 @@ def main(args): | |||||||
|  |  | ||||||
|     print(model) |     print(model) | ||||||
|     print(hypernet) |     print(hypernet) | ||||||
|  |     w_container_per_epoch = dict() | ||||||
|  |     for idx in range(0, env_info["total"]): | ||||||
|  |         future_time = env_info["{:}-timestamp".format(idx)] | ||||||
|  |         future_x = env_info["{:}-x".format(idx)] | ||||||
|  |         future_y = env_info["{:}-y".format(idx)] | ||||||
|  |         future_container = hypernet(task_embeds[idx]) | ||||||
|  |         w_container_per_epoch[idx] = future_container.no_grad_clone() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             future_y_hat = model.forward_with_container( | ||||||
|  |                 future_x, w_container_per_epoch[idx] | ||||||
|  |             ) | ||||||
|  |             future_loss = criterion(future_y_hat, future_y) | ||||||
|  |         logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) | ||||||
|  |  | ||||||
|  |     save_checkpoint( | ||||||
|  |         {"w_container_per_epoch": w_container_per_epoch}, | ||||||
|  |         logger.path(None) / "final-ckp.pth", | ||||||
|  |         logger, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
|     logger.close() |     logger.close() | ||||||
|   | |||||||
| @@ -34,17 +34,20 @@ def main(args): | |||||||
|     logger, env_info, model_kwargs = lfna_setup(args) |     logger, env_info, model_kwargs = lfna_setup(args) | ||||||
|     dynamic_env = env_info["dynamic_env"] |     dynamic_env = env_info["dynamic_env"] | ||||||
|     model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |     model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|  |     model = model.to(args.device) | ||||||
|     criterion = torch.nn.MSELoss() |     criterion = torch.nn.MSELoss() | ||||||
|  |  | ||||||
|     logger.log("There are {:} weights.".format(model.get_w_container().numel())) |     logger.log("There are {:} weights.".format(model.get_w_container().numel())) | ||||||
|  |  | ||||||
|     shape_container = model.get_w_container().to_shape_container() |     shape_container = model.get_w_container().to_shape_container() | ||||||
|     hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) |     hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) | ||||||
|  |     hypernet = hypernet.to(args.device) | ||||||
|     # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) |     # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) | ||||||
|     total_bar = 10 |     total_bar = 10 | ||||||
|     task_embeds = [] |     task_embeds = [] | ||||||
|     for i in range(total_bar): |     for i in range(total_bar): | ||||||
|         task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) |         tensor = torch.Tensor(1, args.task_dim).to(args.device) | ||||||
|  |         task_embeds.append(torch.nn.Parameter(tensor)) | ||||||
|     for task_embed in task_embeds: |     for task_embed in task_embeds: | ||||||
|         trunc_normal_(task_embed, std=0.02) |         trunc_normal_(task_embed, std=0.02) | ||||||
|  |  | ||||||
| @@ -79,8 +82,8 @@ def main(args): | |||||||
|             # cur_time = random.randint(0, total_bar) |             # cur_time = random.randint(0, total_bar) | ||||||
|             cur_task_embed = task_embeds[cur_time] |             cur_task_embed = task_embeds[cur_time] | ||||||
|             cur_container = hypernet(cur_task_embed) |             cur_container = hypernet(cur_task_embed) | ||||||
|             cur_x = env_info["{:}-x".format(cur_time)] |             cur_x = env_info["{:}-x".format(cur_time)].to(args.device) | ||||||
|             cur_y = env_info["{:}-y".format(cur_time)] |             cur_y = env_info["{:}-y".format(cur_time)].to(args.device) | ||||||
|             cur_dataset = TimeData(cur_time, cur_x, cur_y) |             cur_dataset = TimeData(cur_time, cur_x, cur_y) | ||||||
|  |  | ||||||
|             preds = model.forward_with_container(cur_dataset.x, cur_container) |             preds = model.forward_with_container(cur_dataset.x, cur_container) | ||||||
| @@ -166,6 +169,12 @@ if __name__ == "__main__": | |||||||
|         default=2000, |         default=2000, | ||||||
|         help="The total number of epochs.", |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--device", | ||||||
|  |         type=str, | ||||||
|  |         default="cpu", | ||||||
|  |         help="", | ||||||
|  |     ) | ||||||
|     # Random Seed |     # Random Seed | ||||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user