Try a different model / LFNA v2
This commit is contained in:
parent
9135667cc1
commit
be274e0b6c
@ -3,7 +3,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# python exps/LFNA/lfna.py --env_version v1 --workers 0
|
# python exps/LFNA/lfna.py --env_version v1 --workers 0
|
||||||
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
|
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
|
||||||
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002
|
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, copy, torch, random, argparse
|
import sys, time, copy, torch, random, argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -164,7 +164,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
timestamp = meta_model.meta_timestamps[rand_index]
|
timestamp = meta_model.meta_timestamps[rand_index]
|
||||||
meta_embed = meta_model.super_meta_embed[rand_index]
|
meta_embed = meta_model.super_meta_embed[rand_index]
|
||||||
|
|
||||||
timestamps, [container], time_embeds = meta_model(
|
_, [container], time_embed = meta_model(
|
||||||
torch.unsqueeze(timestamp, dim=0), None, True
|
torch.unsqueeze(timestamp, dim=0), None, True
|
||||||
)
|
)
|
||||||
_, (inputs, targets) = xenv(timestamp.item())
|
_, (inputs, targets) = xenv(timestamp.item())
|
||||||
@ -173,30 +173,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
predictions = base_model.forward_with_container(inputs, container)
|
predictions = base_model.forward_with_container(inputs, container)
|
||||||
total_meta_v1_losses.append(criterion(predictions, targets))
|
total_meta_v1_losses.append(criterion(predictions, targets))
|
||||||
# the matching loss
|
# the matching loss
|
||||||
match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embed)
|
match_loss = criterion(torch.squeeze(time_embed, dim=0), meta_embed)
|
||||||
total_match_losses.append(match_loss)
|
total_match_losses.append(match_loss)
|
||||||
# generate models via memory
|
# generate models via memory
|
||||||
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
|
_, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), True)
|
||||||
_, [seq_containers], _ = meta_model(
|
predictions = base_model.forward_with_container(inputs, container)
|
||||||
None,
|
total_meta_v2_losses.append(criterion(predictions, targets))
|
||||||
torch.unsqueeze(
|
|
||||||
meta_model.super_meta_embed[
|
|
||||||
rand_index : rand_index + xenv.seq_length
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
),
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
timestamps = meta_model.meta_timestamps[
|
|
||||||
rand_index : rand_index + xenv.seq_length
|
|
||||||
]
|
|
||||||
_, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
|
|
||||||
seq_inputs, seq_targets = seq_inputs.to(device), seq_targets.to(device)
|
|
||||||
for container, inputs, targets in zip(
|
|
||||||
seq_containers, seq_inputs, seq_targets
|
|
||||||
):
|
|
||||||
predictions = base_model.forward_with_container(inputs, container)
|
|
||||||
total_meta_v2_losses.append(criterion(predictions, targets))
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
meta_std = torch.stack(total_meta_v1_losses).std().item()
|
meta_std = torch.stack(total_meta_v1_losses).std().item()
|
||||||
meta_v1_loss = torch.stack(total_meta_v1_losses).mean()
|
meta_v1_loss = torch.stack(total_meta_v1_losses).mean()
|
||||||
@ -564,8 +546,9 @@ if __name__ == "__main__":
|
|||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||||
args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
|
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
|
||||||
args.save_dir,
|
args.save_dir,
|
||||||
|
args.meta_batch,
|
||||||
args.hidden_dim,
|
args.hidden_dim,
|
||||||
args.layer_dim,
|
args.layer_dim,
|
||||||
args.time_dim,
|
args.time_dim,
|
||||||
|
Loading…
Reference in New Issue
Block a user