Update MLAML
This commit is contained in:
parent
c6db1ef65a
commit
9af34ea94d
@ -1,10 +1,10 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5
|
||||
# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16
|
||||
# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32
|
||||
# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32
|
||||
# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5
|
||||
# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5
|
||||
# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5
|
||||
# python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@ -155,6 +155,8 @@ def main(args):
|
||||
allys = allys.view(-1)
|
||||
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
|
||||
future_container = maml.adapt(historical_x, historical_y)
|
||||
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = maml.predict(future_x, future_container)
|
||||
future_loss = maml.criterion(future_y_hat, future_y)
|
||||
meta_losses.append(future_loss)
|
||||
@ -195,8 +197,6 @@ def main(args):
|
||||
train_results = train_metric.get_info()
|
||||
return train_results, future_container
|
||||
|
||||
train_results, future_container = finetune(0)
|
||||
|
||||
metric = metric_cls(True)
|
||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||
for idx, (future_time, (future_x, future_y)) in enumerate(test_env):
|
||||
@ -212,7 +212,9 @@ def main(args):
|
||||
)
|
||||
|
||||
# build optimizer
|
||||
future_x.to(args.device), future_y.to(args.device)
|
||||
train_results, future_container = finetune(idx)
|
||||
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = maml.predict(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
metric(future_y_hat, future_y)
|
||||
@ -237,7 +239,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/lfna-synthetic/use-maml-nft",
|
||||
default="./outputs/GeMOSA-synthetic/use-maml-ft",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -155,6 +155,8 @@ def main(args):
|
||||
allys = allys.view(-1)
|
||||
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
|
||||
future_container = maml.adapt(historical_x, historical_y)
|
||||
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = maml.predict(future_x, future_container)
|
||||
future_loss = maml.criterion(future_y_hat, future_y)
|
||||
meta_losses.append(future_loss)
|
||||
@ -212,7 +214,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# build optimizer
|
||||
future_x.to(args.device), future_y.to(args.device)
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = maml.predict(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
metric(future_y_hat, future_y)
|
||||
@ -237,7 +239,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/lfna-synthetic/use-maml-nft",
|
||||
default="./outputs/GeMOSA-synthetic/use-maml-nft",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -5,26 +5,16 @@
|
||||
#####################################################
|
||||
import unittest
|
||||
|
||||
from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc
|
||||
from xautodl.datasets.synthetic_core import SyntheticDEnv
|
||||
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||
|
||||
|
||||
class TestSynethicEnv(unittest.TestCase):
|
||||
"""Test the synethtic environment."""
|
||||
|
||||
def test_simple(self):
|
||||
mean_generator = ConstantFunc(constant=0.1)
|
||||
std_generator = ConstantFunc(constant=0.5)
|
||||
dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
|
||||
print(dataset)
|
||||
for timestamp, tau in dataset:
|
||||
self.assertEqual(tau.shape, (5000, 1))
|
||||
|
||||
def test_length(self):
|
||||
mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3})
|
||||
std_generator = ConstantFunc(constant=0.5)
|
||||
dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
dataset = SyntheticDEnv([mean_generator], [[std_generator]], mode="train")
|
||||
self.assertEqual(len(dataset), 60)
|
||||
versions = ["v1", "v2", "v3", "v4"]
|
||||
for version in versions:
|
||||
env = get_synthetic_env(version=version)
|
||||
print(env)
|
||||
for timestamp, tau in env:
|
||||
self.assertEqual(tau.shape, (1000, env.ndim))
|
||||
|
Loading…
Reference in New Issue
Block a user