Update the sync data v1
This commit is contained in:
		| @@ -222,7 +222,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|     logger, model_kwargs = lfna_setup(args) | ||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||
|   | ||||
| @@ -11,33 +11,6 @@ from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| def lfna_setup(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     cache_path = ( | ||||
|         logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) | ||||
|     ).resolve() | ||||
|     if cache_path.exists(): | ||||
|         env_info = torch.load(cache_path) | ||||
|     else: | ||||
|         env_info = dict() | ||||
|         dynamic_env = get_synthetic_env(version=args.env_version) | ||||
|         env_info["total"] = len(dynamic_env) | ||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||
|             env_info["{:}-timestamp".format(idx)] = timestamp | ||||
|             env_info["{:}-x".format(idx)] = _allx | ||||
|             env_info["{:}-y".format(idx)] = _ally | ||||
|         env_info["dynamic_env"] = dynamic_env | ||||
|         torch.save(env_info, cache_path) | ||||
|  | ||||
|     """ | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="simple_mlp"), | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|         hidden_dim=args.hidden_dim, | ||||
|         act_cls="leaky_relu", | ||||
|         norm_cls="identity", | ||||
|     ) | ||||
|     """ | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=1, | ||||
| @@ -46,7 +19,7 @@ def lfna_setup(args): | ||||
|         act_cls="gelu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|     return logger, env_info, model_kwargs | ||||
|     return logger, model_kwargs | ||||
|  | ||||
|  | ||||
| def train_model(model, dataset, lr, epochs): | ||||
|   | ||||
| @@ -20,14 +20,13 @@ matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from models.xcore import get_model | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from utils.temp_sync import optimize_fn, evaluate_fn | ||||
| from procedures.metric_utils import MSEMetric | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.procedures.metric_utils import MSEMetric | ||||
|  | ||||
|  | ||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||
| @@ -181,10 +180,17 @@ def compare_cl(save_dir): | ||||
|  | ||||
| def visualize_env(save_dir, version): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / substr | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(version=version) | ||||
|     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||
|     # min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||
|     allxs, allys = [], [] | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         dpi, width, height = 30, 1800, 1400 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
| @@ -201,21 +207,18 @@ def visualize_env(save_dir, version): | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         if version == "v1": | ||||
|             cur_ax.set_xlim(-2, 2) | ||||
|             cur_ax.set_ylim(-8, 8) | ||||
|         elif version == "v2": | ||||
|             cur_ax.set_xlim(-10, 10) | ||||
|             cur_ax.set_ylim(-60, 60) | ||||
|         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         save_path = save_dir / "v{:}-{:05d}".format(version, idx) | ||||
|         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||
|         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) | ||||
|         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) | ||||
|         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir, version=version | ||||
|         xdir=save_dir / "png", version=version | ||||
|     ) | ||||
|     print(base_cmd) | ||||
|     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) | ||||
| @@ -371,7 +374,7 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||
|     visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||
|     compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||
|     # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user