From 198077905373ebd3ab19ae64c151b85d1540ebdd Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 25 Apr 2021 23:06:51 -0700 Subject: [PATCH] Move to LFNA --- .../baseline.py => LFNA/vis-synthetic.py} | 87 ++++++++++--------- lib/xlayers/super_activations.py | 5 +- lib/xlayers/super_core.py | 1 - tests/test_super_container.py | 1 + 4 files changed, 47 insertions(+), 47 deletions(-) rename exps/{synthetic/baseline.py => LFNA/vis-synthetic.py} (60%) diff --git a/exps/synthetic/baseline.py b/exps/LFNA/vis-synthetic.py similarity index 60% rename from exps/synthetic/baseline.py rename to exps/LFNA/vis-synthetic.py index 3f4f46c..b9a5f18 100644 --- a/exps/synthetic/baseline.py +++ b/exps/LFNA/vis-synthetic.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ############################################################################ -# CUDA_VISIBLE_DEVICES=0 python exps/synthetic/baseline.py # +# CUDA_VISIBLE_DEVICES=0 python exps/LFNA/vis-synthetic.py # ############################################################################ import os, sys, copy, random import torch @@ -31,17 +31,19 @@ from datasets.synthetic_example import create_example_v1 from utils.temp_sync import optimize_fn, evaluate_fn -def draw_fig(save_dir, timestamp, scatter_list): +def draw_multi_fig(save_dir, timestamp, scatter_list, fig_title=None): save_path = save_dir / "{:04d}".format(timestamp) # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path)) - dpi, width, height = 40, 1500, 1500 + dpi, width, height = 40, 2000, 1300 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize, font_gap = 80, 80, 5 fig = plt.figure(figsize=figsize) + if fig_title is not None: + fig.suptitle(fig_title, fontsize=LegendFontsize) - cur_ax = fig.add_subplot(1, 1, 1) - for scatter_dict in scatter_list: + for idx, scatter_dict in enumerate(scatter_list): + cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1) cur_ax.scatter( scatter_dict["xaxis"], scatter_dict["yaxis"], @@ -50,15 +52,15 @@ def draw_fig(save_dir, timestamp, scatter_list): alpha=scatter_dict["alpha"], label=scatter_dict["label"], ) - cur_ax.set_xlabel("X", fontsize=LabelSize) - cur_ax.set_ylabel("f(X)", rotation=0, fontsize=LabelSize) - cur_ax.set_xlim(-6, 6) - cur_ax.set_ylim(-40, 40) - for tick in cur_ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize - font_gap) - tick.label.set_rotation(10) - for tick in cur_ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize - font_gap) + cur_ax.set_xlabel("X", fontsize=LabelSize) + cur_ax.set_ylabel("f(X)", rotation=0, fontsize=LabelSize) + cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1]) + cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1]) + for tick in cur_ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize - font_gap) + tick.label.set_rotation(10) + for tick in cur_ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize - font_gap) plt.legend(loc=1, fontsize=LegendFontsize) fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") @@ -66,7 +68,7 @@ def draw_fig(save_dir, timestamp, scatter_list): plt.close("all") -def main(save_dir): +def compare_cl(save_dir): save_dir = Path(str(save_dir)) save_dir.mkdir(parents=True, exist_ok=True) dynamic_env, function = create_example_v1(100, num_per_task=1000) @@ -74,6 +76,10 @@ def main(save_dir): additional_xaxis = np.arange(-6, 6, 0.2) models = dict() + cl_function = copy.deepcopy(function) + cl_function.set_timestamp(0) + cl_xaxis_all = None + for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): xaxis_all = dataset[:, 0].numpy() # xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) @@ -81,51 +87,46 @@ def main(save_dir): function.set_timestamp(timestamp) yaxis_all = function.noise_call(xaxis_all) - # split the dataset - indexes = list(range(xaxis_all.shape[0])) - random.shuffle(indexes) - train_indexes = indexes[: len(indexes) // 2] - valid_indexes = indexes[len(indexes) // 2 :] - train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes] - valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes] + # create CL data + if cl_xaxis_all is None: + cl_xaxis_all = xaxis_all + else: + cl_xaxis_all = np.concatenate((cl_xaxis_all, xaxis_all + timestamp * 0.2)) + cl_yaxis_all = cl_function(cl_xaxis_all) - model, loss_fn, train_loss = optimize_fn(train_xs, train_ys) - # model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all) - pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn) - print( - "[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format( - idx, timestamp, train_loss, valid_loss - ) - ) - - # the first plot scatter_list = [] scatter_list.append( { - "xaxis": valid_xs, - "yaxis": valid_ys, + "xaxis": xaxis_all, + "yaxis": yaxis_all, "color": "k", "s": 10, "alpha": 0.99, - "label": "Timestamp={:02d}".format(timestamp), + "xlim": (-6, 6), + "ylim": (-40, 40), + "label": "LFNA", } ) scatter_list.append( { - "xaxis": valid_xs, - "yaxis": pred_valid_ys, + "xaxis": cl_xaxis_all, + "yaxis": cl_yaxis_all, "color": "r", "s": 10, - "alpha": 0.5, - "label": "MLP at now", + "xlim": (-6, 6 + timestamp * 0.2), + "ylim": (-200, 40), + "alpha": 0.99, + "label": "Continual Learning", } ) - draw_fig(save_dir, timestamp, scatter_list) + draw_multi_fig( + save_dir, timestamp, scatter_list, "Timestamp={:03d}".format(timestamp) + ) print("Save all figures into {:}".format(save_dir)) save_dir = save_dir.resolve() - cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {xdir}/vis.mp4".format( + cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1500:1000 -vb 5000k {xdir}/vis.mp4".format( xdir=save_dir ) os.system(cmd) @@ -133,7 +134,7 @@ def main(save_dir): if __name__ == "__main__": - parser = argparse.ArgumentParser("Baseline") + parser = argparse.ArgumentParser("Visualize synthetic data.") parser.add_argument( "--save_dir", type=str, @@ -142,4 +143,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - main(args.save_dir) + compare_cl(os.path.join(args.save_dir, "compare-cl")) diff --git a/lib/xlayers/super_activations.py b/lib/xlayers/super_activations.py index a7a1a0c..a0dac54 100644 --- a/lib/xlayers/super_activations.py +++ b/lib/xlayers/super_activations.py @@ -17,8 +17,7 @@ from .super_module import BoolSpaceType class SuperReLU(SuperModule): """Applies a the rectified linear unit function element-wise.""" - def __init__( - self, inplace=False) -> None: + def __init__(self, inplace=False) -> None: super(SuperReLU, self).__init__() self._inplace = inplace @@ -33,4 +32,4 @@ class SuperReLU(SuperModule): return F.relu(input, inplace=self._inplace) def extra_repr(self) -> str: - return 'inplace=True' if self._inplace else '' + return "inplace=True" if self._inplace else "" diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index e5d55d4..11d3fd2 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -18,4 +18,3 @@ from .super_activations import SuperReLU from .super_trade_stem import SuperAlphaEBDv1 from .super_positional_embedding import SuperPositionalEncoder - diff --git a/tests/test_super_container.py b/tests/test_super_container.py index 73fde6e..affa107 100644 --- a/tests/test_super_container.py +++ b/tests/test_super_container.py @@ -79,6 +79,7 @@ def test_super_sequential_v1(): super_core.SuperSimpleNorm(1, 1), torch.nn.ReLU(), super_core.SuperLinear(10, 10), + super_core.SuperReLU() ) inputs = torch.rand(10, 10) print(model)