Update MLAML

This commit is contained in:
D-X-Y 2021-05-27 17:41:32 +00:00
parent c6db1ef65a
commit 9af34ea94d
3 changed files with 21 additions and 27 deletions

View File

@ -1,10 +1,10 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5
# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16
# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32
# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32
# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5
# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5
# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5
# python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@ -155,6 +155,8 @@ def main(args):
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
@ -195,8 +197,6 @@ def main(args):
train_results = train_metric.get_info()
return train_results, future_container
train_results, future_container = finetune(0)
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(test_env):
@ -212,7 +212,9 @@ def main(args):
)
# build optimizer
future_x.to(args.device), future_y.to(args.device)
train_results, future_container = finetune(idx)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
@ -237,7 +239,7 @@ if __name__ == "__main__":
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-maml-nft",
default="./outputs/GeMOSA-synthetic/use-maml-ft",
help="The checkpoint directory.",
)
parser.add_argument(

View File

@ -155,6 +155,8 @@ def main(args):
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
@ -212,7 +214,7 @@ def main(args):
)
# build optimizer
future_x.to(args.device), future_y.to(args.device)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
@ -237,7 +239,7 @@ if __name__ == "__main__":
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-maml-nft",
default="./outputs/GeMOSA-synthetic/use-maml-nft",
help="The checkpoint directory.",
)
parser.add_argument(

View File

@ -5,26 +5,16 @@
#####################################################
import unittest
from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc
from xautodl.datasets.synthetic_core import SyntheticDEnv
from xautodl.datasets.synthetic_core import get_synthetic_env
class TestSynethicEnv(unittest.TestCase):
"""Test the synethtic environment."""
def test_simple(self):
mean_generator = ConstantFunc(constant=0.1)
std_generator = ConstantFunc(constant=0.5)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
print(dataset)
for timestamp, tau in dataset:
self.assertEqual(tau.shape, (5000, 1))
def test_length(self):
mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3})
std_generator = ConstantFunc(constant=0.5)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
self.assertEqual(len(dataset), 100)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], mode="train")
self.assertEqual(len(dataset), 60)
versions = ["v1", "v2", "v3", "v4"]
for version in versions:
env = get_synthetic_env(version=version)
print(env)
for timestamp, tau in env:
self.assertEqual(tau.shape, (1000, env.ndim))