complete maml and remove vis_compare_algo
This commit is contained in:
		| @@ -40,7 +40,7 @@ class MAML: | ||||
|             self.network.parameters(), lr=meta_lr, amsgrad=True | ||||
|         ) | ||||
|         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             self.meta_optimizer, | ||||
|             milestones=[ | ||||
|                 int(epochs * 0.25), | ||||
|                 int(epochs * 0.5), | ||||
| @@ -50,8 +50,8 @@ class MAML: | ||||
|         ) | ||||
|         self.inner_lr = inner_lr | ||||
|         self.inner_step = inner_step | ||||
|         self._best_info = dict(state_dict=None, score=None) | ||||
|         print("There are {:} weights.".format(w_container.numel())) | ||||
|         self._best_info = dict(state_dict=None, iepoch=None, score=None) | ||||
|         print("There are {:} weights.".format(self.network.get_w_container().numel())) | ||||
|  | ||||
|     def adapt(self, dataset): | ||||
|         # create a container for the future timestamp | ||||
| @@ -61,7 +61,6 @@ class MAML: | ||||
|             y_hat = self.network.forward_with_container(dataset.x, container) | ||||
|             loss = self.criterion(y_hat, dataset.y) | ||||
|             grads = torch.autograd.grad(loss, container.parameters()) | ||||
|  | ||||
|             container = container.additive([-self.inner_lr * grad for grad in grads]) | ||||
|         return container | ||||
|  | ||||
| @@ -73,23 +72,34 @@ class MAML: | ||||
|         return y_hat | ||||
|  | ||||
|     def step(self): | ||||
|         torch.nn.utils.clip_grad_norm_(self.container.parameters(), 1.0) | ||||
|         torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) | ||||
|         self.meta_optimizer.step() | ||||
|         self.meta_lr_scheduler.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|  | ||||
|     def save_best(self, network, score): | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.criterion.load_state_dict(state_dict["criterion"]) | ||||
|         self.network.load_state_dict(state_dict["network"]) | ||||
|         self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) | ||||
|         self.meta_lr_scheduler.load_state_dict(state_dict["meta_lr_scheduler"]) | ||||
|  | ||||
|     def save_best(self, iepoch, score): | ||||
|         if self._best_info["score"] is None or self._best_info["score"] < score: | ||||
|             state_dict = dict( | ||||
|                 criterion=criterion, | ||||
|                 network=network.state_dict(), | ||||
|                 criterion=self.criterion.state_dict(), | ||||
|                 network=self.network.state_dict(), | ||||
|                 meta_optimizer=self.meta_optimizer.state_dict(), | ||||
|                 meta_lr_scheduler=self.meta_lr_scheduler.state_dict(), | ||||
|             ) | ||||
|             self._best_info["state_dict"] = state_dict | ||||
|             self._best_info["score"] = score | ||||
|             self._best_info["iepoch"] = iepoch | ||||
|             is_best = True | ||||
|         else: | ||||
|             is_best = False | ||||
|         return self._best_info, is_best | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
| @@ -111,8 +121,9 @@ def main(args): | ||||
|  | ||||
|     # meta-training | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|  | ||||
|     # for iepoch in range(args.epochs): | ||||
|     iepoch = 0 | ||||
|     while iepoch < args.epochs: | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
| @@ -122,9 +133,10 @@ def main(args): | ||||
|         ) | ||||
|  | ||||
|         maml.zero_grad() | ||||
|         meta_losses = [] | ||||
|         batch_indexes, meta_losses = [], [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             sampled_timestamp = random.randint(0, train_time_bar) | ||||
|             batch_indexes.append("{:5d}".format(sampled_timestamp)) | ||||
|             past_dataset = TimeData( | ||||
|                 sampled_timestamp, | ||||
|                 env_info["{:}-x".format(sampled_timestamp)], | ||||
| @@ -135,7 +147,7 @@ def main(args): | ||||
|                 env_info["{:}-x".format(sampled_timestamp + 1)], | ||||
|                 env_info["{:}-y".format(sampled_timestamp + 1)], | ||||
|             ) | ||||
|             future_container = maml.adapt(model, past_dataset) | ||||
|             future_container = maml.adapt(past_dataset) | ||||
|             future_y_hat = maml.predict(future_dataset.x, future_container) | ||||
|             future_loss = maml.criterion(future_y_hat, future_dataset.y) | ||||
|             meta_losses.append(future_loss) | ||||
| @@ -143,14 +155,53 @@ def main(args): | ||||
|         meta_loss.backward() | ||||
|         maml.step() | ||||
|  | ||||
|         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||
|  | ||||
|         logger.log( | ||||
|             "meta-loss: {:.4f}  batch: {:}".format( | ||||
|                 meta_loss.item(), ",".join(batch_indexes) | ||||
|             ) | ||||
|         ) | ||||
|         best_info, is_best = maml.save_best(iepoch, -meta_loss.item()) | ||||
|         if is_best: | ||||
|             save_checkpoint(best_info, logger.path("best"), logger) | ||||
|             logger.log("Save the best into {:}".format(logger.path("best"))) | ||||
|         if iepoch >= 10 and ( | ||||
|             torch.isnan(meta_loss).item() or meta_loss.item() >= args.fail_thresh | ||||
|         ): | ||||
|             xdata = torch.load(logger.path("best")) | ||||
|             maml.load_state_dict(xdata["state_dict"]) | ||||
|             iepoch = xdata["iepoch"] | ||||
|             logger.log( | ||||
|                 "The training failed, re-use the previous best epoch [{:}]".format( | ||||
|                     iepoch | ||||
|                 ) | ||||
|             ) | ||||
|         else: | ||||
|             iepoch = iepoch + 1 | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|     w_container_per_epoch = dict() | ||||
|     for idx in range(1, env_info["total"]): | ||||
|         past_dataset = TimeData( | ||||
|             idx - 1, | ||||
|             env_info["{:}-x".format(idx - 1)], | ||||
|             env_info["{:}-y".format(idx - 1)], | ||||
|         ) | ||||
|         current_container = maml.adapt(past_dataset) | ||||
|         w_container_per_epoch[idx] = current_container.no_grad_clone() | ||||
|         with torch.no_grad(): | ||||
|             current_x = env_info["{:}-x".format(idx)] | ||||
|             current_y = env_info["{:}-y".format(idx)] | ||||
|             current_y_hat = maml.predict(current_x, w_container_per_epoch[idx]) | ||||
|             current_loss = maml.criterion(current_y_hat, current_y) | ||||
|         logger.log( | ||||
|             "meta-test: [{:03d}] -> loss={:.4f}".format(idx, current_loss.item()) | ||||
|         ) | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
| @@ -179,9 +230,15 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--meta_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         default=0.05, | ||||
|         help="The learning rate for the MAML optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--fail_thresh", | ||||
|         type=float, | ||||
|         default=1000, | ||||
|         help="The threshold for the failure, which we reuse the previous best model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_lr", | ||||
|         type=float, | ||||
|   | ||||
| @@ -221,76 +221,7 @@ def visualize_env(save_dir, version): | ||||
|     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) | ||||
|  | ||||
|  | ||||
| def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dpi, width, height = 30, 1800, 1400 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||
|  | ||||
|     cache_path = Path(alg_dir) / "env-info.pth" | ||||
|     assert cache_path.exists(), "{:} does not exist".format(cache_path) | ||||
|     env_info = torch.load(cache_path) | ||||
|  | ||||
|     alg_name2dir = OrderedDict() | ||||
|     alg_name2dir["Optimal"] = "use-same-timestamp" | ||||
|     alg_name2dir["History SL"] = "use-all-past-data" | ||||
|     colors = ["r", "g"] | ||||
|  | ||||
|     dynamic_env = env_info["dynamic_env"] | ||||
|     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||
|  | ||||
|     linewidths = 10 | ||||
|     for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( | ||||
|         tqdm(dynamic_env, ncols=50) | ||||
|     ): | ||||
|         if idx == 0: | ||||
|             continue | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|         cur_ax = fig.add_subplot(1, 1, 1) | ||||
|  | ||||
|         # the data | ||||
|         allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy() | ||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") | ||||
|  | ||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|             ckp_path = ( | ||||
|                 Path(alg_dir) | ||||
|                 / xdir | ||||
|                 / "{:04d}-{:04d}.pth".format(idx, env_info["total"]) | ||||
|             ) | ||||
|             assert ckp_path.exists() | ||||
|             ckp_data = torch.load(ckp_path) | ||||
|             with torch.no_grad(): | ||||
|                 predicts = ckp_data["model"](ori_allx) | ||||
|                 predicts = predicts.cpu().view(-1).numpy() | ||||
|             plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg) | ||||
|  | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(-10, 10) | ||||
|         cur_ax.set_ylim(-60, 60) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         save_path = save_dir / "{:05d}".format(idx) | ||||
|         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir, w=width, h=height | ||||
|     ) | ||||
|     os.system("{:} {xdir}/compare-alg.mp4".format(base_cmd, xdir=save_dir)) | ||||
|     os.system("{:} {xdir}/compare-alg.webm".format(base_cmd, xdir=save_dir)) | ||||
|  | ||||
|  | ||||
| def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
| def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
| @@ -298,16 +229,21 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||
|  | ||||
|     cache_path = Path(alg_dir) / "env-info.pth" | ||||
|     cache_path = Path(alg_dir) / "env-{:}-info.pth".format(version) | ||||
|     assert cache_path.exists(), "{:} does not exist".format(cache_path) | ||||
|     env_info = torch.load(cache_path) | ||||
|  | ||||
|     alg_name2dir = OrderedDict() | ||||
|     alg_name2dir["Optimal"] = "use-same-timestamp" | ||||
|     # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||
|     alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||
|     alg_name2dir["MAML"] = "use-maml-s1" | ||||
|     alg_name2all_containers = OrderedDict() | ||||
|     if version == "v1": | ||||
|         poststr = "v1-d16" | ||||
|     else: | ||||
|         raise ValueError("Invalid version: {:}".format(version)) | ||||
|     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|         ckp_path = Path(alg_dir) / xdir / "final-ckp.pth" | ||||
|         ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth" | ||||
|         xdata = torch.load(ckp_path) | ||||
|         alg_name2all_containers[alg] = xdata["w_container_per_epoch"] | ||||
|     # load the basic model | ||||
| @@ -320,7 +256,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|     ) | ||||
|  | ||||
|     alg2xs, alg2ys = defaultdict(list), defaultdict(list) | ||||
|     colors = ["r", "g"] | ||||
|     colors = ["r", "g", "b"] | ||||
|  | ||||
|     dynamic_env = env_info["dynamic_env"] | ||||
|     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||
| @@ -339,15 +275,6 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") | ||||
|  | ||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|             """ | ||||
|             ckp_path = ( | ||||
|                 Path(alg_dir) | ||||
|                 / xdir | ||||
|                 / "{:04d}-{:04d}.pth".format(idx, env_info["total"]) | ||||
|             ) | ||||
|             assert ckp_path.exists() | ||||
|             ckp_data = torch.load(ckp_path) | ||||
|             """ | ||||
|             with torch.no_grad(): | ||||
|                 # predicts = ckp_data["model"](ori_allx) | ||||
|                 predicts = model.forward_with_container( | ||||
| @@ -369,8 +296,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(-10, 10) | ||||
|         cur_ax.set_ylim(-60, 60) | ||||
|         if version == "v1": | ||||
|             cur_ax.set_xlim(-2, 2) | ||||
|             cur_ax.set_ylim(-8, 8) | ||||
|         elif version == "v2": | ||||
|             cur_ax.set_xlim(-10, 10) | ||||
|             cur_ax.set_ylim(-60, 60) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         # the trajectory data | ||||
| @@ -398,16 +329,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|         cur_ax.set_ylim(0, 10) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         save_path = save_dir / "{:05d}".format(idx) | ||||
|         save_path = save_dir / "v{:}-{:05d}".format(version, idx) | ||||
|         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir, w=width, h=height | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir, w=width, h=height, ver=version | ||||
|     ) | ||||
|     os.system( | ||||
|         "{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version) | ||||
|     ) | ||||
|     os.system( | ||||
|         "{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version) | ||||
|     ) | ||||
|     os.system("{:} {xdir}/compare-alg.mp4".format(base_cmd, xdir=save_dir)) | ||||
|     os.system("{:} {xdir}/compare-alg.webm".format(base_cmd, xdir=save_dir)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| @@ -421,8 +356,7 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||
|     visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||
|     # compare_algs_v2(os.path.join(args.save_dir, "compare-alg-v2")) | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||
|     compare_algs(os.path.join(args.save_dir, "compare-alg-v2"), "v1") | ||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||
|     # compare_algs(os.path.join(args.save_dir, "compare-alg")) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user