Update LFNA -- refine

This commit is contained in:
D-X-Y 2021-05-23 19:26:09 +00:00
parent b1064e5a60
commit da4b61f3ab

View File

@ -93,8 +93,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
return loss_meter return loss_meter
def online_evaluate(env, meta_model, base_model, criterion, args, logger): def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
logger.log("Online evaluate: {:}".format(env)) logger.log("Online evaluate: {:}".format(env))
loss_meter = AverageMeter()
w_containers = dict()
for idx, (future_time, (future_x, future_y)) in enumerate(env): for idx, (future_time, (future_x, future_y)) in enumerate(env):
with torch.no_grad(): with torch.no_grad():
meta_model.eval() meta_model.eval()
@ -102,9 +104,12 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
_, [future_container], time_embeds = meta_model( _, [future_container], time_embeds = meta_model(
future_time.to(args.device).view(1, 1), None, True future_time.to(args.device).view(1, 1), None, True
) )
if save:
w_containers[idx] = future_container.no_grad_clone()
future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container) future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y) future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
refine, post_refine_loss = meta_model.adapt( refine, post_refine_loss = meta_model.adapt(
base_model, base_model,
criterion, criterion,
@ -123,6 +128,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
) )
meta_model.clear_fixed() meta_model.clear_fixed()
meta_model.clear_learnt() meta_model.clear_learnt()
return w_containers, loss_meter
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
@ -219,8 +225,10 @@ def main(args):
logger, env_info, model_kwargs = lfna_setup(args) logger, env_info, model_kwargs = lfna_setup(args)
train_env = get_synthetic_env(mode="train", version=args.env_version) train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version) valid_env = get_synthetic_env(mode="valid", version=args.env_version)
logger.log("training enviornment: {:}".format(train_env)) all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("validation enviornment: {:}".format(valid_env)) logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The total enviornment: {:}".format(all_env))
base_model = get_model(**model_kwargs) base_model = get_model(**model_kwargs)
base_model = base_model.to(args.device) base_model = base_model.to(args.device)
@ -249,10 +257,20 @@ def main(args):
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
# try to evaluate once # try to evaluate once
online_evaluate(train_env, meta_model, base_model, criterion, args, logger) # online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
pdb.set_trace() save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
return
"""
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
meta_model.get_parameters(True, True, False), # fix hypernet meta_model.get_parameters(True, True, False), # fix hypernet
lr=args.lr, lr=args.lr,
@ -364,7 +382,6 @@ def main(args):
# meta-test # meta-test
meta_model.load_best() meta_model.load_best()
eval_env = env_info["dynamic_env"] eval_env = env_info["dynamic_env"]
w_container_per_epoch = dict()
for idx in range(args.seq_length, len(eval_env)): for idx in range(args.seq_length, len(eval_env)):
# build-timestamp # build-timestamp
future_time = env_info["{:}-timestamp".format(idx)].item() future_time = env_info["{:}-timestamp".format(idx)].item()
@ -424,6 +441,7 @@ def main(args):
logger.path(None) / "final-ckp.pth", logger.path(None) / "final-ckp.pth",
logger, logger,
) )
"""
logger.log("-" * 200 + "\n") logger.log("-" * 200 + "\n")
logger.close() logger.close()
@ -494,7 +512,7 @@ if __name__ == "__main__":
help="The learning rate for the optimizer, during refine", help="The learning rate for the optimizer, during refine",
) )
parser.add_argument( parser.add_argument(
"--refine_epochs", type=int, default=40, help="The final refine #epochs." "--refine_epochs", type=int, default=50, help="The final refine #epochs."
) )
parser.add_argument( parser.add_argument(
"--early_stop_thresh", "--early_stop_thresh",