Try a different model / LFNA v2
This commit is contained in:
		| @@ -3,7 +3,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --workers 0 | # python exps/LFNA/lfna.py --env_version v1 --workers 0 | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 | # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 | # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -164,7 +164,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|             timestamp = meta_model.meta_timestamps[rand_index] |             timestamp = meta_model.meta_timestamps[rand_index] | ||||||
|             meta_embed = meta_model.super_meta_embed[rand_index] |             meta_embed = meta_model.super_meta_embed[rand_index] | ||||||
|  |  | ||||||
|             timestamps, [container], time_embeds = meta_model( |             _, [container], time_embed = meta_model( | ||||||
|                 torch.unsqueeze(timestamp, dim=0), None, True |                 torch.unsqueeze(timestamp, dim=0), None, True | ||||||
|             ) |             ) | ||||||
|             _, (inputs, targets) = xenv(timestamp.item()) |             _, (inputs, targets) = xenv(timestamp.item()) | ||||||
| @@ -173,30 +173,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|             predictions = base_model.forward_with_container(inputs, container) |             predictions = base_model.forward_with_container(inputs, container) | ||||||
|             total_meta_v1_losses.append(criterion(predictions, targets)) |             total_meta_v1_losses.append(criterion(predictions, targets)) | ||||||
|             # the matching loss |             # the matching loss | ||||||
|             match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embed) |             match_loss = criterion(torch.squeeze(time_embed, dim=0), meta_embed) | ||||||
|             total_match_losses.append(match_loss) |             total_match_losses.append(match_loss) | ||||||
|             # generate models via memory |             # generate models via memory | ||||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) |             _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), True) | ||||||
|             _, [seq_containers], _ = meta_model( |             predictions = base_model.forward_with_container(inputs, container) | ||||||
|                 None, |             total_meta_v2_losses.append(criterion(predictions, targets)) | ||||||
|                 torch.unsqueeze( |  | ||||||
|                     meta_model.super_meta_embed[ |  | ||||||
|                         rand_index : rand_index + xenv.seq_length |  | ||||||
|                     ], |  | ||||||
|                     dim=0, |  | ||||||
|                 ), |  | ||||||
|                 False, |  | ||||||
|             ) |  | ||||||
|             timestamps = meta_model.meta_timestamps[ |  | ||||||
|                 rand_index : rand_index + xenv.seq_length |  | ||||||
|             ] |  | ||||||
|             _, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) |  | ||||||
|             seq_inputs, seq_targets = seq_inputs.to(device), seq_targets.to(device) |  | ||||||
|             for container, inputs, targets in zip( |  | ||||||
|                 seq_containers, seq_inputs, seq_targets |  | ||||||
|             ): |  | ||||||
|                 predictions = base_model.forward_with_container(inputs, container) |  | ||||||
|                 total_meta_v2_losses.append(criterion(predictions, targets)) |  | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_std = torch.stack(total_meta_v1_losses).std().item() |             meta_std = torch.stack(total_meta_v1_losses).std().item() | ||||||
|         meta_v1_loss = torch.stack(total_meta_v1_losses).mean() |         meta_v1_loss = torch.stack(total_meta_v1_losses).mean() | ||||||
| @@ -564,8 +546,9 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|     args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( |     args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( | ||||||
|         args.save_dir, |         args.save_dir, | ||||||
|  |         args.meta_batch, | ||||||
|         args.hidden_dim, |         args.hidden_dim, | ||||||
|         args.layer_dim, |         args.layer_dim, | ||||||
|         args.time_dim, |         args.time_dim, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user