Fix bugs
This commit is contained in:
		| @@ -76,7 +76,7 @@ def online_evaluate( | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             loss_meter.update(future_loss.item()) | ||||
|             # accumulate the metric scores | ||||
|             metric(future_y_hat, future_y) | ||||
|             score = metric(future_y_hat, future_y) | ||||
|         if easy_adapt: | ||||
|             meta_model.easy_adapt(future_time.item(), future_time_embed) | ||||
|             refine, post_refine_loss = False, -1 | ||||
| @@ -92,8 +92,8 @@ def online_evaluate( | ||||
|                 {"param": future_time_embed, "loss": future_loss.item()}, | ||||
|             ) | ||||
|         logger.log( | ||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||
|                 idx, len(env), future_loss.item() | ||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}, score={:.4f}".format( | ||||
|                 idx, len(env), future_loss.item(), score | ||||
|             ) | ||||
|             + ", post-loss={:.4f}".format(post_refine_loss if refine else -1) | ||||
|         ) | ||||
| @@ -406,7 +406,22 @@ if __name__ == "__main__": | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format( | ||||
|     if args.ablation is None: | ||||
|         args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( | ||||
|             args.save_dir, | ||||
|             args.meta_batch, | ||||
|             args.hidden_dim, | ||||
|             args.layer_dim, | ||||
|             args.time_dim, | ||||
|             args.seq_length, | ||||
|             args.lr, | ||||
|             args.weight_decay, | ||||
|             args.epochs, | ||||
|             args.env_version, | ||||
|         ) | ||||
|     else: | ||||
|         args.save_dir = ( | ||||
|             "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format( | ||||
|                 args.save_dir, | ||||
|                 args.meta_batch, | ||||
|                 args.hidden_dim, | ||||
| @@ -419,4 +434,5 @@ if __name__ == "__main__": | ||||
|                 args.ablation, | ||||
|                 args.env_version, | ||||
|             ) | ||||
|         ) | ||||
|     main(args) | ||||
|   | ||||
| @@ -234,7 +234,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 time_embed = self.gen_time_embed(timestamp.view(1)) | ||||
|                 match_loss = criterion(new_param, time_embed) | ||||
|                 match_loss = F.l1_loss(new_param, time_embed) | ||||
|  | ||||
|                 [container] = self.gen_model(new_param.view(1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user