Update MLAML
This commit is contained in:
		| @@ -1,10 +1,10 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 | # python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5 | ||||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 | # python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5 | ||||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 | # python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5 | ||||||
| # python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 | # python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -155,6 +155,8 @@ def main(args): | |||||||
|                 allys = allys.view(-1) |                 allys = allys.view(-1) | ||||||
|             historical_x, historical_y = allxs.to(args.device), allys.to(args.device) |             historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||||
|             future_container = maml.adapt(historical_x, historical_y) |             future_container = maml.adapt(historical_x, historical_y) | ||||||
|  |  | ||||||
|  |             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|             future_y_hat = maml.predict(future_x, future_container) |             future_y_hat = maml.predict(future_x, future_container) | ||||||
|             future_loss = maml.criterion(future_y_hat, future_y) |             future_loss = maml.criterion(future_y_hat, future_y) | ||||||
|             meta_losses.append(future_loss) |             meta_losses.append(future_loss) | ||||||
| @@ -195,8 +197,6 @@ def main(args): | |||||||
|         train_results = train_metric.get_info() |         train_results = train_metric.get_info() | ||||||
|         return train_results, future_container |         return train_results, future_container | ||||||
|  |  | ||||||
|     train_results, future_container = finetune(0) |  | ||||||
|  |  | ||||||
|     metric = metric_cls(True) |     metric = metric_cls(True) | ||||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(test_env): |     for idx, (future_time, (future_x, future_y)) in enumerate(test_env): | ||||||
| @@ -212,7 +212,9 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         future_x.to(args.device), future_y.to(args.device) |         train_results, future_container = finetune(idx) | ||||||
|  |  | ||||||
|  |         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|         future_y_hat = maml.predict(future_x, future_container) |         future_y_hat = maml.predict(future_x, future_container) | ||||||
|         future_loss = criterion(future_y_hat, future_y) |         future_loss = criterion(future_y_hat, future_y) | ||||||
|         metric(future_y_hat, future_y) |         metric(future_y_hat, future_y) | ||||||
| @@ -237,7 +239,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--save_dir", |         "--save_dir", | ||||||
|         type=str, |         type=str, | ||||||
|         default="./outputs/lfna-synthetic/use-maml-nft", |         default="./outputs/GeMOSA-synthetic/use-maml-ft", | ||||||
|         help="The checkpoint directory.", |         help="The checkpoint directory.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|   | |||||||
| @@ -155,6 +155,8 @@ def main(args): | |||||||
|                 allys = allys.view(-1) |                 allys = allys.view(-1) | ||||||
|             historical_x, historical_y = allxs.to(args.device), allys.to(args.device) |             historical_x, historical_y = allxs.to(args.device), allys.to(args.device) | ||||||
|             future_container = maml.adapt(historical_x, historical_y) |             future_container = maml.adapt(historical_x, historical_y) | ||||||
|  |  | ||||||
|  |             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|             future_y_hat = maml.predict(future_x, future_container) |             future_y_hat = maml.predict(future_x, future_container) | ||||||
|             future_loss = maml.criterion(future_y_hat, future_y) |             future_loss = maml.criterion(future_y_hat, future_y) | ||||||
|             meta_losses.append(future_loss) |             meta_losses.append(future_loss) | ||||||
| @@ -212,7 +214,7 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         future_x.to(args.device), future_y.to(args.device) |         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|         future_y_hat = maml.predict(future_x, future_container) |         future_y_hat = maml.predict(future_x, future_container) | ||||||
|         future_loss = criterion(future_y_hat, future_y) |         future_loss = criterion(future_y_hat, future_y) | ||||||
|         metric(future_y_hat, future_y) |         metric(future_y_hat, future_y) | ||||||
| @@ -237,7 +239,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--save_dir", |         "--save_dir", | ||||||
|         type=str, |         type=str, | ||||||
|         default="./outputs/lfna-synthetic/use-maml-nft", |         default="./outputs/GeMOSA-synthetic/use-maml-nft", | ||||||
|         help="The checkpoint directory.", |         help="The checkpoint directory.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|   | |||||||
| @@ -5,26 +5,16 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc | from xautodl.datasets.synthetic_core import get_synthetic_env | ||||||
| from xautodl.datasets.synthetic_core import SyntheticDEnv |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSynethicEnv(unittest.TestCase): | class TestSynethicEnv(unittest.TestCase): | ||||||
|     """Test the synethtic environment.""" |     """Test the synethtic environment.""" | ||||||
|  |  | ||||||
|     def test_simple(self): |     def test_simple(self): | ||||||
|         mean_generator = ConstantFunc(constant=0.1) |         versions = ["v1", "v2", "v3", "v4"] | ||||||
|         std_generator = ConstantFunc(constant=0.5) |         for version in versions: | ||||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) |             env = get_synthetic_env(version=version) | ||||||
|         print(dataset) |         print(env) | ||||||
|         for timestamp, tau in dataset: |         for timestamp, tau in env: | ||||||
|             self.assertEqual(tau.shape, (5000, 1)) |             self.assertEqual(tau.shape, (1000, env.ndim)) | ||||||
|  |  | ||||||
|     def test_length(self): |  | ||||||
|         mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3}) |  | ||||||
|         std_generator = ConstantFunc(constant=0.5) |  | ||||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) |  | ||||||
|         self.assertEqual(len(dataset), 100) |  | ||||||
|  |  | ||||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], mode="train") |  | ||||||
|         self.assertEqual(len(dataset), 60) |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user