DEBUG
This commit is contained in:
parent
cbd2afb4ef
commit
5457dcf042
@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1
|
||||
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@ -36,7 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000):
|
||||
|
||||
|
||||
def main(args):
|
||||
logger, env_info = lfna_setup(args)
|
||||
logger, env_info, model_kwargs = lfna_setup(args)
|
||||
|
||||
# check indexes to be evaluated
|
||||
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
|
||||
@ -71,13 +71,6 @@ def main(args):
|
||||
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
|
||||
historical_x, historical_y = subsample(historical_x, historical_y)
|
||||
# build model
|
||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
||||
model_kwargs = dict(
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="identity",
|
||||
)
|
||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||
# build optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||
@ -167,6 +160,12 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="The synthetic enviornment version.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_dim",
|
||||
type=int,
|
||||
required=True,
|
||||
help="The hidden dimension.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_lr",
|
||||
type=float,
|
||||
|
@ -169,7 +169,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/lfna-synthetic/maml",
|
||||
default="./outputs/lfna-synthetic/use-maml",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -178,6 +178,12 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="The synthetic enviornment version.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_dim",
|
||||
type=int,
|
||||
required=True,
|
||||
help="The hidden dimension.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_lr",
|
||||
type=float,
|
||||
@ -217,4 +223,7 @@ if __name__ == "__main__":
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||
args.save_dir = "{:}-{:}-d{:}".format(
|
||||
args.save_dir, args.env_version, args.hidden_dim
|
||||
)
|
||||
main(args)
|
||||
|
Loading…
Reference in New Issue
Block a user