Fix bugs
This commit is contained in:
		| @@ -1,14 +1,18 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1 | ||||
| # python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 | ||||
| # python exps/GeMOSA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1 | ||||
| # python exps/GeMOSA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
| @@ -38,9 +42,9 @@ def subsample(historical_x, historical_y, maxn=10000): | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|     logger, model_kwargs = lfna_setup(args) | ||||
|  | ||||
|     w_container_per_epoch = dict() | ||||
|     w_containers = dict() | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for idx in range(args.prev_time, env_info["total"]): | ||||
| @@ -111,7 +115,7 @@ def main(args): | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||
|         w_containers[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
| @@ -127,7 +131,7 @@ def main(args): | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         {"w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|   | ||||
| @@ -68,6 +68,8 @@ def main(args): | ||||
|         # build model | ||||
|         model = get_model(**model_kwargs) | ||||
|         model = model.to(args.device) | ||||
|         if idx == 0: | ||||
|             print(model) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         criterion = torch.nn.MSELoss() | ||||
|   | ||||
| @@ -16,7 +16,7 @@ def lfna_setup(args): | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|         hidden_dims=[args.hidden_dim] * 2, | ||||
|         act_cls="gelu", | ||||
|         act_cls="relu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|     return logger, model_kwargs | ||||
|   | ||||
		Reference in New Issue
	
	Block a user