Update LFNA with resume
This commit is contained in:
		| @@ -101,21 +101,49 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         optimizer, |         optimizer, | ||||||
|         milestones=[ |         milestones=[1, 2], | ||||||
|             int(args.epochs * 0.8), |  | ||||||
|             int(args.epochs * 0.9), |  | ||||||
|         ], |  | ||||||
|         gamma=0.1, |         gamma=0.1, | ||||||
|     ) |     ) | ||||||
|     logger.log("The base-model is\n{:}".format(base_model)) |     logger.log("The base-model is\n{:}".format(base_model)) | ||||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) |     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||||
|     logger.log("The optimizer is\n{:}".format(optimizer)) |     logger.log("The optimizer is\n{:}".format(optimizer)) | ||||||
|  |     logger.log("The scheduler is\n{:}".format(lr_scheduler)) | ||||||
|     logger.log("Per epoch iterations = {:}".format(len(env_loader))) |     logger.log("Per epoch iterations = {:}".format(len(env_loader))) | ||||||
|  |  | ||||||
|     # LFNA meta-training |     if logger.path("model").exists(): | ||||||
|  |         ckp_data = torch.load(logger.path("model")) | ||||||
|  |         base_model.load_state_dict(ckp_data["base_model"]) | ||||||
|  |         meta_model.load_state_dict(ckp_data["meta_model"]) | ||||||
|  |         optimizer.load_state_dict(ckp_data["optimizer"]) | ||||||
|  |         lr_scheduler.load_state_dict(ckp_data["lr_scheduler"]) | ||||||
|  |         last_success_epoch = ckp_data["last_success_epoch"] | ||||||
|  |         start_epoch = ckp_data["iepoch"] + 1 | ||||||
|  |         check_strs = [ | ||||||
|  |             "epochs", | ||||||
|  |             "env_version", | ||||||
|  |             "hidden_dim", | ||||||
|  |             "init_lr", | ||||||
|  |             "layer_dim", | ||||||
|  |             "time_dim", | ||||||
|  |             "seq_length", | ||||||
|  |         ] | ||||||
|  |         for xstr in check_strs: | ||||||
|  |             cx = getattr(args, xstr) | ||||||
|  |             px = getattr(ckp_data["args"], xstr) | ||||||
|  |             assert cx == px, "[{:}] {:} vs {:}".format(xstr, cx, ps) | ||||||
|  |         success, _ = meta_model.save_best(ckp_data["cur_score"]) | ||||||
|  |         logger.log("Load ckp from {:}".format(logger.path("model"))) | ||||||
|  |         if success: | ||||||
|  |             logger.log( | ||||||
|  |                 "Re-save the best model with score={:}".format(ckp_data["cur_score"]) | ||||||
|  |             ) | ||||||
|  |     else: | ||||||
|  |         start_epoch, last_success_epoch = 0, 0 | ||||||
|  |  | ||||||
|  |     # LFNA meta-train | ||||||
|  |     meta_model.set_best_dir(logger.path(None) / "checkpoint") | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|     last_success_epoch = 0 |     for iepoch in range(start_epoch, args.epochs): | ||||||
|     for iepoch in range(args.epochs): |  | ||||||
|  |  | ||||||
|         head_str = "[{:}] [{:04d}/{:04d}] ".format( |         head_str = "[{:}] [{:04d}/{:04d}] ".format( | ||||||
|             time_string(), iepoch, args.epochs |             time_string(), iepoch, args.epochs | ||||||
| @@ -132,11 +160,11 @@ def main(args): | |||||||
|             args.device, |             args.device, | ||||||
|             logger, |             logger, | ||||||
|         ) |         ) | ||||||
|         lr_scheduler.step() |  | ||||||
|         logger.log( |         logger.log( | ||||||
|             head_str |             head_str | ||||||
|             + " meta-loss: {meter.avg:.4f} ({meter.count:.0f})".format(meter=loss_meter) |             + " meta-loss: {meter.avg:.4f} ({meter.count:.0f})".format(meter=loss_meter) | ||||||
|             + " :: lr={:.5f}".format(min(lr_scheduler.get_last_lr())) |             + " :: lr={:.5f}".format(min(lr_scheduler.get_last_lr())) | ||||||
|  |             + "  :: last-success={:}".format(last_success_epoch) | ||||||
|         ) |         ) | ||||||
|         success, best_score = meta_model.save_best(-loss_meter.avg) |         success, best_score = meta_model.save_best(-loss_meter.avg) | ||||||
|         if success: |         if success: | ||||||
| @@ -145,8 +173,11 @@ def main(args): | |||||||
|             save_checkpoint( |             save_checkpoint( | ||||||
|                 { |                 { | ||||||
|                     "meta_model": meta_model.state_dict(), |                     "meta_model": meta_model.state_dict(), | ||||||
|  |                     "base_model": base_model.state_dict(), | ||||||
|                     "optimizer": optimizer.state_dict(), |                     "optimizer": optimizer.state_dict(), | ||||||
|                     "lr_scheduler": lr_scheduler.state_dict(), |                     "lr_scheduler": lr_scheduler.state_dict(), | ||||||
|  |                     "last_success_epoch": last_success_epoch, | ||||||
|  |                     "cur_score": -loss_meter.avg, | ||||||
|                     "iepoch": iepoch, |                     "iepoch": iepoch, | ||||||
|                     "args": args, |                     "args": args, | ||||||
|                 }, |                 }, | ||||||
| @@ -154,8 +185,12 @@ def main(args): | |||||||
|                 logger, |                 logger, | ||||||
|             ) |             ) | ||||||
|         if iepoch - last_success_epoch >= args.early_stop_thresh: |         if iepoch - last_success_epoch >= args.early_stop_thresh: | ||||||
|             logger.log("Early stop at {:}".format(iepoch)) |             if lr_scheduler.last_epoch > 2: | ||||||
|             break |                 logger.log("Early stop at {:}".format(iepoch)) | ||||||
|  |                 break | ||||||
|  |             else: | ||||||
|  |                 last_epoch.step() | ||||||
|  |                 logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch)) | ||||||
|  |  | ||||||
|         per_epoch_time.update(time.time() - start_time) |         per_epoch_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
| @@ -199,7 +234,7 @@ def main(args): | |||||||
|             [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True |             [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True | ||||||
|         ) |         ) | ||||||
|         meta_model.replace_append_learnt( |         meta_model.replace_append_learnt( | ||||||
|             torch.Tensor([future_time], device=args.device), new_param |             torch.Tensor([future_time]).to(args.device), new_param | ||||||
|         ) |         ) | ||||||
|         meta_model.eval() |         meta_model.eval() | ||||||
|         base_model.train() |         base_model.train() | ||||||
| @@ -289,8 +324,8 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|         type=int, |         type=int, | ||||||
|         default=100, |         default=50, | ||||||
|         help="The maximum epochs for early stop.", |         help="The #epochs for early stop.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--seq_length", type=int, default=5, help="The sequence length." |         "--seq_length", type=int, default=5, help="The sequence length." | ||||||
|   | |||||||
| @@ -102,9 +102,11 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         return torch.cat(meta_embed) |         return torch.cat(meta_embed) | ||||||
|  |  | ||||||
|     def create_meta_embed(self): |     def create_meta_embed(self): | ||||||
|         param = torch.nn.Parameter(torch.Tensor(1, self._time_embed_dim)) |         param = torch.Tensor(1, self._time_embed_dim) | ||||||
|         trunc_normal_(param, std=0.02) |         trunc_normal_(param, std=0.02) | ||||||
|         return param.to(self._super_meta_embed.device) |         param = param.to(self._super_meta_embed.device) | ||||||
|  |         param = torch.nn.Parameter(param, True) | ||||||
|  |         return param | ||||||
|  |  | ||||||
|     def get_closest_meta_distance(self, timestamp): |     def get_closest_meta_distance(self, timestamp): | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
| @@ -112,12 +114,14 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|             return torch.min(distances).item() |             return torch.min(distances).item() | ||||||
|  |  | ||||||
|     def replace_append_learnt(self, timestamp, meta_embed): |     def replace_append_learnt(self, timestamp, meta_embed): | ||||||
|         self._append_meta_embed["learnt"] = meta_embed |  | ||||||
|         self._append_meta_timestamps["learnt"] = timestamp |         self._append_meta_timestamps["learnt"] = timestamp | ||||||
|  |         self._append_meta_embed["learnt"] = meta_embed | ||||||
|  |  | ||||||
|     def append_fixed(self, timestamp, meta_embed): |     def append_fixed(self, timestamp, meta_embed): | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             timestamp, meta_embed = timestamp.clone(), meta_embed.clone() |             device = self._super_meta_embed.device | ||||||
|  |             timestamp = timestamp.detach().clone().to(device) | ||||||
|  |             meta_embed = meta_embed.detach().clone().to(device) | ||||||
|             if self._append_meta_timestamps["fixed"] is None: |             if self._append_meta_timestamps["fixed"] is None: | ||||||
|                 self._append_meta_timestamps["fixed"] = timestamp |                 self._append_meta_timestamps["fixed"] = timestamp | ||||||
|             else: |             else: | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| import os | import os | ||||||
|  | from pathlib import Path | ||||||
| import abc | import abc | ||||||
| import tempfile | import tempfile | ||||||
| import warnings | import warnings | ||||||
| @@ -90,6 +91,10 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|                 total += buf.numel() |                 total += buf.numel() | ||||||
|         return total |         return total | ||||||
|  |  | ||||||
|  |     def set_best_dir(self, xdir): | ||||||
|  |         self._meta_info[BEST_DIR_KEY] = str(xdir) | ||||||
|  |         Path(xdir).mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|     def save_best(self, score): |     def save_best(self, score): | ||||||
|         if BEST_DIR_KEY not in self._meta_info: |         if BEST_DIR_KEY not in self._meta_info: | ||||||
|             tempdir = tempfile.mkdtemp("-xlayers") |             tempdir = tempfile.mkdtemp("-xlayers") | ||||||
| @@ -97,7 +102,7 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|         if BEST_SCORE_KEY not in self._meta_info: |         if BEST_SCORE_KEY not in self._meta_info: | ||||||
|             self._meta_info[BEST_SCORE_KEY] = None |             self._meta_info[BEST_SCORE_KEY] = None | ||||||
|         best_score = self._meta_info[BEST_SCORE_KEY] |         best_score = self._meta_info[BEST_SCORE_KEY] | ||||||
|         if best_score is None or best_score < score: |         if best_score is None or best_score <= score: | ||||||
|             best_save_path = os.path.join( |             best_save_path = os.path.join( | ||||||
|                 self._meta_info[BEST_DIR_KEY], |                 self._meta_info[BEST_DIR_KEY], | ||||||
|                 "best-{:}.pth".format(self.__class__.__name__), |                 "best-{:}.pth".format(self.__class__.__name__), | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user