LFNA -> GMOA
This commit is contained in:
		| @@ -1,9 +1,10 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Learning to Generate Model One Step Ahead         # | # Learning to Generate Model One Step Ahead         # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --workers 0 | # python exps/GMOA/lfna.py --env_version v1 --workers 0 | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 | # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.001 | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128 | # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 | ||||||
|  | # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||||
| ##################################################### | ##################################################### | ||||||
| import pdb, sys, time, copy, torch, random, argparse | import pdb, sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -33,7 +34,7 @@ from xautodl.models.xcore import get_model | |||||||
| from xautodl.xlayers import super_core, trunc_normal_ | from xautodl.xlayers import super_core, trunc_normal_ | ||||||
| 
 | 
 | ||||||
| from lfna_utils import lfna_setup, train_model, TimeData | from lfna_utils import lfna_setup, train_model, TimeData | ||||||
| from lfna_meta_model import LFNA_Meta | from lfna_meta_model import MetaModelV1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): | def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): | ||||||
| @@ -240,7 +241,7 @@ def main(args): | |||||||
| 
 | 
 | ||||||
|     # pre-train the hypernetwork |     # pre-train the hypernetwork | ||||||
|     timestamps = trainval_env.get_timestamp(None) |     timestamps = trainval_env.get_timestamp(None) | ||||||
|     meta_model = LFNA_Meta( |     meta_model = MetaModelV1( | ||||||
|         shape_container, |         shape_container, | ||||||
|         args.layer_dim, |         args.layer_dim, | ||||||
|         args.time_dim, |         args.time_dim, | ||||||
| @@ -270,179 +271,6 @@ def main(args): | |||||||
|         logger.path(None) / "final-ckp.pth", |         logger.path(None) / "final-ckp.pth", | ||||||
|         logger, |         logger, | ||||||
|     ) |     ) | ||||||
|     return |  | ||||||
|     """ |  | ||||||
|     optimizer = torch.optim.Adam( |  | ||||||
|         meta_model.get_parameters(True, True, False),  # fix hypernet |  | ||||||
|         lr=args.lr, |  | ||||||
|         weight_decay=args.weight_decay, |  | ||||||
|         amsgrad=True, |  | ||||||
|     ) |  | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |  | ||||||
|         optimizer, |  | ||||||
|         milestones=[1, 2, 3, 4, 5], |  | ||||||
|         gamma=0.2, |  | ||||||
|     ) |  | ||||||
|     logger.log("The optimizer is\n{:}".format(optimizer)) |  | ||||||
|     logger.log("The scheduler is\n{:}".format(lr_scheduler)) |  | ||||||
|     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) |  | ||||||
| 
 |  | ||||||
|     if logger.path("model").exists(): |  | ||||||
|         ckp_data = torch.load(logger.path("model")) |  | ||||||
|         base_model.load_state_dict(ckp_data["base_model"]) |  | ||||||
|         meta_model.load_state_dict(ckp_data["meta_model"]) |  | ||||||
|         optimizer.load_state_dict(ckp_data["optimizer"]) |  | ||||||
|         lr_scheduler.load_state_dict(ckp_data["lr_scheduler"]) |  | ||||||
|         last_success_epoch = ckp_data["last_success_epoch"] |  | ||||||
|         start_epoch = ckp_data["iepoch"] + 1 |  | ||||||
|         check_strs = [ |  | ||||||
|             "epochs", |  | ||||||
|             "env_version", |  | ||||||
|             "hidden_dim", |  | ||||||
|             "lr", |  | ||||||
|             "layer_dim", |  | ||||||
|             "time_dim", |  | ||||||
|             "seq_length", |  | ||||||
|         ] |  | ||||||
|         for xstr in check_strs: |  | ||||||
|             cx = getattr(args, xstr) |  | ||||||
|             px = getattr(ckp_data["args"], xstr) |  | ||||||
|             assert cx == px, "[{:}] {:} vs {:}".format(xstr, cx, ps) |  | ||||||
|         success, _ = meta_model.save_best(ckp_data["cur_score"]) |  | ||||||
|         logger.log("Load ckp from {:}".format(logger.path("model"))) |  | ||||||
|         if success: |  | ||||||
|             logger.log( |  | ||||||
|                 "Re-save the best model with score={:}".format(ckp_data["cur_score"]) |  | ||||||
|             ) |  | ||||||
|     else: |  | ||||||
|         start_epoch, last_success_epoch = 0, 0 |  | ||||||
| 
 |  | ||||||
|     # LFNA meta-train |  | ||||||
|     meta_model.set_best_dir(logger.path(None) / "checkpoint") |  | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |  | ||||||
|     for iepoch in range(start_epoch, args.epochs): |  | ||||||
| 
 |  | ||||||
|         head_str = "[{:}] [{:04d}/{:04d}] ".format( |  | ||||||
|             time_string(), iepoch, args.epochs |  | ||||||
|         ) + "Time Left: {:}".format( |  | ||||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         loss_meter = epoch_train( |  | ||||||
|             train_env_loader, |  | ||||||
|             meta_model, |  | ||||||
|             base_model, |  | ||||||
|             optimizer, |  | ||||||
|             criterion, |  | ||||||
|             args.device, |  | ||||||
|             logger, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         valid_loss_meter = epoch_evaluate( |  | ||||||
|             valid_env_loader, meta_model, base_model, criterion, args.device, logger |  | ||||||
|         ) |  | ||||||
|         logger.log( |  | ||||||
|             head_str |  | ||||||
|             + " meta-train-loss: {meter.avg:.4f} ({meter.count:.0f})".format( |  | ||||||
|                 meter=loss_meter |  | ||||||
|             ) |  | ||||||
|             + " meta-valid-loss: {meter.val:.4f}".format(meter=valid_loss_meter) |  | ||||||
|             + " :: lr={:.5f}".format(min(lr_scheduler.get_last_lr())) |  | ||||||
|             + "  :: last-success={:}".format(last_success_epoch) |  | ||||||
|         ) |  | ||||||
|         success, best_score = meta_model.save_best(-loss_meter.avg) |  | ||||||
|         if success: |  | ||||||
|             logger.log("Achieve the best with best-score = {:.5f}".format(best_score)) |  | ||||||
|             last_success_epoch = iepoch |  | ||||||
|             save_checkpoint( |  | ||||||
|                 { |  | ||||||
|                     "meta_model": meta_model.state_dict(), |  | ||||||
|                     "base_model": base_model.state_dict(), |  | ||||||
|                     "optimizer": optimizer.state_dict(), |  | ||||||
|                     "lr_scheduler": lr_scheduler.state_dict(), |  | ||||||
|                     "last_success_epoch": last_success_epoch, |  | ||||||
|                     "cur_score": -loss_meter.avg, |  | ||||||
|                     "iepoch": iepoch, |  | ||||||
|                     "args": args, |  | ||||||
|                 }, |  | ||||||
|                 logger.path("model"), |  | ||||||
|                 logger, |  | ||||||
|             ) |  | ||||||
|         if iepoch - last_success_epoch >= args.early_stop_thresh: |  | ||||||
|             if lr_scheduler.last_epoch > 4: |  | ||||||
|                 logger.log("Early stop at {:}".format(iepoch)) |  | ||||||
|                 break |  | ||||||
|             else: |  | ||||||
|                 last_success_epoch = iepoch |  | ||||||
|                 lr_scheduler.step() |  | ||||||
|                 logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch)) |  | ||||||
| 
 |  | ||||||
|         per_epoch_time.update(time.time() - start_time) |  | ||||||
|         start_time = time.time() |  | ||||||
| 
 |  | ||||||
|     # meta-test |  | ||||||
|     meta_model.load_best() |  | ||||||
|     eval_env = env_info["dynamic_env"] |  | ||||||
|     for idx in range(args.seq_length, len(eval_env)): |  | ||||||
|         # build-timestamp |  | ||||||
|         future_time = env_info["{:}-timestamp".format(idx)].item() |  | ||||||
|         time_seqs = [] |  | ||||||
|         for iseq in range(args.seq_length): |  | ||||||
|             time_seqs.append(future_time - iseq * eval_env.time_interval) |  | ||||||
|         time_seqs.reverse() |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             meta_model.eval() |  | ||||||
|             base_model.eval() |  | ||||||
|             time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) |  | ||||||
|             [seq_containers] = meta_model(time_seqs) |  | ||||||
|             future_container = seq_containers[-1] |  | ||||||
|             w_container_per_epoch[idx] = future_container.no_grad_clone() |  | ||||||
|             # evaluation |  | ||||||
|             future_x = env_info["{:}-x".format(idx)].to(args.device) |  | ||||||
|             future_y = env_info["{:}-y".format(idx)].to(args.device) |  | ||||||
|             future_y_hat = base_model.forward_with_container( |  | ||||||
|                 future_x, w_container_per_epoch[idx] |  | ||||||
|             ) |  | ||||||
|             future_loss = criterion(future_y_hat, future_y) |  | ||||||
|             logger.log( |  | ||||||
|                 "meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         # creating the new meta-time-embedding |  | ||||||
|         distance = meta_model.get_closest_meta_distance(future_time) |  | ||||||
|         if distance < eval_env.time_interval: |  | ||||||
|             continue |  | ||||||
|         # |  | ||||||
|         new_param = meta_model.create_meta_embed() |  | ||||||
|         optimizer = torch.optim.Adam( |  | ||||||
|             [new_param], lr=args.refine_lr, weight_decay=1e-5, amsgrad=True |  | ||||||
|         ) |  | ||||||
|         meta_model.replace_append_learnt( |  | ||||||
|             torch.Tensor([future_time]).to(args.device), new_param |  | ||||||
|         ) |  | ||||||
|         meta_model.eval() |  | ||||||
|         base_model.train() |  | ||||||
|         for iepoch in range(args.refine_epochs): |  | ||||||
|             optimizer.zero_grad() |  | ||||||
|             [seq_containers] = meta_model(time_seqs) |  | ||||||
|             future_container = seq_containers[-1] |  | ||||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) |  | ||||||
|             future_loss = criterion(future_y_hat, future_y) |  | ||||||
|             future_loss.backward() |  | ||||||
|             optimizer.step() |  | ||||||
|         logger.log( |  | ||||||
|             "post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) |  | ||||||
|         ) |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             meta_model.replace_append_learnt(None, None) |  | ||||||
|             meta_model.append_fixed(torch.Tensor([future_time]), new_param) |  | ||||||
| 
 |  | ||||||
|     save_checkpoint( |  | ||||||
|         {"w_container_per_epoch": w_container_per_epoch}, |  | ||||||
|         logger.path(None) / "final-ckp.pth", |  | ||||||
|         logger, |  | ||||||
|     ) |  | ||||||
|     """ |  | ||||||
| 
 | 
 | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
|     logger.close() |     logger.close() | ||||||
| @@ -513,7 +341,7 @@ if __name__ == "__main__": | |||||||
|         help="The learning rate for the optimizer, during refine", |         help="The learning rate for the optimizer, during refine", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--refine_epochs", type=int, default=50, help="The final refine #epochs." |         "--refine_epochs", type=int, default=100, help="The final refine #epochs." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
| @@ -10,8 +10,8 @@ from xautodl.xlayers import trunc_normal_ | |||||||
| from xautodl.models.xcore import get_model | from xautodl.models.xcore import get_model | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LFNA_Meta(super_core.SuperModule): | class MetaModelV1(super_core.SuperModule): | ||||||
|     """Learning to Forecast Neural Adaptation (Meta Model Design).""" |     """Learning to Generate Models One Step Ahead (Meta Model Design).""" | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
		Reference in New Issue
	
	Block a user