Update LFNA -- refine
This commit is contained in:
parent
b1064e5a60
commit
da4b61f3ab
@ -93,8 +93,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
|
||||
return loss_meter
|
||||
|
||||
|
||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
|
||||
logger.log("Online evaluate: {:}".format(env))
|
||||
loss_meter = AverageMeter()
|
||||
w_containers = dict()
|
||||
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
||||
with torch.no_grad():
|
||||
meta_model.eval()
|
||||
@ -102,9 +104,12 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
_, [future_container], time_embeds = meta_model(
|
||||
future_time.to(args.device).view(1, 1), None, True
|
||||
)
|
||||
if save:
|
||||
w_containers[idx] = future_container.no_grad_clone()
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
loss_meter.update(future_loss.item())
|
||||
refine, post_refine_loss = meta_model.adapt(
|
||||
base_model,
|
||||
criterion,
|
||||
@ -123,6 +128,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||
)
|
||||
meta_model.clear_fixed()
|
||||
meta_model.clear_learnt()
|
||||
return w_containers, loss_meter
|
||||
|
||||
|
||||
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
@ -219,8 +225,10 @@ def main(args):
|
||||
logger, env_info, model_kwargs = lfna_setup(args)
|
||||
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
||||
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
||||
logger.log("training enviornment: {:}".format(train_env))
|
||||
logger.log("validation enviornment: {:}".format(valid_env))
|
||||
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
||||
logger.log("The training enviornment: {:}".format(train_env))
|
||||
logger.log("The validation enviornment: {:}".format(valid_env))
|
||||
logger.log("The total enviornment: {:}".format(all_env))
|
||||
|
||||
base_model = get_model(**model_kwargs)
|
||||
base_model = base_model.to(args.device)
|
||||
@ -249,10 +257,20 @@ def main(args):
|
||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||
|
||||
# try to evaluate once
|
||||
online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||
online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
||||
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
||||
w_containers, loss_meter = online_evaluate(
|
||||
all_env, meta_model, base_model, criterion, args, logger, True
|
||||
)
|
||||
logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
|
||||
|
||||
pdb.set_trace()
|
||||
save_checkpoint(
|
||||
{"w_containers": w_containers},
|
||||
logger.path(None) / "final-ckp.pth",
|
||||
logger,
|
||||
)
|
||||
return
|
||||
"""
|
||||
optimizer = torch.optim.Adam(
|
||||
meta_model.get_parameters(True, True, False), # fix hypernet
|
||||
lr=args.lr,
|
||||
@ -364,7 +382,6 @@ def main(args):
|
||||
# meta-test
|
||||
meta_model.load_best()
|
||||
eval_env = env_info["dynamic_env"]
|
||||
w_container_per_epoch = dict()
|
||||
for idx in range(args.seq_length, len(eval_env)):
|
||||
# build-timestamp
|
||||
future_time = env_info["{:}-timestamp".format(idx)].item()
|
||||
@ -424,6 +441,7 @@ def main(args):
|
||||
logger.path(None) / "final-ckp.pth",
|
||||
logger,
|
||||
)
|
||||
"""
|
||||
|
||||
logger.log("-" * 200 + "\n")
|
||||
logger.close()
|
||||
@ -494,7 +512,7 @@ if __name__ == "__main__":
|
||||
help="The learning rate for the optimizer, during refine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--refine_epochs", type=int, default=40, help="The final refine #epochs."
|
||||
"--refine_epochs", type=int, default=50, help="The final refine #epochs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stop_thresh",
|
||||
|
Loading…
Reference in New Issue
Block a user