DEBUG
This commit is contained in:
parent
cbd2afb4ef
commit
5457dcf042
@ -1,7 +1,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
#####################################################
|
#####################################################
|
||||||
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1
|
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, copy, torch, random, argparse
|
import sys, time, copy, torch, random, argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -36,7 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
logger, env_info = lfna_setup(args)
|
logger, env_info, model_kwargs = lfna_setup(args)
|
||||||
|
|
||||||
# check indexes to be evaluated
|
# check indexes to be evaluated
|
||||||
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
|
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
|
||||||
@ -71,13 +71,6 @@ def main(args):
|
|||||||
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
|
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
|
||||||
historical_x, historical_y = subsample(historical_x, historical_y)
|
historical_x, historical_y = subsample(historical_x, historical_y)
|
||||||
# build model
|
# build model
|
||||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
|
||||||
model_kwargs = dict(
|
|
||||||
input_dim=1,
|
|
||||||
output_dim=1,
|
|
||||||
act_cls="leaky_relu",
|
|
||||||
norm_cls="identity",
|
|
||||||
)
|
|
||||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||||
# build optimizer
|
# build optimizer
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||||
@ -167,6 +160,12 @@ if __name__ == "__main__":
|
|||||||
required=True,
|
required=True,
|
||||||
help="The synthetic enviornment version.",
|
help="The synthetic enviornment version.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden_dim",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="The hidden dimension.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--init_lr",
|
"--init_lr",
|
||||||
type=float,
|
type=float,
|
||||||
|
@ -169,7 +169,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_dir",
|
"--save_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./outputs/lfna-synthetic/maml",
|
default="./outputs/lfna-synthetic/use-maml",
|
||||||
help="The checkpoint directory.",
|
help="The checkpoint directory.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -178,6 +178,12 @@ if __name__ == "__main__":
|
|||||||
required=True,
|
required=True,
|
||||||
help="The synthetic enviornment version.",
|
help="The synthetic enviornment version.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden_dim",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="The hidden dimension.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--meta_lr",
|
"--meta_lr",
|
||||||
type=float,
|
type=float,
|
||||||
@ -217,4 +223,7 @@ if __name__ == "__main__":
|
|||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||||
|
args.save_dir = "{:}-{:}-d{:}".format(
|
||||||
|
args.save_dir, args.env_version, args.hidden_dim
|
||||||
|
)
|
||||||
main(args)
|
main(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user