Fix bugs
This commit is contained in:
		| @@ -76,7 +76,7 @@ def online_evaluate( | |||||||
|             future_loss = criterion(future_y_hat, future_y) |             future_loss = criterion(future_y_hat, future_y) | ||||||
|             loss_meter.update(future_loss.item()) |             loss_meter.update(future_loss.item()) | ||||||
|             # accumulate the metric scores |             # accumulate the metric scores | ||||||
|             metric(future_y_hat, future_y) |             score = metric(future_y_hat, future_y) | ||||||
|         if easy_adapt: |         if easy_adapt: | ||||||
|             meta_model.easy_adapt(future_time.item(), future_time_embed) |             meta_model.easy_adapt(future_time.item(), future_time_embed) | ||||||
|             refine, post_refine_loss = False, -1 |             refine, post_refine_loss = False, -1 | ||||||
| @@ -92,8 +92,8 @@ def online_evaluate( | |||||||
|                 {"param": future_time_embed, "loss": future_loss.item()}, |                 {"param": future_time_embed, "loss": future_loss.item()}, | ||||||
|             ) |             ) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( |             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}, score={:.4f}".format( | ||||||
|                 idx, len(env), future_loss.item() |                 idx, len(env), future_loss.item(), score | ||||||
|             ) |             ) | ||||||
|             + ", post-loss={:.4f}".format(post_refine_loss if refine else -1) |             + ", post-loss={:.4f}".format(post_refine_loss if refine else -1) | ||||||
|         ) |         ) | ||||||
| @@ -406,7 +406,22 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|     args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format( |     if args.ablation is None: | ||||||
|  |         args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( | ||||||
|  |             args.save_dir, | ||||||
|  |             args.meta_batch, | ||||||
|  |             args.hidden_dim, | ||||||
|  |             args.layer_dim, | ||||||
|  |             args.time_dim, | ||||||
|  |             args.seq_length, | ||||||
|  |             args.lr, | ||||||
|  |             args.weight_decay, | ||||||
|  |             args.epochs, | ||||||
|  |             args.env_version, | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         args.save_dir = ( | ||||||
|  |             "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format( | ||||||
|                 args.save_dir, |                 args.save_dir, | ||||||
|                 args.meta_batch, |                 args.meta_batch, | ||||||
|                 args.hidden_dim, |                 args.hidden_dim, | ||||||
| @@ -419,4 +434,5 @@ if __name__ == "__main__": | |||||||
|                 args.ablation, |                 args.ablation, | ||||||
|                 args.env_version, |                 args.env_version, | ||||||
|             ) |             ) | ||||||
|  |         ) | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
| @@ -234,7 +234,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             for iepoch in range(epochs): |             for iepoch in range(epochs): | ||||||
|                 optimizer.zero_grad() |                 optimizer.zero_grad() | ||||||
|                 time_embed = self.gen_time_embed(timestamp.view(1)) |                 time_embed = self.gen_time_embed(timestamp.view(1)) | ||||||
|                 match_loss = criterion(new_param, time_embed) |                 match_loss = F.l1_loss(new_param, time_embed) | ||||||
|  |  | ||||||
|                 [container] = self.gen_model(new_param.view(1, -1)) |                 [container] = self.gen_model(new_param.view(1, -1)) | ||||||
|                 y_hat = base_model.forward_with_container(x, container) |                 y_hat = base_model.forward_with_container(x, container) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user