Update codes
This commit is contained in:
parent
e1818694a4
commit
8358d71cdf
@ -24,10 +24,7 @@ if str(lib_dir) not in sys.path:
|
|||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
|
||||||
from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv
|
|
||||||
from datasets import DynamicQuadraticFunc
|
|
||||||
from datasets.synthetic_example import create_example_v1
|
from datasets.synthetic_example import create_example_v1
|
||||||
|
|
||||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
from utils.temp_sync import optimize_fn, evaluate_fn
|
||||||
|
|
||||||
|
|
||||||
@ -61,43 +58,72 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
|
|||||||
tick.label.set_rotation(10)
|
tick.label.set_rotation(10)
|
||||||
for tick in cur_ax.yaxis.get_major_ticks():
|
for tick in cur_ax.yaxis.get_major_ticks():
|
||||||
tick.label.set_fontsize(LabelSize - font_gap)
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
plt.legend(loc=1, fontsize=LegendFontsize)
|
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||||
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
|
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
|
||||||
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
|
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
|
|
||||||
|
|
||||||
|
def find_min(cur, others):
|
||||||
|
if cur is None:
|
||||||
|
return float(others.min())
|
||||||
|
else:
|
||||||
|
return float(min(cur, others.min()))
|
||||||
|
|
||||||
|
|
||||||
|
def find_max(cur, others):
|
||||||
|
if cur is None:
|
||||||
|
return float(others.max())
|
||||||
|
else:
|
||||||
|
return float(max(cur, others.max()))
|
||||||
|
|
||||||
|
|
||||||
def compare_cl(save_dir):
|
def compare_cl(save_dir):
|
||||||
save_dir = Path(str(save_dir))
|
save_dir = Path(str(save_dir))
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
dynamic_env, function = create_example_v1(100, num_per_task=1000)
|
dynamic_env, function = create_example_v1(
|
||||||
|
timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
|
||||||
|
num_per_task=1000,
|
||||||
|
)
|
||||||
|
|
||||||
additional_xaxis = np.arange(-6, 6, 0.2)
|
|
||||||
models = dict()
|
models = dict()
|
||||||
|
|
||||||
cl_function = copy.deepcopy(function)
|
cl_function = copy.deepcopy(function)
|
||||||
cl_function.set_timestamp(0)
|
cl_function.set_timestamp(0)
|
||||||
cl_xaxis_all = None
|
cl_xaxis_min = None
|
||||||
|
cl_xaxis_max = None
|
||||||
|
|
||||||
|
all_data = OrderedDict()
|
||||||
|
|
||||||
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||||
xaxis_all = dataset[:, 0].numpy()
|
xaxis_all = dataset[:, 0].numpy()
|
||||||
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
|
current_data = dict()
|
||||||
# compute the ground truth
|
|
||||||
function.set_timestamp(timestamp)
|
function.set_timestamp(timestamp)
|
||||||
yaxis_all = function.noise_call(xaxis_all)
|
yaxis_all = function.noise_call(xaxis_all)
|
||||||
|
current_data["lfna_xaxis_all"] = xaxis_all
|
||||||
|
current_data["lfna_yaxis_all"] = yaxis_all
|
||||||
|
|
||||||
# create CL data
|
import pdb
|
||||||
if cl_xaxis_all is None:
|
|
||||||
cl_xaxis_all = xaxis_all
|
|
||||||
else:
|
|
||||||
cl_xaxis_all = np.concatenate((cl_xaxis_all, xaxis_all + timestamp * 0.2))
|
|
||||||
cl_yaxis_all = cl_function(cl_xaxis_all)
|
|
||||||
|
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
|
# compute cl-min
|
||||||
|
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all)
|
||||||
|
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1
|
||||||
|
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
|
||||||
|
|
||||||
|
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
|
||||||
|
current_data["cl_xaxis_all"] = cl_xaxis_all
|
||||||
|
current_data["cl_yaxis_all"] = cl_yaxis_all
|
||||||
|
all_data[timestamp] = current_data
|
||||||
|
|
||||||
|
for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)):
|
||||||
scatter_list = []
|
scatter_list = []
|
||||||
scatter_list.append(
|
scatter_list.append(
|
||||||
{
|
{
|
||||||
"xaxis": xaxis_all,
|
"xaxis": xdata["lfna_xaxis_all"],
|
||||||
"yaxis": yaxis_all,
|
"yaxis": xdata["lfna_yaxis_all"],
|
||||||
"color": "k",
|
"color": "k",
|
||||||
"s": 10,
|
"s": 10,
|
||||||
"alpha": 0.99,
|
"alpha": 0.99,
|
||||||
@ -107,6 +133,9 @@ def compare_cl(save_dir):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cl_xaxis_all = current_data["cl_xaxis_all"]
|
||||||
|
cl_yaxis_all = current_data["cl_yaxis_all"]
|
||||||
|
|
||||||
scatter_list.append(
|
scatter_list.append(
|
||||||
{
|
{
|
||||||
"xaxis": cl_xaxis_all,
|
"xaxis": cl_xaxis_all,
|
||||||
@ -121,15 +150,21 @@ def compare_cl(save_dir):
|
|||||||
)
|
)
|
||||||
|
|
||||||
draw_multi_fig(
|
draw_multi_fig(
|
||||||
save_dir, timestamp, scatter_list,
|
save_dir,
|
||||||
wh=(2000, 1300), fig_title="Timestamp={:03d}".format(timestamp)
|
timestamp,
|
||||||
|
scatter_list,
|
||||||
|
wh=(2000, 1300),
|
||||||
|
fig_title="Timestamp={:03d}".format(timestamp),
|
||||||
)
|
)
|
||||||
print("Save all figures into {:}".format(save_dir))
|
print("Save all figures into {:}".format(save_dir))
|
||||||
save_dir = save_dir.resolve()
|
save_dir = save_dir.resolve()
|
||||||
cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=2000:1300 -vb 5000k {xdir}/vis.mp4".format(
|
base_cmd = (
|
||||||
xdir=save_dir
|
"ffmpeg -y -i {xdir}/%04d.png -vf fps=2 -vf scale=2000:1300 -vb 5000k".format(
|
||||||
|
xdir=save_dir
|
||||||
|
)
|
||||||
)
|
)
|
||||||
os.system(cmd)
|
os.system("{:} -pix_fmt yuv420p {xdir}/vis.mp4".format(base_cmd, xdir=save_dir))
|
||||||
|
os.system("{:} -c:a libvorbis {xdir}/vis.webm".format(base_cmd, xdir=save_dir))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -19,7 +19,7 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
mean_functors: List[data.Dataset],
|
mean_functors: List[data.Dataset],
|
||||||
cov_functors: List[List[data.Dataset]],
|
cov_functors: List[List[data.Dataset]],
|
||||||
num_per_task: int = 5000,
|
num_per_task: int = 5000,
|
||||||
time_stamp_config: Optional[Dict] = None,
|
timestamp_config: Optional[Dict] = None,
|
||||||
mode: Optional[str] = None,
|
mode: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self._ndim = len(mean_functors)
|
self._ndim = len(mean_functors)
|
||||||
@ -31,12 +31,12 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
cov_functor
|
cov_functor
|
||||||
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
|
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
|
||||||
self._num_per_task = num_per_task
|
self._num_per_task = num_per_task
|
||||||
if time_stamp_config is None:
|
if timestamp_config is None:
|
||||||
time_stamp_config = dict(mode=mode)
|
timestamp_config = dict(mode=mode)
|
||||||
else:
|
else:
|
||||||
time_stamp_config["mode"] = mode
|
timestamp_config["mode"] = mode
|
||||||
|
|
||||||
self._timestamp_generator = TimeStamp(**time_stamp_config)
|
self._timestamp_generator = TimeStamp(**timestamp_config)
|
||||||
|
|
||||||
self._mean_functors = mean_functors
|
self._mean_functors = mean_functors
|
||||||
self._cov_functors = cov_functors
|
self._cov_functors = cov_functors
|
||||||
|
@ -2,21 +2,23 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
#####################################################
|
#####################################################
|
||||||
|
|
||||||
from .math_base_funcs import DynamicQuadraticFunc
|
from .math_adv_funcs import DynamicQuadraticFunc
|
||||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
||||||
from .synthetic_env import SyntheticDEnv
|
from .synthetic_env import SyntheticDEnv
|
||||||
|
|
||||||
|
|
||||||
def create_example_v1(timestamps=50, num_per_task=5000):
|
def create_example_v1(
|
||||||
|
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||||
|
num_per_task=5000,
|
||||||
|
):
|
||||||
mean_generator = ComposedSinFunc()
|
mean_generator = ComposedSinFunc()
|
||||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
||||||
std_generator.set_transform(lambda x: x + 1)
|
|
||||||
|
|
||||||
dynamic_env = SyntheticDEnv(
|
dynamic_env = SyntheticDEnv(
|
||||||
[mean_generator],
|
[mean_generator],
|
||||||
[[std_generator]],
|
[[std_generator]],
|
||||||
num_per_task=num_per_task,
|
num_per_task=num_per_task,
|
||||||
time_stamp_config=dict(num=timestamps),
|
timestamp_config=timestamp_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
function = DynamicQuadraticFunc()
|
function = DynamicQuadraticFunc()
|
||||||
|
8
scripts/black.sh
Normal file
8
scripts/black.sh
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts/black.sh
|
||||||
|
|
||||||
|
black ./tests/
|
||||||
|
black ./lib/datasets
|
||||||
|
black ./lib/xlayers
|
||||||
|
black ./exps/LFNA
|
||||||
|
black ./exps/trading
|
Loading…
Reference in New Issue
Block a user