diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 4a1ecfa..bba6025 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -3,7 +3,7 @@ ##################################################### # python exps/LFNA/lfna.py --env_version v1 --workers 0 # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 -# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 +# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -164,7 +164,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): timestamp = meta_model.meta_timestamps[rand_index] meta_embed = meta_model.super_meta_embed[rand_index] - timestamps, [container], time_embeds = meta_model( + _, [container], time_embed = meta_model( torch.unsqueeze(timestamp, dim=0), None, True ) _, (inputs, targets) = xenv(timestamp.item()) @@ -173,30 +173,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): predictions = base_model.forward_with_container(inputs, container) total_meta_v1_losses.append(criterion(predictions, targets)) # the matching loss - match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embed) + match_loss = criterion(torch.squeeze(time_embed, dim=0), meta_embed) total_match_losses.append(match_loss) # generate models via memory - rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) - _, [seq_containers], _ = meta_model( - None, - torch.unsqueeze( - meta_model.super_meta_embed[ - rand_index : rand_index + xenv.seq_length - ], - dim=0, - ), - False, - ) - timestamps = meta_model.meta_timestamps[ - rand_index : rand_index + xenv.seq_length - ] - _, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) - seq_inputs, seq_targets = seq_inputs.to(device), seq_targets.to(device) - for container, inputs, targets in zip( - seq_containers, seq_inputs, seq_targets - ): - predictions = base_model.forward_with_container(inputs, container) - total_meta_v2_losses.append(criterion(predictions, targets)) + _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), True) + predictions = base_model.forward_with_container(inputs, container) + total_meta_v2_losses.append(criterion(predictions, targets)) with torch.no_grad(): meta_std = torch.stack(total_meta_v1_losses).std().item() meta_v1_loss = torch.stack(total_meta_v1_losses).mean() @@ -564,8 +546,9 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( + args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( args.save_dir, + args.meta_batch, args.hidden_dim, args.layer_dim, args.time_dim,