Update LFNA -- refine
This commit is contained in:
parent
b1064e5a60
commit
da4b61f3ab
@ -93,8 +93,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
|
|||||||
return loss_meter
|
return loss_meter
|
||||||
|
|
||||||
|
|
||||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
|
||||||
logger.log("Online evaluate: {:}".format(env))
|
logger.log("Online evaluate: {:}".format(env))
|
||||||
|
loss_meter = AverageMeter()
|
||||||
|
w_containers = dict()
|
||||||
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
meta_model.eval()
|
meta_model.eval()
|
||||||
@ -102,9 +104,12 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
|||||||
_, [future_container], time_embeds = meta_model(
|
_, [future_container], time_embeds = meta_model(
|
||||||
future_time.to(args.device).view(1, 1), None, True
|
future_time.to(args.device).view(1, 1), None, True
|
||||||
)
|
)
|
||||||
|
if save:
|
||||||
|
w_containers[idx] = future_container.no_grad_clone()
|
||||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||||
future_loss = criterion(future_y_hat, future_y)
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
|
loss_meter.update(future_loss.item())
|
||||||
refine, post_refine_loss = meta_model.adapt(
|
refine, post_refine_loss = meta_model.adapt(
|
||||||
base_model,
|
base_model,
|
||||||
criterion,
|
criterion,
|
||||||
@ -123,6 +128,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
|||||||
)
|
)
|
||||||
meta_model.clear_fixed()
|
meta_model.clear_fixed()
|
||||||
meta_model.clear_learnt()
|
meta_model.clear_learnt()
|
||||||
|
return w_containers, loss_meter
|
||||||
|
|
||||||
|
|
||||||
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||||
@ -219,8 +225,10 @@ def main(args):
|
|||||||
logger, env_info, model_kwargs = lfna_setup(args)
|
logger, env_info, model_kwargs = lfna_setup(args)
|
||||||
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
||||||
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
||||||
logger.log("training enviornment: {:}".format(train_env))
|
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
||||||
logger.log("validation enviornment: {:}".format(valid_env))
|
logger.log("The training enviornment: {:}".format(train_env))
|
||||||
|
logger.log("The validation enviornment: {:}".format(valid_env))
|
||||||
|
logger.log("The total enviornment: {:}".format(all_env))
|
||||||
|
|
||||||
base_model = get_model(**model_kwargs)
|
base_model = get_model(**model_kwargs)
|
||||||
base_model = base_model.to(args.device)
|
base_model = base_model.to(args.device)
|
||||||
@ -249,10 +257,20 @@ def main(args):
|
|||||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||||
|
|
||||||
# try to evaluate once
|
# try to evaluate once
|
||||||
online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||||
online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
||||||
|
w_containers, loss_meter = online_evaluate(
|
||||||
|
all_env, meta_model, base_model, criterion, args, logger, True
|
||||||
|
)
|
||||||
|
logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
|
||||||
|
|
||||||
pdb.set_trace()
|
save_checkpoint(
|
||||||
|
{"w_containers": w_containers},
|
||||||
|
logger.path(None) / "final-ckp.pth",
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
"""
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
meta_model.get_parameters(True, True, False), # fix hypernet
|
meta_model.get_parameters(True, True, False), # fix hypernet
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
@ -364,7 +382,6 @@ def main(args):
|
|||||||
# meta-test
|
# meta-test
|
||||||
meta_model.load_best()
|
meta_model.load_best()
|
||||||
eval_env = env_info["dynamic_env"]
|
eval_env = env_info["dynamic_env"]
|
||||||
w_container_per_epoch = dict()
|
|
||||||
for idx in range(args.seq_length, len(eval_env)):
|
for idx in range(args.seq_length, len(eval_env)):
|
||||||
# build-timestamp
|
# build-timestamp
|
||||||
future_time = env_info["{:}-timestamp".format(idx)].item()
|
future_time = env_info["{:}-timestamp".format(idx)].item()
|
||||||
@ -424,6 +441,7 @@ def main(args):
|
|||||||
logger.path(None) / "final-ckp.pth",
|
logger.path(None) / "final-ckp.pth",
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
logger.log("-" * 200 + "\n")
|
logger.log("-" * 200 + "\n")
|
||||||
logger.close()
|
logger.close()
|
||||||
@ -494,7 +512,7 @@ if __name__ == "__main__":
|
|||||||
help="The learning rate for the optimizer, during refine",
|
help="The learning rate for the optimizer, during refine",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--refine_epochs", type=int, default=40, help="The final refine #epochs."
|
"--refine_epochs", type=int, default=50, help="The final refine #epochs."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stop_thresh",
|
"--early_stop_thresh",
|
||||||
|
Loading…
Reference in New Issue
Block a user