Update MLAML
This commit is contained in:
		| @@ -1,10 +1,10 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 | ||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 | ||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 | ||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 | ||||
| # python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5 | ||||
| # python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5 | ||||
| # python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5 | ||||
| # python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| @@ -155,6 +155,8 @@ def main(args): | ||||
|                 allys = allys.view(-1) | ||||
|             historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||
|             future_container = maml.adapt(historical_x, historical_y) | ||||
|  | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = maml.predict(future_x, future_container) | ||||
|             future_loss = maml.criterion(future_y_hat, future_y) | ||||
|             meta_losses.append(future_loss) | ||||
| @@ -195,8 +197,6 @@ def main(args): | ||||
|         train_results = train_metric.get_info() | ||||
|         return train_results, future_container | ||||
|  | ||||
|     train_results, future_container = finetune(0) | ||||
|  | ||||
|     metric = metric_cls(True) | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(test_env): | ||||
| @@ -212,7 +212,9 @@ def main(args): | ||||
|         ) | ||||
|  | ||||
|         # build optimizer | ||||
|         future_x.to(args.device), future_y.to(args.device) | ||||
|         train_results, future_container = finetune(idx) | ||||
|  | ||||
|         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|         future_y_hat = maml.predict(future_x, future_container) | ||||
|         future_loss = criterion(future_y_hat, future_y) | ||||
|         metric(future_y_hat, future_y) | ||||
| @@ -237,7 +239,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-maml-nft", | ||||
|         default="./outputs/GeMOSA-synthetic/use-maml-ft", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
| @@ -155,6 +155,8 @@ def main(args): | ||||
|                 allys = allys.view(-1) | ||||
|             historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||
|             future_container = maml.adapt(historical_x, historical_y) | ||||
|  | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = maml.predict(future_x, future_container) | ||||
|             future_loss = maml.criterion(future_y_hat, future_y) | ||||
|             meta_losses.append(future_loss) | ||||
| @@ -212,7 +214,7 @@ def main(args): | ||||
|         ) | ||||
|  | ||||
|         # build optimizer | ||||
|         future_x.to(args.device), future_y.to(args.device) | ||||
|         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|         future_y_hat = maml.predict(future_x, future_container) | ||||
|         future_loss = criterion(future_y_hat, future_y) | ||||
|         metric(future_y_hat, future_y) | ||||
| @@ -237,7 +239,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-maml-nft", | ||||
|         default="./outputs/GeMOSA-synthetic/use-maml-nft", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user