Move to LFNA
This commit is contained in:
parent
89a5faabc3
commit
1980779053
@ -1,7 +1,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||||
############################################################################
|
############################################################################
|
||||||
# CUDA_VISIBLE_DEVICES=0 python exps/synthetic/baseline.py #
|
# CUDA_VISIBLE_DEVICES=0 python exps/LFNA/vis-synthetic.py #
|
||||||
############################################################################
|
############################################################################
|
||||||
import os, sys, copy, random
|
import os, sys, copy, random
|
||||||
import torch
|
import torch
|
||||||
@ -31,17 +31,19 @@ from datasets.synthetic_example import create_example_v1
|
|||||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
from utils.temp_sync import optimize_fn, evaluate_fn
|
||||||
|
|
||||||
|
|
||||||
def draw_fig(save_dir, timestamp, scatter_list):
|
def draw_multi_fig(save_dir, timestamp, scatter_list, fig_title=None):
|
||||||
save_path = save_dir / "{:04d}".format(timestamp)
|
save_path = save_dir / "{:04d}".format(timestamp)
|
||||||
# print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))
|
# print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))
|
||||||
dpi, width, height = 40, 1500, 1500
|
dpi, width, height = 40, 2000, 1300
|
||||||
figsize = width / float(dpi), height / float(dpi)
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
LabelSize, LegendFontsize, font_gap = 80, 80, 5
|
LabelSize, LegendFontsize, font_gap = 80, 80, 5
|
||||||
|
|
||||||
fig = plt.figure(figsize=figsize)
|
fig = plt.figure(figsize=figsize)
|
||||||
|
if fig_title is not None:
|
||||||
|
fig.suptitle(fig_title, fontsize=LegendFontsize)
|
||||||
|
|
||||||
cur_ax = fig.add_subplot(1, 1, 1)
|
for idx, scatter_dict in enumerate(scatter_list):
|
||||||
for scatter_dict in scatter_list:
|
cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1)
|
||||||
cur_ax.scatter(
|
cur_ax.scatter(
|
||||||
scatter_dict["xaxis"],
|
scatter_dict["xaxis"],
|
||||||
scatter_dict["yaxis"],
|
scatter_dict["yaxis"],
|
||||||
@ -50,15 +52,15 @@ def draw_fig(save_dir, timestamp, scatter_list):
|
|||||||
alpha=scatter_dict["alpha"],
|
alpha=scatter_dict["alpha"],
|
||||||
label=scatter_dict["label"],
|
label=scatter_dict["label"],
|
||||||
)
|
)
|
||||||
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
||||||
cur_ax.set_ylabel("f(X)", rotation=0, fontsize=LabelSize)
|
cur_ax.set_ylabel("f(X)", rotation=0, fontsize=LabelSize)
|
||||||
cur_ax.set_xlim(-6, 6)
|
cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1])
|
||||||
cur_ax.set_ylim(-40, 40)
|
cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1])
|
||||||
for tick in cur_ax.xaxis.get_major_ticks():
|
for tick in cur_ax.xaxis.get_major_ticks():
|
||||||
tick.label.set_fontsize(LabelSize - font_gap)
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
tick.label.set_rotation(10)
|
tick.label.set_rotation(10)
|
||||||
for tick in cur_ax.yaxis.get_major_ticks():
|
for tick in cur_ax.yaxis.get_major_ticks():
|
||||||
tick.label.set_fontsize(LabelSize - font_gap)
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
|
|
||||||
plt.legend(loc=1, fontsize=LegendFontsize)
|
plt.legend(loc=1, fontsize=LegendFontsize)
|
||||||
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
|
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
|
||||||
@ -66,7 +68,7 @@ def draw_fig(save_dir, timestamp, scatter_list):
|
|||||||
plt.close("all")
|
plt.close("all")
|
||||||
|
|
||||||
|
|
||||||
def main(save_dir):
|
def compare_cl(save_dir):
|
||||||
save_dir = Path(str(save_dir))
|
save_dir = Path(str(save_dir))
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
dynamic_env, function = create_example_v1(100, num_per_task=1000)
|
dynamic_env, function = create_example_v1(100, num_per_task=1000)
|
||||||
@ -74,6 +76,10 @@ def main(save_dir):
|
|||||||
additional_xaxis = np.arange(-6, 6, 0.2)
|
additional_xaxis = np.arange(-6, 6, 0.2)
|
||||||
models = dict()
|
models = dict()
|
||||||
|
|
||||||
|
cl_function = copy.deepcopy(function)
|
||||||
|
cl_function.set_timestamp(0)
|
||||||
|
cl_xaxis_all = None
|
||||||
|
|
||||||
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||||
xaxis_all = dataset[:, 0].numpy()
|
xaxis_all = dataset[:, 0].numpy()
|
||||||
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
|
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
|
||||||
@ -81,51 +87,46 @@ def main(save_dir):
|
|||||||
function.set_timestamp(timestamp)
|
function.set_timestamp(timestamp)
|
||||||
yaxis_all = function.noise_call(xaxis_all)
|
yaxis_all = function.noise_call(xaxis_all)
|
||||||
|
|
||||||
# split the dataset
|
# create CL data
|
||||||
indexes = list(range(xaxis_all.shape[0]))
|
if cl_xaxis_all is None:
|
||||||
random.shuffle(indexes)
|
cl_xaxis_all = xaxis_all
|
||||||
train_indexes = indexes[: len(indexes) // 2]
|
else:
|
||||||
valid_indexes = indexes[len(indexes) // 2 :]
|
cl_xaxis_all = np.concatenate((cl_xaxis_all, xaxis_all + timestamp * 0.2))
|
||||||
train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes]
|
cl_yaxis_all = cl_function(cl_xaxis_all)
|
||||||
valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes]
|
|
||||||
|
|
||||||
model, loss_fn, train_loss = optimize_fn(train_xs, train_ys)
|
|
||||||
# model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all)
|
|
||||||
pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn)
|
|
||||||
print(
|
|
||||||
"[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(
|
|
||||||
idx, timestamp, train_loss, valid_loss
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# the first plot
|
|
||||||
scatter_list = []
|
scatter_list = []
|
||||||
scatter_list.append(
|
scatter_list.append(
|
||||||
{
|
{
|
||||||
"xaxis": valid_xs,
|
"xaxis": xaxis_all,
|
||||||
"yaxis": valid_ys,
|
"yaxis": yaxis_all,
|
||||||
"color": "k",
|
"color": "k",
|
||||||
"s": 10,
|
"s": 10,
|
||||||
"alpha": 0.99,
|
"alpha": 0.99,
|
||||||
"label": "Timestamp={:02d}".format(timestamp),
|
"xlim": (-6, 6),
|
||||||
|
"ylim": (-40, 40),
|
||||||
|
"label": "LFNA",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
scatter_list.append(
|
scatter_list.append(
|
||||||
{
|
{
|
||||||
"xaxis": valid_xs,
|
"xaxis": cl_xaxis_all,
|
||||||
"yaxis": pred_valid_ys,
|
"yaxis": cl_yaxis_all,
|
||||||
"color": "r",
|
"color": "r",
|
||||||
"s": 10,
|
"s": 10,
|
||||||
"alpha": 0.5,
|
"xlim": (-6, 6 + timestamp * 0.2),
|
||||||
"label": "MLP at now",
|
"ylim": (-200, 40),
|
||||||
|
"alpha": 0.99,
|
||||||
|
"label": "Continual Learning",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
draw_fig(save_dir, timestamp, scatter_list)
|
draw_multi_fig(
|
||||||
|
save_dir, timestamp, scatter_list, "Timestamp={:03d}".format(timestamp)
|
||||||
|
)
|
||||||
print("Save all figures into {:}".format(save_dir))
|
print("Save all figures into {:}".format(save_dir))
|
||||||
save_dir = save_dir.resolve()
|
save_dir = save_dir.resolve()
|
||||||
cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {xdir}/vis.mp4".format(
|
cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1500:1000 -vb 5000k {xdir}/vis.mp4".format(
|
||||||
xdir=save_dir
|
xdir=save_dir
|
||||||
)
|
)
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
@ -133,7 +134,7 @@ def main(save_dir):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser("Baseline")
|
parser = argparse.ArgumentParser("Visualize synthetic data.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_dir",
|
"--save_dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -142,4 +143,4 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args.save_dir)
|
compare_cl(os.path.join(args.save_dir, "compare-cl"))
|
@ -17,8 +17,7 @@ from .super_module import BoolSpaceType
|
|||||||
class SuperReLU(SuperModule):
|
class SuperReLU(SuperModule):
|
||||||
"""Applies a the rectified linear unit function element-wise."""
|
"""Applies a the rectified linear unit function element-wise."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, inplace=False) -> None:
|
||||||
self, inplace=False) -> None:
|
|
||||||
super(SuperReLU, self).__init__()
|
super(SuperReLU, self).__init__()
|
||||||
self._inplace = inplace
|
self._inplace = inplace
|
||||||
|
|
||||||
@ -33,4 +32,4 @@ class SuperReLU(SuperModule):
|
|||||||
return F.relu(input, inplace=self._inplace)
|
return F.relu(input, inplace=self._inplace)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return 'inplace=True' if self._inplace else ''
|
return "inplace=True" if self._inplace else ""
|
||||||
|
@ -18,4 +18,3 @@ from .super_activations import SuperReLU
|
|||||||
|
|
||||||
from .super_trade_stem import SuperAlphaEBDv1
|
from .super_trade_stem import SuperAlphaEBDv1
|
||||||
from .super_positional_embedding import SuperPositionalEncoder
|
from .super_positional_embedding import SuperPositionalEncoder
|
||||||
|
|
||||||
|
@ -79,6 +79,7 @@ def test_super_sequential_v1():
|
|||||||
super_core.SuperSimpleNorm(1, 1),
|
super_core.SuperSimpleNorm(1, 1),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
super_core.SuperLinear(10, 10),
|
super_core.SuperLinear(10, 10),
|
||||||
|
super_core.SuperReLU()
|
||||||
)
|
)
|
||||||
inputs = torch.rand(10, 10)
|
inputs = torch.rand(10, 10)
|
||||||
print(model)
|
print(model)
|
||||||
|
Loading…
Reference in New Issue
Block a user