diff --git a/exps/LFNA/basic-his.py b/exps/LFNA/basic-his.py index 3352687..d571c1f 100644 --- a/exps/LFNA/basic-his.py +++ b/exps/LFNA/basic-his.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 +# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -36,7 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000): def main(args): - logger, env_info = lfna_setup(args) + logger, env_info, model_kwargs = lfna_setup(args) # check indexes to be evaluated to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) @@ -71,13 +71,6 @@ def main(args): historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) historical_x, historical_y = subsample(historical_x, historical_y) # build model - mean, std = historical_x.mean().item(), historical_x.std().item() - model_kwargs = dict( - input_dim=1, - output_dim=1, - act_cls="leaky_relu", - norm_cls="identity", - ) model = get_model(dict(model_type="simple_mlp"), **model_kwargs) # build optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) @@ -167,6 +160,12 @@ if __name__ == "__main__": required=True, help="The synthetic enviornment version.", ) + parser.add_argument( + "--hidden_dim", + type=int, + required=True, + help="The hidden dimension.", + ) parser.add_argument( "--init_lr", type=float, diff --git a/exps/LFNA/basic-maml.py b/exps/LFNA/basic-maml.py index 970800c..e0ead12 100644 --- a/exps/LFNA/basic-maml.py +++ b/exps/LFNA/basic-maml.py @@ -169,7 +169,7 @@ if __name__ == "__main__": parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/maml", + default="./outputs/lfna-synthetic/use-maml", help="The checkpoint directory.", ) parser.add_argument( @@ -178,6 +178,12 @@ if __name__ == "__main__": required=True, help="The synthetic enviornment version.", ) + parser.add_argument( + "--hidden_dim", + type=int, + required=True, + help="The hidden dimension.", + ) parser.add_argument( "--meta_lr", type=float, @@ -217,4 +223,7 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" + args.save_dir = "{:}-{:}-d{:}".format( + args.save_dir, args.env_version, args.hidden_dim + ) main(args)