DEBUG
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 | # python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -36,7 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000): | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     logger, env_info = lfna_setup(args) |     logger, env_info, model_kwargs = lfna_setup(args) | ||||||
|  |  | ||||||
|     # check indexes to be evaluated |     # check indexes to be evaluated | ||||||
|     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) |     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||||
| @@ -71,13 +71,6 @@ def main(args): | |||||||
|         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) |         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) | ||||||
|         historical_x, historical_y = subsample(historical_x, historical_y) |         historical_x, historical_y = subsample(historical_x, historical_y) | ||||||
|         # build model |         # build model | ||||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() |  | ||||||
|         model_kwargs = dict( |  | ||||||
|             input_dim=1, |  | ||||||
|             output_dim=1, |  | ||||||
|             act_cls="leaky_relu", |  | ||||||
|             norm_cls="identity", |  | ||||||
|         ) |  | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
| @@ -167,6 +160,12 @@ if __name__ == "__main__": | |||||||
|         required=True, |         required=True, | ||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         required=True, | ||||||
|  |         help="The hidden dimension.", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
|         type=float, |         type=float, | ||||||
|   | |||||||
| @@ -169,7 +169,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--save_dir", |         "--save_dir", | ||||||
|         type=str, |         type=str, | ||||||
|         default="./outputs/lfna-synthetic/maml", |         default="./outputs/lfna-synthetic/use-maml", | ||||||
|         help="The checkpoint directory.", |         help="The checkpoint directory.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -178,6 +178,12 @@ if __name__ == "__main__": | |||||||
|         required=True, |         required=True, | ||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         required=True, | ||||||
|  |         help="The hidden dimension.", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_lr", |         "--meta_lr", | ||||||
|         type=float, |         type=float, | ||||||
| @@ -217,4 +223,7 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     args.save_dir = "{:}-{:}-d{:}".format( | ||||||
|  |         args.save_dir, args.env_version, args.hidden_dim | ||||||
|  |     ) | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user