Update GeMOSA v4
This commit is contained in:
		| @@ -2,9 +2,9 @@ | ||||
| # Learning to Generate Model One Step Ahead         # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/main.py --env_version v1 --workers 0 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256 | ||||
| # python exps/GeMOSA/main.py --env_version v2 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256 | ||||
| # python exps/GeMOSA/main.py --env_version v3 --device cuda --lr 0.002 --hidden_dim 32 --meta_batch 256 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
|   | ||||
| @@ -3,7 +3,8 @@ | ||||
| ############################################################################ | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v1                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v3                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v4                     # | ||||
| ############################################################################ | ||||
| import os, sys, copy, random | ||||
| import torch | ||||
| @@ -31,8 +32,8 @@ from xautodl.procedures.metric_utils import MSEMetric | ||||
|  | ||||
|  | ||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label) | ||||
|     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None) | ||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label) | ||||
|     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None) | ||||
|  | ||||
|  | ||||
| def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | ||||
| @@ -186,15 +187,23 @@ def visualize_env(save_dir, version): | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(version=version) | ||||
|     print("env: {:}".format(dynamic_env)) | ||||
|     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||
|     allxs, allys = [], [] | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|     print("env: {:}".format(dynamic_env)) | ||||
|     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||
|     print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||
|     print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||
|     if dynamic_env.meta_info['task'] == 'regression': | ||||
|         allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|         print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||
|         print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||
|     elif dynamic_env.meta_info['task'] == 'classification': | ||||
|         allxs = torch.cat(allxs) | ||||
|         print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item())) | ||||
|         print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item())) | ||||
|     else: | ||||
|         raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) | ||||
|  | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         dpi, width, height = 30, 1800, 1400 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
| @@ -202,19 +211,29 @@ def visualize_env(save_dir, version): | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|  | ||||
|         cur_ax = fig.add_subplot(1, 1, 1) | ||||
|         allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() | ||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx)) | ||||
|         if dynamic_env.meta_info['task'] == 'regression': | ||||
|             allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() | ||||
|             plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)) | ||||
|             cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|             cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         elif dynamic_env.meta_info['task'] == 'classification': | ||||
|             positive, negative = ally == 1, ally == 0 | ||||
|             # plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx)) | ||||
|             plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive") | ||||
|             plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative") | ||||
|             cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1)) | ||||
|             cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1)) | ||||
|         else: | ||||
|             raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) | ||||
|  | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|                 tick.label.set_fontsize(LabelSize - font_gap) | ||||
|                 tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|                 tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize)    | ||||
|         pdf_save_path = ( | ||||
|             save_dir | ||||
|             / "pdf-{:}".format(version) | ||||
| @@ -237,7 +256,7 @@ def visualize_env(save_dir, version): | ||||
|     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) | ||||
|  | ||||
|  | ||||
| def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | ||||
| def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / substr | ||||
|   | ||||
		Reference in New Issue
	
	Block a user