Update LFNA -- refine

This commit is contained in:
D-X-Y 2021-05-23 19:26:09 +00:00
parent b1064e5a60
commit da4b61f3ab

View File

@ -93,8 +93,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
return loss_meter
def online_evaluate(env, meta_model, base_model, criterion, args, logger):
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
logger.log("Online evaluate: {:}".format(env))
loss_meter = AverageMeter()
w_containers = dict()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
with torch.no_grad():
meta_model.eval()
@ -102,9 +104,12 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
_, [future_container], time_embeds = meta_model(
future_time.to(args.device).view(1, 1), None, True
)
if save:
w_containers[idx] = future_container.no_grad_clone()
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
@ -123,6 +128,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
)
meta_model.clear_fixed()
meta_model.clear_learnt()
return w_containers, loss_meter
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
@ -219,8 +225,10 @@ def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
logger.log("training enviornment: {:}".format(train_env))
logger.log("validation enviornment: {:}".format(valid_env))
all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The total enviornment: {:}".format(all_env))
base_model = get_model(**model_kwargs)
base_model = base_model.to(args.device)
@ -249,10 +257,20 @@ def main(args):
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
# try to evaluate once
online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
pdb.set_trace()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
return
"""
optimizer = torch.optim.Adam(
meta_model.get_parameters(True, True, False), # fix hypernet
lr=args.lr,
@ -364,7 +382,6 @@ def main(args):
# meta-test
meta_model.load_best()
eval_env = env_info["dynamic_env"]
w_container_per_epoch = dict()
for idx in range(args.seq_length, len(eval_env)):
# build-timestamp
future_time = env_info["{:}-timestamp".format(idx)].item()
@ -424,6 +441,7 @@ def main(args):
logger.path(None) / "final-ckp.pth",
logger,
)
"""
logger.log("-" * 200 + "\n")
logger.close()
@ -494,7 +512,7 @@ if __name__ == "__main__":
help="The learning rate for the optimizer, during refine",
)
parser.add_argument(
"--refine_epochs", type=int, default=40, help="The final refine #epochs."
"--refine_epochs", type=int, default=50, help="The final refine #epochs."
)
parser.add_argument(
"--early_stop_thresh",