This commit is contained in:
D-X-Y 2021-05-10 01:05:00 +08:00
parent cbd2afb4ef
commit 5457dcf042
2 changed files with 18 additions and 10 deletions

View File

@ -1,7 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
##################################################### #####################################################
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 # python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16
##################################################### #####################################################
import sys, time, copy, torch, random, argparse import sys, time, copy, torch, random, argparse
from tqdm import tqdm from tqdm import tqdm
@ -36,7 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000):
def main(args): def main(args):
logger, env_info = lfna_setup(args) logger, env_info, model_kwargs = lfna_setup(args)
# check indexes to be evaluated # check indexes to be evaluated
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
@ -71,13 +71,6 @@ def main(args):
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
historical_x, historical_y = subsample(historical_x, historical_y) historical_x, historical_y = subsample(historical_x, historical_y)
# build model # build model
mean, std = historical_x.mean().item(), historical_x.std().item()
model_kwargs = dict(
input_dim=1,
output_dim=1,
act_cls="leaky_relu",
norm_cls="identity",
)
model = get_model(dict(model_type="simple_mlp"), **model_kwargs) model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer # build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
@ -167,6 +160,12 @@ if __name__ == "__main__":
required=True, required=True,
help="The synthetic enviornment version.", help="The synthetic enviornment version.",
) )
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument( parser.add_argument(
"--init_lr", "--init_lr",
type=float, type=float,

View File

@ -169,7 +169,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--save_dir", "--save_dir",
type=str, type=str,
default="./outputs/lfna-synthetic/maml", default="./outputs/lfna-synthetic/use-maml",
help="The checkpoint directory.", help="The checkpoint directory.",
) )
parser.add_argument( parser.add_argument(
@ -178,6 +178,12 @@ if __name__ == "__main__":
required=True, required=True,
help="The synthetic enviornment version.", help="The synthetic enviornment version.",
) )
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument( parser.add_argument(
"--meta_lr", "--meta_lr",
type=float, type=float,
@ -217,4 +223,7 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0: if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None" assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}".format(
args.save_dir, args.env_version, args.hidden_dim
)
main(args) main(args)