From 89448e433f6d7070c02a8abfdf3dad53a0708b0c Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 28 May 2021 02:22:59 +0800 Subject: [PATCH] XY --- exps/GeMOSA/baselines/maml-ft.py | 6 +++--- exps/GeMOSA/baselines/maml-nof.py | 8 ++++---- tests/test_synthetic_env.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/exps/GeMOSA/baselines/maml-ft.py b/exps/GeMOSA/baselines/maml-ft.py index 1ea1a33..4dadb20 100644 --- a/exps/GeMOSA/baselines/maml-ft.py +++ b/exps/GeMOSA/baselines/maml-ft.py @@ -1,9 +1,9 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5 -# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5 -# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5 +# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5 --device cuda +# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5 --device cuda +# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5 --device cuda # python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda ##################################################### import sys, time, copy, torch, random, argparse diff --git a/exps/GeMOSA/baselines/maml-nof.py b/exps/GeMOSA/baselines/maml-nof.py index ce33c1c..88ed819 100644 --- a/exps/GeMOSA/baselines/maml-nof.py +++ b/exps/GeMOSA/baselines/maml-nof.py @@ -1,10 +1,10 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 -# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 -# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 -# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 +# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 --device cuda +# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 --inner_step 5 --device cuda +# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 --inner_step 5 --device cuda +# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm diff --git a/tests/test_synthetic_env.py b/tests/test_synthetic_env.py index a34b968..f96d6fc 100644 --- a/tests/test_synthetic_env.py +++ b/tests/test_synthetic_env.py @@ -16,5 +16,5 @@ class TestSynethicEnv(unittest.TestCase): for version in versions: env = get_synthetic_env(version=version) print(env) - for timestamp, tau in env: - self.assertEqual(tau.shape, (1000, env.ndim)) + for timestamp, (x, y) in env: + self.assertEqual(x.shape, (1000, env.ndim))