Fix bugs
This commit is contained in:
		| @@ -66,23 +66,24 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | ||||
|  | ||||
| def find_min(cur, others): | ||||
|     if cur is None: | ||||
|         return float(others.min()) | ||||
|         return float(others) | ||||
|     else: | ||||
|         return float(min(cur, others.min())) | ||||
|         return float(min(cur, others)) | ||||
|  | ||||
|  | ||||
| def find_max(cur, others): | ||||
|     if cur is None: | ||||
|         return float(others.max()) | ||||
|     else: | ||||
|         return float(max(cur, others.max())) | ||||
|         return float(max(cur, others)) | ||||
|  | ||||
|  | ||||
| def compare_cl(save_dir): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     dynamic_env, function = create_example_v1( | ||||
|         timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), | ||||
|         # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), | ||||
|         timestamp_config=None, | ||||
|         num_per_task=1000, | ||||
|     ) | ||||
|  | ||||
| @@ -104,13 +105,11 @@ def compare_cl(save_dir): | ||||
|         current_data["lfna_xaxis_all"] = xaxis_all | ||||
|         current_data["lfna_yaxis_all"] = yaxis_all | ||||
|  | ||||
|         import pdb | ||||
|  | ||||
|         pdb.set_trace() | ||||
|  | ||||
|         # compute cl-min | ||||
|         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all) | ||||
|         cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1 | ||||
|         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std()) | ||||
|         cl_xaxis_max = ( | ||||
|             find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) + idx * 0.1 | ||||
|         ) | ||||
|         cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05) | ||||
|  | ||||
|         cl_yaxis_all = cl_function.noise_call(cl_xaxis_all) | ||||
| @@ -142,8 +141,8 @@ def compare_cl(save_dir): | ||||
|                 "yaxis": cl_yaxis_all, | ||||
|                 "color": "r", | ||||
|                 "s": 10, | ||||
|                 "xlim": (-6, 6 + timestamp * 0.2), | ||||
|                 "ylim": (-40, 40), | ||||
|                 "xlim": (round(cl_xaxis_all.min(), 1), round(cl_xaxis_all.max(), 1)), | ||||
|                 "ylim": (round(cl_xaxis_all.min(), 1), round(cl_yaxis_all.max(), 1)), | ||||
|                 "alpha": 0.99, | ||||
|                 "label": "Continual Learning", | ||||
|             } | ||||
| @@ -151,10 +150,10 @@ def compare_cl(save_dir): | ||||
|  | ||||
|         draw_multi_fig( | ||||
|             save_dir, | ||||
|             timestamp, | ||||
|             idx, | ||||
|             scatter_list, | ||||
|             wh=(2000, 1300), | ||||
|             fig_title="Timestamp={:03d}".format(timestamp), | ||||
|             fig_title="Timestamp={:03d}".format(idx), | ||||
|         ) | ||||
|     print("Save all figures into {:}".format(save_dir)) | ||||
|     save_dir = save_dir.resolve() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user