Fix bugs
This commit is contained in:
parent
e3bd75f378
commit
ffc0d16d6c
@ -76,7 +76,7 @@ def online_evaluate(
|
|||||||
future_loss = criterion(future_y_hat, future_y)
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
loss_meter.update(future_loss.item())
|
loss_meter.update(future_loss.item())
|
||||||
# accumulate the metric scores
|
# accumulate the metric scores
|
||||||
metric(future_y_hat, future_y)
|
score = metric(future_y_hat, future_y)
|
||||||
if easy_adapt:
|
if easy_adapt:
|
||||||
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
||||||
refine, post_refine_loss = False, -1
|
refine, post_refine_loss = False, -1
|
||||||
@ -92,8 +92,8 @@ def online_evaluate(
|
|||||||
{"param": future_time_embed, "loss": future_loss.item()},
|
{"param": future_time_embed, "loss": future_loss.item()},
|
||||||
)
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}, score={:.4f}".format(
|
||||||
idx, len(env), future_loss.item()
|
idx, len(env), future_loss.item(), score
|
||||||
)
|
)
|
||||||
+ ", post-loss={:.4f}".format(post_refine_loss if refine else -1)
|
+ ", post-loss={:.4f}".format(post_refine_loss if refine else -1)
|
||||||
)
|
)
|
||||||
@ -406,17 +406,33 @@ if __name__ == "__main__":
|
|||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||||
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format(
|
if args.ablation is None:
|
||||||
args.save_dir,
|
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
|
||||||
args.meta_batch,
|
args.save_dir,
|
||||||
args.hidden_dim,
|
args.meta_batch,
|
||||||
args.layer_dim,
|
args.hidden_dim,
|
||||||
args.time_dim,
|
args.layer_dim,
|
||||||
args.seq_length,
|
args.time_dim,
|
||||||
args.lr,
|
args.seq_length,
|
||||||
args.weight_decay,
|
args.lr,
|
||||||
args.epochs,
|
args.weight_decay,
|
||||||
args.ablation,
|
args.epochs,
|
||||||
args.env_version,
|
args.env_version,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
args.save_dir = (
|
||||||
|
"{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format(
|
||||||
|
args.save_dir,
|
||||||
|
args.meta_batch,
|
||||||
|
args.hidden_dim,
|
||||||
|
args.layer_dim,
|
||||||
|
args.time_dim,
|
||||||
|
args.seq_length,
|
||||||
|
args.lr,
|
||||||
|
args.weight_decay,
|
||||||
|
args.epochs,
|
||||||
|
args.ablation,
|
||||||
|
args.env_version,
|
||||||
|
)
|
||||||
|
)
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -234,7 +234,7 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
for iepoch in range(epochs):
|
for iepoch in range(epochs):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
time_embed = self.gen_time_embed(timestamp.view(1))
|
time_embed = self.gen_time_embed(timestamp.view(1))
|
||||||
match_loss = criterion(new_param, time_embed)
|
match_loss = F.l1_loss(new_param, time_embed)
|
||||||
|
|
||||||
[container] = self.gen_model(new_param.view(1, -1))
|
[container] = self.gen_model(new_param.view(1, -1))
|
||||||
y_hat = base_model.forward_with_container(x, container)
|
y_hat = base_model.forward_with_container(x, container)
|
||||||
|
Loading…
Reference in New Issue
Block a user