complete maml and remove vis_compare_algo
This commit is contained in:
		| @@ -40,7 +40,7 @@ class MAML: | |||||||
|             self.network.parameters(), lr=meta_lr, amsgrad=True |             self.network.parameters(), lr=meta_lr, amsgrad=True | ||||||
|         ) |         ) | ||||||
|         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|             optimizer, |             self.meta_optimizer, | ||||||
|             milestones=[ |             milestones=[ | ||||||
|                 int(epochs * 0.25), |                 int(epochs * 0.25), | ||||||
|                 int(epochs * 0.5), |                 int(epochs * 0.5), | ||||||
| @@ -50,8 +50,8 @@ class MAML: | |||||||
|         ) |         ) | ||||||
|         self.inner_lr = inner_lr |         self.inner_lr = inner_lr | ||||||
|         self.inner_step = inner_step |         self.inner_step = inner_step | ||||||
|         self._best_info = dict(state_dict=None, score=None) |         self._best_info = dict(state_dict=None, iepoch=None, score=None) | ||||||
|         print("There are {:} weights.".format(w_container.numel())) |         print("There are {:} weights.".format(self.network.get_w_container().numel())) | ||||||
|  |  | ||||||
|     def adapt(self, dataset): |     def adapt(self, dataset): | ||||||
|         # create a container for the future timestamp |         # create a container for the future timestamp | ||||||
| @@ -61,7 +61,6 @@ class MAML: | |||||||
|             y_hat = self.network.forward_with_container(dataset.x, container) |             y_hat = self.network.forward_with_container(dataset.x, container) | ||||||
|             loss = self.criterion(y_hat, dataset.y) |             loss = self.criterion(y_hat, dataset.y) | ||||||
|             grads = torch.autograd.grad(loss, container.parameters()) |             grads = torch.autograd.grad(loss, container.parameters()) | ||||||
|  |  | ||||||
|             container = container.additive([-self.inner_lr * grad for grad in grads]) |             container = container.additive([-self.inner_lr * grad for grad in grads]) | ||||||
|         return container |         return container | ||||||
|  |  | ||||||
| @@ -73,23 +72,34 @@ class MAML: | |||||||
|         return y_hat |         return y_hat | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         torch.nn.utils.clip_grad_norm_(self.container.parameters(), 1.0) |         torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) | ||||||
|         self.meta_optimizer.step() |         self.meta_optimizer.step() | ||||||
|         self.meta_lr_scheduler.step() |         self.meta_lr_scheduler.step() | ||||||
|  |  | ||||||
|     def zero_grad(self): |     def zero_grad(self): | ||||||
|         self.meta_optimizer.zero_grad() |         self.meta_optimizer.zero_grad() | ||||||
|  |  | ||||||
|     def save_best(self, network, score): |     def load_state_dict(self, state_dict): | ||||||
|  |         self.criterion.load_state_dict(state_dict["criterion"]) | ||||||
|  |         self.network.load_state_dict(state_dict["network"]) | ||||||
|  |         self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) | ||||||
|  |         self.meta_lr_scheduler.load_state_dict(state_dict["meta_lr_scheduler"]) | ||||||
|  |  | ||||||
|  |     def save_best(self, iepoch, score): | ||||||
|         if self._best_info["score"] is None or self._best_info["score"] < score: |         if self._best_info["score"] is None or self._best_info["score"] < score: | ||||||
|             state_dict = dict( |             state_dict = dict( | ||||||
|                 criterion=criterion, |                 criterion=self.criterion.state_dict(), | ||||||
|                 network=network.state_dict(), |                 network=self.network.state_dict(), | ||||||
|                 meta_optimizer=self.meta_optimizer.state_dict(), |                 meta_optimizer=self.meta_optimizer.state_dict(), | ||||||
|                 meta_lr_scheduler=self.meta_lr_scheduler.state_dict(), |                 meta_lr_scheduler=self.meta_lr_scheduler.state_dict(), | ||||||
|             ) |             ) | ||||||
|             self._best_info["state_dict"] = state_dict |             self._best_info["state_dict"] = state_dict | ||||||
|             self._best_info["score"] = score |             self._best_info["score"] = score | ||||||
|  |             self._best_info["iepoch"] = iepoch | ||||||
|  |             is_best = True | ||||||
|  |         else: | ||||||
|  |             is_best = False | ||||||
|  |         return self._best_info, is_best | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
| @@ -111,8 +121,9 @@ def main(args): | |||||||
|  |  | ||||||
|     # meta-training |     # meta-training | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|     for iepoch in range(args.epochs): |     # for iepoch in range(args.epochs): | ||||||
|  |     iepoch = 0 | ||||||
|  |     while iepoch < args.epochs: | ||||||
|         need_time = "Time Left: {:}".format( |         need_time = "Time Left: {:}".format( | ||||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
|         ) |         ) | ||||||
| @@ -122,9 +133,10 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         maml.zero_grad() |         maml.zero_grad() | ||||||
|         meta_losses = [] |         batch_indexes, meta_losses = [], [] | ||||||
|         for ibatch in range(args.meta_batch): |         for ibatch in range(args.meta_batch): | ||||||
|             sampled_timestamp = random.randint(0, train_time_bar) |             sampled_timestamp = random.randint(0, train_time_bar) | ||||||
|  |             batch_indexes.append("{:5d}".format(sampled_timestamp)) | ||||||
|             past_dataset = TimeData( |             past_dataset = TimeData( | ||||||
|                 sampled_timestamp, |                 sampled_timestamp, | ||||||
|                 env_info["{:}-x".format(sampled_timestamp)], |                 env_info["{:}-x".format(sampled_timestamp)], | ||||||
| @@ -135,7 +147,7 @@ def main(args): | |||||||
|                 env_info["{:}-x".format(sampled_timestamp + 1)], |                 env_info["{:}-x".format(sampled_timestamp + 1)], | ||||||
|                 env_info["{:}-y".format(sampled_timestamp + 1)], |                 env_info["{:}-y".format(sampled_timestamp + 1)], | ||||||
|             ) |             ) | ||||||
|             future_container = maml.adapt(model, past_dataset) |             future_container = maml.adapt(past_dataset) | ||||||
|             future_y_hat = maml.predict(future_dataset.x, future_container) |             future_y_hat = maml.predict(future_dataset.x, future_container) | ||||||
|             future_loss = maml.criterion(future_y_hat, future_dataset.y) |             future_loss = maml.criterion(future_y_hat, future_dataset.y) | ||||||
|             meta_losses.append(future_loss) |             meta_losses.append(future_loss) | ||||||
| @@ -143,14 +155,53 @@ def main(args): | |||||||
|         meta_loss.backward() |         meta_loss.backward() | ||||||
|         maml.step() |         maml.step() | ||||||
|  |  | ||||||
|         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) |         logger.log( | ||||||
|  |             "meta-loss: {:.4f}  batch: {:}".format( | ||||||
|  |                 meta_loss.item(), ",".join(batch_indexes) | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         best_info, is_best = maml.save_best(iepoch, -meta_loss.item()) | ||||||
|  |         if is_best: | ||||||
|  |             save_checkpoint(best_info, logger.path("best"), logger) | ||||||
|  |             logger.log("Save the best into {:}".format(logger.path("best"))) | ||||||
|  |         if iepoch >= 10 and ( | ||||||
|  |             torch.isnan(meta_loss).item() or meta_loss.item() >= args.fail_thresh | ||||||
|  |         ): | ||||||
|  |             xdata = torch.load(logger.path("best")) | ||||||
|  |             maml.load_state_dict(xdata["state_dict"]) | ||||||
|  |             iepoch = xdata["iepoch"] | ||||||
|  |             logger.log( | ||||||
|  |                 "The training failed, re-use the previous best epoch [{:}]".format( | ||||||
|  |                     iepoch | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             iepoch = iepoch + 1 | ||||||
|         per_epoch_time.update(time.time() - start_time) |         per_epoch_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|     import pdb |     w_container_per_epoch = dict() | ||||||
|  |     for idx in range(1, env_info["total"]): | ||||||
|     pdb.set_trace() |         past_dataset = TimeData( | ||||||
|  |             idx - 1, | ||||||
|  |             env_info["{:}-x".format(idx - 1)], | ||||||
|  |             env_info["{:}-y".format(idx - 1)], | ||||||
|  |         ) | ||||||
|  |         current_container = maml.adapt(past_dataset) | ||||||
|  |         w_container_per_epoch[idx] = current_container.no_grad_clone() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             current_x = env_info["{:}-x".format(idx)] | ||||||
|  |             current_y = env_info["{:}-y".format(idx)] | ||||||
|  |             current_y_hat = maml.predict(current_x, w_container_per_epoch[idx]) | ||||||
|  |             current_loss = maml.criterion(current_y_hat, current_y) | ||||||
|  |         logger.log( | ||||||
|  |             "meta-test: [{:03d}] -> loss={:.4f}".format(idx, current_loss.item()) | ||||||
|  |         ) | ||||||
|  |     save_checkpoint( | ||||||
|  |         {"w_container_per_epoch": w_container_per_epoch}, | ||||||
|  |         logger.path(None) / "final-ckp.pth", | ||||||
|  |         logger, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
|     logger.close() |     logger.close() | ||||||
| @@ -179,9 +230,15 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_lr", |         "--meta_lr", | ||||||
|         type=float, |         type=float, | ||||||
|         default=0.1, |         default=0.05, | ||||||
|         help="The learning rate for the MAML optimizer (default is Adam)", |         help="The learning rate for the MAML optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--fail_thresh", | ||||||
|  |         type=float, | ||||||
|  |         default=1000, | ||||||
|  |         help="The threshold for the failure, which we reuse the previous best model", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--inner_lr", |         "--inner_lr", | ||||||
|         type=float, |         type=float, | ||||||
|   | |||||||
| @@ -221,76 +221,7 @@ def visualize_env(save_dir, version): | |||||||
|     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) |     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"): | def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | ||||||
|     save_dir = Path(str(save_dir)) |  | ||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |  | ||||||
|  |  | ||||||
|     dpi, width, height = 30, 1800, 1400 |  | ||||||
|     figsize = width / float(dpi), height / float(dpi) |  | ||||||
|     LabelSize, LegendFontsize, font_gap = 80, 80, 5 |  | ||||||
|  |  | ||||||
|     cache_path = Path(alg_dir) / "env-info.pth" |  | ||||||
|     assert cache_path.exists(), "{:} does not exist".format(cache_path) |  | ||||||
|     env_info = torch.load(cache_path) |  | ||||||
|  |  | ||||||
|     alg_name2dir = OrderedDict() |  | ||||||
|     alg_name2dir["Optimal"] = "use-same-timestamp" |  | ||||||
|     alg_name2dir["History SL"] = "use-all-past-data" |  | ||||||
|     colors = ["r", "g"] |  | ||||||
|  |  | ||||||
|     dynamic_env = env_info["dynamic_env"] |  | ||||||
|     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp |  | ||||||
|  |  | ||||||
|     linewidths = 10 |  | ||||||
|     for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( |  | ||||||
|         tqdm(dynamic_env, ncols=50) |  | ||||||
|     ): |  | ||||||
|         if idx == 0: |  | ||||||
|             continue |  | ||||||
|         fig = plt.figure(figsize=figsize) |  | ||||||
|         cur_ax = fig.add_subplot(1, 1, 1) |  | ||||||
|  |  | ||||||
|         # the data |  | ||||||
|         allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy() |  | ||||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") |  | ||||||
|  |  | ||||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): |  | ||||||
|             ckp_path = ( |  | ||||||
|                 Path(alg_dir) |  | ||||||
|                 / xdir |  | ||||||
|                 / "{:04d}-{:04d}.pth".format(idx, env_info["total"]) |  | ||||||
|             ) |  | ||||||
|             assert ckp_path.exists() |  | ||||||
|             ckp_data = torch.load(ckp_path) |  | ||||||
|             with torch.no_grad(): |  | ||||||
|                 predicts = ckp_data["model"](ori_allx) |  | ||||||
|                 predicts = predicts.cpu().view(-1).numpy() |  | ||||||
|             plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg) |  | ||||||
|  |  | ||||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) |  | ||||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) |  | ||||||
|         for tick in cur_ax.xaxis.get_major_ticks(): |  | ||||||
|             tick.label.set_fontsize(LabelSize - font_gap) |  | ||||||
|             tick.label.set_rotation(10) |  | ||||||
|         for tick in cur_ax.yaxis.get_major_ticks(): |  | ||||||
|             tick.label.set_fontsize(LabelSize - font_gap) |  | ||||||
|         cur_ax.set_xlim(-10, 10) |  | ||||||
|         cur_ax.set_ylim(-60, 60) |  | ||||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) |  | ||||||
|  |  | ||||||
|         save_path = save_dir / "{:05d}".format(idx) |  | ||||||
|         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") |  | ||||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") |  | ||||||
|         plt.close("all") |  | ||||||
|     save_dir = save_dir.resolve() |  | ||||||
|     base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( |  | ||||||
|         xdir=save_dir, w=width, h=height |  | ||||||
|     ) |  | ||||||
|     os.system("{:} {xdir}/compare-alg.mp4".format(base_cmd, xdir=save_dir)) |  | ||||||
|     os.system("{:} {xdir}/compare-alg.webm".format(base_cmd, xdir=save_dir)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): |  | ||||||
|     save_dir = Path(str(save_dir)) |     save_dir = Path(str(save_dir)) | ||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |     save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
| @@ -298,16 +229,21 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|     figsize = width / float(dpi), height / float(dpi) |     figsize = width / float(dpi), height / float(dpi) | ||||||
|     LabelSize, LegendFontsize, font_gap = 80, 80, 5 |     LabelSize, LegendFontsize, font_gap = 80, 80, 5 | ||||||
|  |  | ||||||
|     cache_path = Path(alg_dir) / "env-info.pth" |     cache_path = Path(alg_dir) / "env-{:}-info.pth".format(version) | ||||||
|     assert cache_path.exists(), "{:} does not exist".format(cache_path) |     assert cache_path.exists(), "{:} does not exist".format(cache_path) | ||||||
|     env_info = torch.load(cache_path) |     env_info = torch.load(cache_path) | ||||||
|  |  | ||||||
|     alg_name2dir = OrderedDict() |     alg_name2dir = OrderedDict() | ||||||
|     alg_name2dir["Optimal"] = "use-same-timestamp" |     alg_name2dir["Optimal"] = "use-same-timestamp" | ||||||
|     # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" |     alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||||
|  |     alg_name2dir["MAML"] = "use-maml-s1" | ||||||
|     alg_name2all_containers = OrderedDict() |     alg_name2all_containers = OrderedDict() | ||||||
|  |     if version == "v1": | ||||||
|  |         poststr = "v1-d16" | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Invalid version: {:}".format(version)) | ||||||
|     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): |     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||||
|         ckp_path = Path(alg_dir) / xdir / "final-ckp.pth" |         ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth" | ||||||
|         xdata = torch.load(ckp_path) |         xdata = torch.load(ckp_path) | ||||||
|         alg_name2all_containers[alg] = xdata["w_container_per_epoch"] |         alg_name2all_containers[alg] = xdata["w_container_per_epoch"] | ||||||
|     # load the basic model |     # load the basic model | ||||||
| @@ -320,7 +256,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     alg2xs, alg2ys = defaultdict(list), defaultdict(list) |     alg2xs, alg2ys = defaultdict(list), defaultdict(list) | ||||||
|     colors = ["r", "g"] |     colors = ["r", "g", "b"] | ||||||
|  |  | ||||||
|     dynamic_env = env_info["dynamic_env"] |     dynamic_env = env_info["dynamic_env"] | ||||||
|     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp |     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||||
| @@ -339,15 +275,6 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") |         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") | ||||||
|  |  | ||||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): |         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||||
|             """ |  | ||||||
|             ckp_path = ( |  | ||||||
|                 Path(alg_dir) |  | ||||||
|                 / xdir |  | ||||||
|                 / "{:04d}-{:04d}.pth".format(idx, env_info["total"]) |  | ||||||
|             ) |  | ||||||
|             assert ckp_path.exists() |  | ||||||
|             ckp_data = torch.load(ckp_path) |  | ||||||
|             """ |  | ||||||
|             with torch.no_grad(): |             with torch.no_grad(): | ||||||
|                 # predicts = ckp_data["model"](ori_allx) |                 # predicts = ckp_data["model"](ori_allx) | ||||||
|                 predicts = model.forward_with_container( |                 predicts = model.forward_with_container( | ||||||
| @@ -369,8 +296,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|             tick.label.set_rotation(10) |             tick.label.set_rotation(10) | ||||||
|         for tick in cur_ax.yaxis.get_major_ticks(): |         for tick in cur_ax.yaxis.get_major_ticks(): | ||||||
|             tick.label.set_fontsize(LabelSize - font_gap) |             tick.label.set_fontsize(LabelSize - font_gap) | ||||||
|         cur_ax.set_xlim(-10, 10) |         if version == "v1": | ||||||
|         cur_ax.set_ylim(-60, 60) |             cur_ax.set_xlim(-2, 2) | ||||||
|  |             cur_ax.set_ylim(-8, 8) | ||||||
|  |         elif version == "v2": | ||||||
|  |             cur_ax.set_xlim(-10, 10) | ||||||
|  |             cur_ax.set_ylim(-60, 60) | ||||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) |         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||||
|  |  | ||||||
|         # the trajectory data |         # the trajectory data | ||||||
| @@ -398,16 +329,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|         cur_ax.set_ylim(0, 10) |         cur_ax.set_ylim(0, 10) | ||||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) |         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||||
|  |  | ||||||
|         save_path = save_dir / "{:05d}".format(idx) |         save_path = save_dir / "v{:}-{:05d}".format(version, idx) | ||||||
|         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") |         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") | ||||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") |         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||||
|         plt.close("all") |         plt.close("all") | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( |     base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( | ||||||
|         xdir=save_dir, w=width, h=height |         xdir=save_dir, w=width, h=height, ver=version | ||||||
|  |     ) | ||||||
|  |     os.system( | ||||||
|  |         "{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version) | ||||||
|  |     ) | ||||||
|  |     os.system( | ||||||
|  |         "{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version) | ||||||
|     ) |     ) | ||||||
|     os.system("{:} {xdir}/compare-alg.mp4".format(base_cmd, xdir=save_dir)) |  | ||||||
|     os.system("{:} {xdir}/compare-alg.webm".format(base_cmd, xdir=save_dir)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
| @@ -421,8 +356,7 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") |     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||||
|     visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") |     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||||
|     # compare_algs_v2(os.path.join(args.save_dir, "compare-alg-v2")) |     compare_algs(os.path.join(args.save_dir, "compare-alg-v2"), "v1") | ||||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) |     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||||
|     # compare_algs(os.path.join(args.save_dir, "compare-alg")) |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user