From 147da98f9434a3c9ca6a6e3f3072935cedb5d6ec Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 10 May 2021 11:19:18 +0800 Subject: [PATCH] complete maml and remove vis_compare_algo --- exps/LFNA/basic-maml.py | 93 ++++++++++++++++++++++------ exps/LFNA/vis-synthetic.py | 122 +++++++++---------------------------- 2 files changed, 103 insertions(+), 112 deletions(-) diff --git a/exps/LFNA/basic-maml.py b/exps/LFNA/basic-maml.py index 3dcd7b2..b86adff 100644 --- a/exps/LFNA/basic-maml.py +++ b/exps/LFNA/basic-maml.py @@ -40,7 +40,7 @@ class MAML: self.network.parameters(), lr=meta_lr, amsgrad=True ) self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, + self.meta_optimizer, milestones=[ int(epochs * 0.25), int(epochs * 0.5), @@ -50,8 +50,8 @@ class MAML: ) self.inner_lr = inner_lr self.inner_step = inner_step - self._best_info = dict(state_dict=None, score=None) - print("There are {:} weights.".format(w_container.numel())) + self._best_info = dict(state_dict=None, iepoch=None, score=None) + print("There are {:} weights.".format(self.network.get_w_container().numel())) def adapt(self, dataset): # create a container for the future timestamp @@ -61,7 +61,6 @@ class MAML: y_hat = self.network.forward_with_container(dataset.x, container) loss = self.criterion(y_hat, dataset.y) grads = torch.autograd.grad(loss, container.parameters()) - container = container.additive([-self.inner_lr * grad for grad in grads]) return container @@ -73,23 +72,34 @@ class MAML: return y_hat def step(self): - torch.nn.utils.clip_grad_norm_(self.container.parameters(), 1.0) + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) self.meta_optimizer.step() self.meta_lr_scheduler.step() def zero_grad(self): self.meta_optimizer.zero_grad() - def save_best(self, network, score): + def load_state_dict(self, state_dict): + self.criterion.load_state_dict(state_dict["criterion"]) + self.network.load_state_dict(state_dict["network"]) + self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) + self.meta_lr_scheduler.load_state_dict(state_dict["meta_lr_scheduler"]) + + def save_best(self, iepoch, score): if self._best_info["score"] is None or self._best_info["score"] < score: state_dict = dict( - criterion=criterion, - network=network.state_dict(), + criterion=self.criterion.state_dict(), + network=self.network.state_dict(), meta_optimizer=self.meta_optimizer.state_dict(), meta_lr_scheduler=self.meta_lr_scheduler.state_dict(), ) self._best_info["state_dict"] = state_dict self._best_info["score"] = score + self._best_info["iepoch"] = iepoch + is_best = True + else: + is_best = False + return self._best_info, is_best def main(args): @@ -111,8 +121,9 @@ def main(args): # meta-training per_epoch_time, start_time = AverageMeter(), time.time() - for iepoch in range(args.epochs): - + # for iepoch in range(args.epochs): + iepoch = 0 + while iepoch < args.epochs: need_time = "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) ) @@ -122,9 +133,10 @@ def main(args): ) maml.zero_grad() - meta_losses = [] + batch_indexes, meta_losses = [], [] for ibatch in range(args.meta_batch): sampled_timestamp = random.randint(0, train_time_bar) + batch_indexes.append("{:5d}".format(sampled_timestamp)) past_dataset = TimeData( sampled_timestamp, env_info["{:}-x".format(sampled_timestamp)], @@ -135,7 +147,7 @@ def main(args): env_info["{:}-x".format(sampled_timestamp + 1)], env_info["{:}-y".format(sampled_timestamp + 1)], ) - future_container = maml.adapt(model, past_dataset) + future_container = maml.adapt(past_dataset) future_y_hat = maml.predict(future_dataset.x, future_container) future_loss = maml.criterion(future_y_hat, future_dataset.y) meta_losses.append(future_loss) @@ -143,14 +155,53 @@ def main(args): meta_loss.backward() maml.step() - logger.log("meta-loss: {:.4f}".format(meta_loss.item())) - + logger.log( + "meta-loss: {:.4f} batch: {:}".format( + meta_loss.item(), ",".join(batch_indexes) + ) + ) + best_info, is_best = maml.save_best(iepoch, -meta_loss.item()) + if is_best: + save_checkpoint(best_info, logger.path("best"), logger) + logger.log("Save the best into {:}".format(logger.path("best"))) + if iepoch >= 10 and ( + torch.isnan(meta_loss).item() or meta_loss.item() >= args.fail_thresh + ): + xdata = torch.load(logger.path("best")) + maml.load_state_dict(xdata["state_dict"]) + iepoch = xdata["iepoch"] + logger.log( + "The training failed, re-use the previous best epoch [{:}]".format( + iepoch + ) + ) + else: + iepoch = iepoch + 1 per_epoch_time.update(time.time() - start_time) start_time = time.time() - import pdb - - pdb.set_trace() + w_container_per_epoch = dict() + for idx in range(1, env_info["total"]): + past_dataset = TimeData( + idx - 1, + env_info["{:}-x".format(idx - 1)], + env_info["{:}-y".format(idx - 1)], + ) + current_container = maml.adapt(past_dataset) + w_container_per_epoch[idx] = current_container.no_grad_clone() + with torch.no_grad(): + current_x = env_info["{:}-x".format(idx)] + current_y = env_info["{:}-y".format(idx)] + current_y_hat = maml.predict(current_x, w_container_per_epoch[idx]) + current_loss = maml.criterion(current_y_hat, current_y) + logger.log( + "meta-test: [{:03d}] -> loss={:.4f}".format(idx, current_loss.item()) + ) + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) logger.log("-" * 200 + "\n") logger.close() @@ -179,9 +230,15 @@ if __name__ == "__main__": parser.add_argument( "--meta_lr", type=float, - default=0.1, + default=0.05, help="The learning rate for the MAML optimizer (default is Adam)", ) + parser.add_argument( + "--fail_thresh", + type=float, + default=1000, + help="The threshold for the failure, which we reuse the previous best model", + ) parser.add_argument( "--inner_lr", type=float, diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 98be80f..432bb65 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -221,76 +221,7 @@ def visualize_env(save_dir, version): os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) -def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"): - save_dir = Path(str(save_dir)) - save_dir.mkdir(parents=True, exist_ok=True) - - dpi, width, height = 30, 1800, 1400 - figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize, font_gap = 80, 80, 5 - - cache_path = Path(alg_dir) / "env-info.pth" - assert cache_path.exists(), "{:} does not exist".format(cache_path) - env_info = torch.load(cache_path) - - alg_name2dir = OrderedDict() - alg_name2dir["Optimal"] = "use-same-timestamp" - alg_name2dir["History SL"] = "use-all-past-data" - colors = ["r", "g"] - - dynamic_env = env_info["dynamic_env"] - min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp - - linewidths = 10 - for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( - tqdm(dynamic_env, ncols=50) - ): - if idx == 0: - continue - fig = plt.figure(figsize=figsize) - cur_ax = fig.add_subplot(1, 1, 1) - - # the data - allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy() - plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") - - for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): - ckp_path = ( - Path(alg_dir) - / xdir - / "{:04d}-{:04d}.pth".format(idx, env_info["total"]) - ) - assert ckp_path.exists() - ckp_data = torch.load(ckp_path) - with torch.no_grad(): - predicts = ckp_data["model"](ori_allx) - predicts = predicts.cpu().view(-1).numpy() - plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg) - - cur_ax.set_xlabel("X", fontsize=LabelSize) - cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) - for tick in cur_ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize - font_gap) - tick.label.set_rotation(10) - for tick in cur_ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize - font_gap) - cur_ax.set_xlim(-10, 10) - cur_ax.set_ylim(-60, 60) - cur_ax.legend(loc=1, fontsize=LegendFontsize) - - save_path = save_dir / "{:05d}".format(idx) - fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") - fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") - plt.close("all") - save_dir = save_dir.resolve() - base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( - xdir=save_dir, w=width, h=height - ) - os.system("{:} {xdir}/compare-alg.mp4".format(base_cmd, xdir=save_dir)) - os.system("{:} {xdir}/compare-alg.webm".format(base_cmd, xdir=save_dir)) - - -def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): +def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): save_dir = Path(str(save_dir)) save_dir.mkdir(parents=True, exist_ok=True) @@ -298,16 +229,21 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize, font_gap = 80, 80, 5 - cache_path = Path(alg_dir) / "env-info.pth" + cache_path = Path(alg_dir) / "env-{:}-info.pth".format(version) assert cache_path.exists(), "{:} does not exist".format(cache_path) env_info = torch.load(cache_path) alg_name2dir = OrderedDict() alg_name2dir["Optimal"] = "use-same-timestamp" - # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" + alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" + alg_name2dir["MAML"] = "use-maml-s1" alg_name2all_containers = OrderedDict() + if version == "v1": + poststr = "v1-d16" + else: + raise ValueError("Invalid version: {:}".format(version)) for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): - ckp_path = Path(alg_dir) / xdir / "final-ckp.pth" + ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth" xdata = torch.load(ckp_path) alg_name2all_containers[alg] = xdata["w_container_per_epoch"] # load the basic model @@ -320,7 +256,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): ) alg2xs, alg2ys = defaultdict(list), defaultdict(list) - colors = ["r", "g"] + colors = ["r", "g", "b"] dynamic_env = env_info["dynamic_env"] min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp @@ -339,15 +275,6 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): - """ - ckp_path = ( - Path(alg_dir) - / xdir - / "{:04d}-{:04d}.pth".format(idx, env_info["total"]) - ) - assert ckp_path.exists() - ckp_data = torch.load(ckp_path) - """ with torch.no_grad(): # predicts = ckp_data["model"](ori_allx) predicts = model.forward_with_container( @@ -369,8 +296,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): tick.label.set_rotation(10) for tick in cur_ax.yaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize - font_gap) - cur_ax.set_xlim(-10, 10) - cur_ax.set_ylim(-60, 60) + if version == "v1": + cur_ax.set_xlim(-2, 2) + cur_ax.set_ylim(-8, 8) + elif version == "v2": + cur_ax.set_xlim(-10, 10) + cur_ax.set_ylim(-60, 60) cur_ax.legend(loc=1, fontsize=LegendFontsize) # the trajectory data @@ -398,16 +329,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): cur_ax.set_ylim(0, 10) cur_ax.legend(loc=1, fontsize=LegendFontsize) - save_path = save_dir / "{:05d}".format(idx) + save_path = save_dir / "v{:}-{:05d}".format(version, idx) fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") plt.close("all") save_dir = save_dir.resolve() - base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( - xdir=save_dir, w=width, h=height + base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( + xdir=save_dir, w=width, h=height, ver=version + ) + os.system( + "{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version) + ) + os.system( + "{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version) ) - os.system("{:} {xdir}/compare-alg.mp4".format(base_cmd, xdir=save_dir)) - os.system("{:} {xdir}/compare-alg.webm".format(base_cmd, xdir=save_dir)) if __name__ == "__main__": @@ -421,8 +356,7 @@ if __name__ == "__main__": ) args = parser.parse_args() - visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") - visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") - # compare_algs_v2(os.path.join(args.save_dir, "compare-alg-v2")) + # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") + # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") + compare_algs(os.path.join(args.save_dir, "compare-alg-v2"), "v1") # compare_cl(os.path.join(args.save_dir, "compare-cl")) - # compare_algs(os.path.join(args.save_dir, "compare-alg"))