diff --git a/exps/GeMOSA/baselines/maml-ft.py b/exps/GeMOSA/baselines/maml-ft.py index 17c4ef2..4d86990 100644 --- a/exps/GeMOSA/baselines/maml-ft.py +++ b/exps/GeMOSA/baselines/maml-ft.py @@ -1,10 +1,10 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 -# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 -# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 -# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 +# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5 +# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5 +# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5 +# python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -155,6 +155,8 @@ def main(args): allys = allys.view(-1) historical_x, historical_y = allxs.to(args.device), allys.to(args.device) future_container = maml.adapt(historical_x, historical_y) + + future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_y_hat = maml.predict(future_x, future_container) future_loss = maml.criterion(future_y_hat, future_y) meta_losses.append(future_loss) @@ -195,8 +197,6 @@ def main(args): train_results = train_metric.get_info() return train_results, future_container - train_results, future_container = finetune(0) - metric = metric_cls(True) per_timestamp_time, start_time = AverageMeter(), time.time() for idx, (future_time, (future_x, future_y)) in enumerate(test_env): @@ -212,7 +212,9 @@ def main(args): ) # build optimizer - future_x.to(args.device), future_y.to(args.device) + train_results, future_container = finetune(idx) + + future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_y_hat = maml.predict(future_x, future_container) future_loss = criterion(future_y_hat, future_y) metric(future_y_hat, future_y) @@ -237,7 +239,7 @@ if __name__ == "__main__": parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/use-maml-nft", + default="./outputs/GeMOSA-synthetic/use-maml-ft", help="The checkpoint directory.", ) parser.add_argument( diff --git a/exps/GeMOSA/baselines/maml-nof.py b/exps/GeMOSA/baselines/maml-nof.py index 17c4ef2..bacf849 100644 --- a/exps/GeMOSA/baselines/maml-nof.py +++ b/exps/GeMOSA/baselines/maml-nof.py @@ -155,6 +155,8 @@ def main(args): allys = allys.view(-1) historical_x, historical_y = allxs.to(args.device), allys.to(args.device) future_container = maml.adapt(historical_x, historical_y) + + future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_y_hat = maml.predict(future_x, future_container) future_loss = maml.criterion(future_y_hat, future_y) meta_losses.append(future_loss) @@ -212,7 +214,7 @@ def main(args): ) # build optimizer - future_x.to(args.device), future_y.to(args.device) + future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_y_hat = maml.predict(future_x, future_container) future_loss = criterion(future_y_hat, future_y) metric(future_y_hat, future_y) @@ -237,7 +239,7 @@ if __name__ == "__main__": parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/use-maml-nft", + default="./outputs/GeMOSA-synthetic/use-maml-nft", help="The checkpoint directory.", ) parser.add_argument( diff --git a/tests/test_synthetic_env.py b/tests/test_synthetic_env.py index b1b2f7f..a34b968 100644 --- a/tests/test_synthetic_env.py +++ b/tests/test_synthetic_env.py @@ -5,26 +5,16 @@ ##################################################### import unittest -from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc -from xautodl.datasets.synthetic_core import SyntheticDEnv +from xautodl.datasets.synthetic_core import get_synthetic_env class TestSynethicEnv(unittest.TestCase): """Test the synethtic environment.""" def test_simple(self): - mean_generator = ConstantFunc(constant=0.1) - std_generator = ConstantFunc(constant=0.5) - dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) - print(dataset) - for timestamp, tau in dataset: - self.assertEqual(tau.shape, (5000, 1)) - - def test_length(self): - mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3}) - std_generator = ConstantFunc(constant=0.5) - dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) - self.assertEqual(len(dataset), 100) - - dataset = SyntheticDEnv([mean_generator], [[std_generator]], mode="train") - self.assertEqual(len(dataset), 60) + versions = ["v1", "v2", "v3", "v4"] + for version in versions: + env = get_synthetic_env(version=version) + print(env) + for timestamp, tau in env: + self.assertEqual(tau.shape, (1000, env.ndim))