Update codes
This commit is contained in:
		| @@ -24,10 +24,7 @@ if str(lib_dir) not in sys.path: | |||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  |  | ||||||
| from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv |  | ||||||
| from datasets import DynamicQuadraticFunc |  | ||||||
| from datasets.synthetic_example import create_example_v1 | from datasets.synthetic_example import create_example_v1 | ||||||
|  |  | ||||||
| from utils.temp_sync import optimize_fn, evaluate_fn | from utils.temp_sync import optimize_fn, evaluate_fn | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -61,43 +58,72 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | |||||||
|             tick.label.set_rotation(10) |             tick.label.set_rotation(10) | ||||||
|         for tick in cur_ax.yaxis.get_major_ticks(): |         for tick in cur_ax.yaxis.get_major_ticks(): | ||||||
|             tick.label.set_fontsize(LabelSize - font_gap) |             tick.label.set_fontsize(LabelSize - font_gap) | ||||||
|         plt.legend(loc=1, fontsize=LegendFontsize) |         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||||
|     fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") |     fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") | ||||||
|     fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") |     fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||||
|     plt.close("all") |     plt.close("all") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def find_min(cur, others): | ||||||
|  |     if cur is None: | ||||||
|  |         return float(others.min()) | ||||||
|  |     else: | ||||||
|  |         return float(min(cur, others.min())) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def find_max(cur, others): | ||||||
|  |     if cur is None: | ||||||
|  |         return float(others.max()) | ||||||
|  |     else: | ||||||
|  |         return float(max(cur, others.max())) | ||||||
|  |  | ||||||
|  |  | ||||||
| def compare_cl(save_dir): | def compare_cl(save_dir): | ||||||
|     save_dir = Path(str(save_dir)) |     save_dir = Path(str(save_dir)) | ||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |     save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|     dynamic_env, function = create_example_v1(100, num_per_task=1000) |     dynamic_env, function = create_example_v1( | ||||||
|  |         timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), | ||||||
|  |         num_per_task=1000, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     additional_xaxis = np.arange(-6, 6, 0.2) |  | ||||||
|     models = dict() |     models = dict() | ||||||
|  |  | ||||||
|     cl_function = copy.deepcopy(function) |     cl_function = copy.deepcopy(function) | ||||||
|     cl_function.set_timestamp(0) |     cl_function.set_timestamp(0) | ||||||
|     cl_xaxis_all = None |     cl_xaxis_min = None | ||||||
|  |     cl_xaxis_max = None | ||||||
|  |  | ||||||
|  |     all_data = OrderedDict() | ||||||
|  |  | ||||||
|     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         xaxis_all = dataset[:, 0].numpy() |         xaxis_all = dataset[:, 0].numpy() | ||||||
|         # xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) |         current_data = dict() | ||||||
|         # compute the ground truth |  | ||||||
|         function.set_timestamp(timestamp) |         function.set_timestamp(timestamp) | ||||||
|         yaxis_all = function.noise_call(xaxis_all) |         yaxis_all = function.noise_call(xaxis_all) | ||||||
|  |         current_data["lfna_xaxis_all"] = xaxis_all | ||||||
|  |         current_data["lfna_yaxis_all"] = yaxis_all | ||||||
|  |  | ||||||
|         # create CL data |         import pdb | ||||||
|         if cl_xaxis_all is None: |  | ||||||
|             cl_xaxis_all = xaxis_all |  | ||||||
|         else: |  | ||||||
|             cl_xaxis_all = np.concatenate((cl_xaxis_all, xaxis_all + timestamp * 0.2)) |  | ||||||
|         cl_yaxis_all = cl_function(cl_xaxis_all) |  | ||||||
|  |  | ||||||
|  |         pdb.set_trace() | ||||||
|  |  | ||||||
|  |         # compute cl-min | ||||||
|  |         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all) | ||||||
|  |         cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1 | ||||||
|  |         cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05) | ||||||
|  |  | ||||||
|  |         cl_yaxis_all = cl_function.noise_call(cl_xaxis_all) | ||||||
|  |         current_data["cl_xaxis_all"] = cl_xaxis_all | ||||||
|  |         current_data["cl_yaxis_all"] = cl_yaxis_all | ||||||
|  |         all_data[timestamp] = current_data | ||||||
|  |  | ||||||
|  |     for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)): | ||||||
|         scatter_list = [] |         scatter_list = [] | ||||||
|         scatter_list.append( |         scatter_list.append( | ||||||
|             { |             { | ||||||
|                 "xaxis": xaxis_all, |                 "xaxis": xdata["lfna_xaxis_all"], | ||||||
|                 "yaxis": yaxis_all, |                 "yaxis": xdata["lfna_yaxis_all"], | ||||||
|                 "color": "k", |                 "color": "k", | ||||||
|                 "s": 10, |                 "s": 10, | ||||||
|                 "alpha": 0.99, |                 "alpha": 0.99, | ||||||
| @@ -107,6 +133,9 @@ def compare_cl(save_dir): | |||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |         cl_xaxis_all = current_data["cl_xaxis_all"] | ||||||
|  |         cl_yaxis_all = current_data["cl_yaxis_all"] | ||||||
|  |  | ||||||
|         scatter_list.append( |         scatter_list.append( | ||||||
|             { |             { | ||||||
|                 "xaxis": cl_xaxis_all, |                 "xaxis": cl_xaxis_all, | ||||||
| @@ -121,15 +150,21 @@ def compare_cl(save_dir): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         draw_multi_fig( |         draw_multi_fig( | ||||||
|             save_dir, timestamp, scatter_list, |             save_dir, | ||||||
|             wh=(2000, 1300), fig_title="Timestamp={:03d}".format(timestamp) |             timestamp, | ||||||
|  |             scatter_list, | ||||||
|  |             wh=(2000, 1300), | ||||||
|  |             fig_title="Timestamp={:03d}".format(timestamp), | ||||||
|         ) |         ) | ||||||
|     print("Save all figures into {:}".format(save_dir)) |     print("Save all figures into {:}".format(save_dir)) | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=2000:1300 -vb 5000k {xdir}/vis.mp4".format( |     base_cmd = ( | ||||||
|         xdir=save_dir |         "ffmpeg -y -i {xdir}/%04d.png -vf fps=2 -vf scale=2000:1300 -vb 5000k".format( | ||||||
|  |             xdir=save_dir | ||||||
|  |         ) | ||||||
|     ) |     ) | ||||||
|     os.system(cmd) |     os.system("{:} -pix_fmt yuv420p {xdir}/vis.mp4".format(base_cmd, xdir=save_dir)) | ||||||
|  |     os.system("{:} -c:a libvorbis {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|   | |||||||
| @@ -19,7 +19,7 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         mean_functors: List[data.Dataset], |         mean_functors: List[data.Dataset], | ||||||
|         cov_functors: List[List[data.Dataset]], |         cov_functors: List[List[data.Dataset]], | ||||||
|         num_per_task: int = 5000, |         num_per_task: int = 5000, | ||||||
|         time_stamp_config: Optional[Dict] = None, |         timestamp_config: Optional[Dict] = None, | ||||||
|         mode: Optional[str] = None, |         mode: Optional[str] = None, | ||||||
|     ): |     ): | ||||||
|         self._ndim = len(mean_functors) |         self._ndim = len(mean_functors) | ||||||
| @@ -31,12 +31,12 @@ class SyntheticDEnv(data.Dataset): | |||||||
|                 cov_functor |                 cov_functor | ||||||
|             ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) |             ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) | ||||||
|         self._num_per_task = num_per_task |         self._num_per_task = num_per_task | ||||||
|         if time_stamp_config is None: |         if timestamp_config is None: | ||||||
|             time_stamp_config = dict(mode=mode) |             timestamp_config = dict(mode=mode) | ||||||
|         else: |         else: | ||||||
|             time_stamp_config["mode"] = mode |             timestamp_config["mode"] = mode | ||||||
|  |  | ||||||
|         self._timestamp_generator = TimeStamp(**time_stamp_config) |         self._timestamp_generator = TimeStamp(**timestamp_config) | ||||||
|  |  | ||||||
|         self._mean_functors = mean_functors |         self._mean_functors = mean_functors | ||||||
|         self._cov_functors = cov_functors |         self._cov_functors = cov_functors | ||||||
|   | |||||||
| @@ -2,21 +2,23 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| from .math_base_funcs import DynamicQuadraticFunc | from .math_adv_funcs import DynamicQuadraticFunc | ||||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_example_v1(timestamps=50, num_per_task=5000): | def create_example_v1( | ||||||
|  |     timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0), | ||||||
|  |     num_per_task=5000, | ||||||
|  | ): | ||||||
|     mean_generator = ComposedSinFunc() |     mean_generator = ComposedSinFunc() | ||||||
|     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) |     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) | ||||||
|     std_generator.set_transform(lambda x: x + 1) |  | ||||||
|  |  | ||||||
|     dynamic_env = SyntheticDEnv( |     dynamic_env = SyntheticDEnv( | ||||||
|         [mean_generator], |         [mean_generator], | ||||||
|         [[std_generator]], |         [[std_generator]], | ||||||
|         num_per_task=num_per_task, |         num_per_task=num_per_task, | ||||||
|         time_stamp_config=dict(num=timestamps), |         timestamp_config=timestamp_config, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     function = DynamicQuadraticFunc() |     function = DynamicQuadraticFunc() | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								scripts/black.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								scripts/black.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # bash ./scripts/black.sh | ||||||
|  |  | ||||||
|  | black ./tests/ | ||||||
|  | black ./lib/datasets | ||||||
|  | black ./lib/xlayers | ||||||
|  | black ./exps/LFNA | ||||||
|  | black ./exps/trading | ||||||
		Reference in New Issue
	
	Block a user