Fix bugs in LFNA
This commit is contained in:
		| @@ -7,6 +7,7 @@ echo "script-directory: $script_dir" | |||||||
| cp ${script_dir}/tmux.conf ~/.tmux.conf | cp ${script_dir}/tmux.conf ~/.tmux.conf | ||||||
| cp ${script_dir}/vimrc ~/.vimrc | cp ${script_dir}/vimrc ~/.vimrc | ||||||
| cp ${script_dir}/bashrc ~/.bashrc | cp ${script_dir}/bashrc ~/.bashrc | ||||||
|  | cp ${script_dir}/condarc ~/.condarc | ||||||
|  |  | ||||||
| wget https://repo.anaconda.com/miniconda/Miniconda3-4.7.12.1-Linux-x86_64.sh | wget https://repo.anaconda.com/miniconda/Miniconda3-4.7.12.1-Linux-x86_64.sh | ||||||
| wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh | wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh | ||||||
|   | |||||||
| @@ -161,10 +161,7 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", |         "--hidden_dim", type=int, required=True, help="The hidden dimension.", | ||||||
|         type=int, |  | ||||||
|         required=True, |  | ||||||
|         help="The hidden dimension.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
| @@ -173,16 +170,10 @@ if __name__ == "__main__": | |||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--batch_size", |         "--batch_size", type=int, default=512, help="The batch size", | ||||||
|         type=int, |  | ||||||
|         default=512, |  | ||||||
|         help="The batch size", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", |         "--epochs", type=int, default=1000, help="The total number of epochs.", | ||||||
|         type=int, |  | ||||||
|         default=1000, |  | ||||||
|         help="The total number of epochs.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--srange", type=str, required=True, help="The range of models to be evaluated" |         "--srange", type=str, required=True, help="The range of models to be evaluated" | ||||||
|   | |||||||
| @@ -41,10 +41,7 @@ class MAML: | |||||||
|         ) |         ) | ||||||
|         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|             self.meta_optimizer, |             self.meta_optimizer, | ||||||
|             milestones=[ |             milestones=[int(epochs * 0.8), int(epochs * 0.9),], | ||||||
|                 int(epochs * 0.8), |  | ||||||
|                 int(epochs * 0.9), |  | ||||||
|             ], |  | ||||||
|             gamma=0.1, |             gamma=0.1, | ||||||
|         ) |         ) | ||||||
|         self.inner_lr = inner_lr |         self.inner_lr = inner_lr | ||||||
| @@ -197,10 +194,7 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", |         "--hidden_dim", type=int, default=16, help="The hidden dimension.", | ||||||
|         type=int, |  | ||||||
|         default=16, |  | ||||||
|         help="The hidden dimension.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_lr", |         "--meta_lr", | ||||||
| @@ -230,16 +224,10 @@ if __name__ == "__main__": | |||||||
|         help="The gap between prev_time and current_timestamp", |         help="The gap between prev_time and current_timestamp", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", |         "--meta_batch", type=int, default=64, help="The batch size for the meta-model", | ||||||
|         type=int, |  | ||||||
|         default=64, |  | ||||||
|         help="The batch size for the meta-model", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", |         "--epochs", type=int, default=2000, help="The total number of epochs.", | ||||||
|         type=int, |  | ||||||
|         default=2000, |  | ||||||
|         help="The total number of epochs.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|   | |||||||
| @@ -149,10 +149,7 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", |         "--hidden_dim", type=int, required=True, help="The hidden dimension.", | ||||||
|         type=int, |  | ||||||
|         required=True, |  | ||||||
|         help="The hidden dimension.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
| @@ -167,16 +164,10 @@ if __name__ == "__main__": | |||||||
|         help="The gap between prev_time and current_timestamp", |         help="The gap between prev_time and current_timestamp", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--batch_size", |         "--batch_size", type=int, default=512, help="The batch size", | ||||||
|         type=int, |  | ||||||
|         default=512, |  | ||||||
|         help="The batch size", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", |         "--epochs", type=int, default=300, help="The total number of epochs.", | ||||||
|         type=int, |  | ||||||
|         default=300, |  | ||||||
|         help="The total number of epochs.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", |         "--workers", | ||||||
|   | |||||||
| @@ -149,10 +149,7 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", |         "--hidden_dim", type=int, required=True, help="The hidden dimension.", | ||||||
|         type=int, |  | ||||||
|         required=True, |  | ||||||
|         help="The hidden dimension.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
| @@ -161,16 +158,10 @@ if __name__ == "__main__": | |||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--batch_size", |         "--batch_size", type=int, default=512, help="The batch size", | ||||||
|         type=int, |  | ||||||
|         default=512, |  | ||||||
|         help="The batch size", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", |         "--epochs", type=int, default=300, help="The total number of epochs.", | ||||||
|         type=int, |  | ||||||
|         default=300, |  | ||||||
|         help="The total number of epochs.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", |         "--workers", | ||||||
|   | |||||||
| @@ -62,10 +62,7 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         optimizer, |         optimizer, | ||||||
|         milestones=[ |         milestones=[int(args.epochs * 0.8), int(args.epochs * 0.9),], | ||||||
|             int(args.epochs * 0.8), |  | ||||||
|             int(args.epochs * 0.9), |  | ||||||
|         ], |  | ||||||
|         gamma=0.1, |         gamma=0.1, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
| @@ -173,10 +170,7 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", |         "--hidden_dim", type=int, required=True, help="The hidden dimension.", | ||||||
|         type=int, |  | ||||||
|         required=True, |  | ||||||
|         help="The hidden dimension.", |  | ||||||
|     ) |     ) | ||||||
|     ##### |     ##### | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -186,10 +180,7 @@ if __name__ == "__main__": | |||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", |         "--meta_batch", type=int, default=64, help="The batch size for the meta-model", | ||||||
|         type=int, |  | ||||||
|         default=64, |  | ||||||
|         help="The batch size for the meta-model", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
| @@ -198,22 +189,13 @@ if __name__ == "__main__": | |||||||
|         help="The maximum epochs for early stop.", |         help="The maximum epochs for early stop.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", |         "--epochs", type=int, default=2000, help="The total number of epochs.", | ||||||
|         type=int, |  | ||||||
|         default=2000, |  | ||||||
|         help="The total number of epochs.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--per_epoch_step", |         "--per_epoch_step", type=int, default=20, help="The total number of epochs.", | ||||||
|         type=int, |  | ||||||
|         default=20, |  | ||||||
|         help="The total number of epochs.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--device", |         "--device", type=str, default="cpu", help="", | ||||||
|         type=str, |  | ||||||
|         default="cpu", |  | ||||||
|         help="", |  | ||||||
|     ) |     ) | ||||||
|     # Random Seed |     # Random Seed | ||||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|   | |||||||
| @@ -101,10 +101,7 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         optimizer, |         optimizer, | ||||||
|         milestones=[ |         milestones=[int(args.epochs * 0.8), int(args.epochs * 0.9),], | ||||||
|             int(args.epochs * 0.8), |  | ||||||
|             int(args.epochs * 0.9), |  | ||||||
|         ], |  | ||||||
|         gamma=0.1, |         gamma=0.1, | ||||||
|     ) |     ) | ||||||
|     logger.log("The base-model is\n{:}".format(base_model)) |     logger.log("The base-model is\n{:}".format(base_model)) | ||||||
| @@ -166,7 +163,7 @@ def main(args): | |||||||
|     w_container_per_epoch = dict() |     w_container_per_epoch = dict() | ||||||
|     for idx in range(args.seq_length, len(eval_env)): |     for idx in range(args.seq_length, len(eval_env)): | ||||||
|         # build-timestamp |         # build-timestamp | ||||||
|         future_time = env_info["{:}-timestamp".format(idx)] |         future_time = env_info["{:}-timestamp".format(idx)].item() | ||||||
|         time_seqs = [] |         time_seqs = [] | ||||||
|         for iseq in range(args.seq_length): |         for iseq in range(args.seq_length): | ||||||
|             time_seqs.append(future_time - iseq * eval_env.timestamp_interval) |             time_seqs.append(future_time - iseq * eval_env.timestamp_interval) | ||||||
| @@ -190,7 +187,7 @@ def main(args): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         # creating the new meta-time-embedding |         # creating the new meta-time-embedding | ||||||
|         distance = meta_model.get_closest_meta_distance(future_time.item()) |         distance = meta_model.get_closest_meta_distance(future_time) | ||||||
|         if distance < eval_env.timestamp_interval: |         if distance < eval_env.timestamp_interval: | ||||||
|             continue |             continue | ||||||
|         # |         # | ||||||
| @@ -198,7 +195,9 @@ def main(args): | |||||||
|         optimizer = torch.optim.Adam( |         optimizer = torch.optim.Adam( | ||||||
|             [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True |             [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True | ||||||
|         ) |         ) | ||||||
|         meta_model.replace_append_learnt(torch.Tensor([future_time]).to(args.device), new_param) |         meta_model.replace_append_learnt( | ||||||
|  |             torch.Tensor([future_time], device=args.device), new_param | ||||||
|  |         ) | ||||||
|         meta_model.eval() |         meta_model.eval() | ||||||
|         base_model.train() |         base_model.train() | ||||||
|         for iepoch in range(args.epochs): |         for iepoch in range(args.epochs): | ||||||
| @@ -241,22 +240,13 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", |         "--hidden_dim", type=int, default=16, help="The hidden dimension.", | ||||||
|         type=int, |  | ||||||
|         default=16, |  | ||||||
|         help="The hidden dimension.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--layer_dim", |         "--layer_dim", type=int, default=16, help="The layer chunk dimension.", | ||||||
|         type=int, |  | ||||||
|         default=16, |  | ||||||
|         help="The layer chunk dimension.", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--time_dim", |         "--time_dim", type=int, default=16, help="The timestamp dimension.", | ||||||
|         type=int, |  | ||||||
|         default=16, |  | ||||||
|         help="The timestamp dimension.", |  | ||||||
|     ) |     ) | ||||||
|     ##### |     ##### | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -272,10 +262,7 @@ if __name__ == "__main__": | |||||||
|         help="The weight decay for the optimizer (default is Adam)", |         help="The weight decay for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", |         "--meta_batch", type=int, default=64, help="The batch size for the meta-model", | ||||||
|         type=int, |  | ||||||
|         default=64, |  | ||||||
|         help="The batch size for the meta-model", |  | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--sampler_enlarge", |         "--sampler_enlarge", | ||||||
| @@ -297,10 +284,7 @@ if __name__ == "__main__": | |||||||
|         "--workers", type=int, default=4, help="The number of workers in parallel." |         "--workers", type=int, default=4, help="The number of workers in parallel." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--device", |         "--device", type=str, default="cpu", help="", | ||||||
|         type=str, |  | ||||||
|         default="cpu", |  | ||||||
|         help="", |  | ||||||
|     ) |     ) | ||||||
|     # Random Seed |     # Random Seed | ||||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|   | |||||||
| @@ -75,8 +75,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|  |  | ||||||
|         # unknown token |         # unknown token | ||||||
|         self.register_parameter( |         self.register_parameter( | ||||||
|             "_unknown_token", |             "_unknown_token", torch.nn.Parameter(torch.Tensor(1, time_embedding)), | ||||||
|             torch.nn.Parameter(torch.Tensor(1, time_embedding)), |  | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # initialization |         # initialization | ||||||
|   | |||||||
| @@ -164,10 +164,8 @@ def compare_cl(save_dir): | |||||||
|         ) |         ) | ||||||
|     print("Save all figures into {:}".format(save_dir)) |     print("Save all figures into {:}".format(save_dir)) | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     base_cmd = ( |     base_cmd = "ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format( | ||||||
|         "ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format( |         xdir=save_dir | ||||||
|             xdir=save_dir |  | ||||||
|         ) |  | ||||||
|     ) |     ) | ||||||
|     video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format( |     video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format( | ||||||
|         base_cmd, xdir=save_dir |         base_cmd, xdir=save_dir | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user