This commit is contained in:
D-X-Y 2021-05-27 06:18:28 -07:00
parent e3bd75f378
commit ffc0d16d6c
2 changed files with 33 additions and 17 deletions

View File

@ -76,7 +76,7 @@ def online_evaluate(
future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
# accumulate the metric scores
metric(future_y_hat, future_y)
score = metric(future_y_hat, future_y)
if easy_adapt:
meta_model.easy_adapt(future_time.item(), future_time_embed)
refine, post_refine_loss = False, -1
@ -92,8 +92,8 @@ def online_evaluate(
{"param": future_time_embed, "loss": future_loss.item()},
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
idx, len(env), future_loss.item()
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}, score={:.4f}".format(
idx, len(env), future_loss.item(), score
)
+ ", post-loss={:.4f}".format(post_refine_loss if refine else -1)
)
@ -406,17 +406,33 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,
args.ablation,
args.env_version,
)
if args.ablation is None:
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,
args.env_version,
)
else:
args.save_dir = (
"{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,
args.ablation,
args.env_version,
)
)
main(args)

View File

@ -234,7 +234,7 @@ class MetaModelV1(super_core.SuperModule):
for iepoch in range(epochs):
optimizer.zero_grad()
time_embed = self.gen_time_embed(timestamp.view(1))
match_loss = criterion(new_param, time_embed)
match_loss = F.l1_loss(new_param, time_embed)
[container] = self.gen_model(new_param.view(1, -1))
y_hat = base_model.forward_with_container(x, container)