Mark as ddl passed

This commit is contained in:
D-X-Y 2021-06-01 20:19:48 +08:00
parent b8b94cc791
commit d3d950d310
3 changed files with 13 additions and 4 deletions

View File

@ -151,7 +151,9 @@ if __name__ == "__main__":
key_map = dict()
for xset in ("train", "valid", "test"):
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
# key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
key_map["{:}-mean-Rank-IC".format(xset)] = "Rank IC ({:})".format(xset)
# key_map["{:}-mean-Rank-ICIR".format(xset)] = "Rank ICIR ({:})".format(xset)
all_qresults = []
for save_dir in args.save_dir:

View File

@ -16,7 +16,7 @@
#
# TODO(xuanyidong): upload it to conda
#
# [2021.05.21] v0.9.9
# [2021.06.01] v0.9.9
import os
from setuptools import setup, find_packages

View File

@ -97,6 +97,7 @@ class QResult:
separate: Text = "& ",
space: int = 20,
verbose: bool = True,
version: str = "v1",
):
avaliable_keys = []
for key in keys:
@ -113,8 +114,14 @@ class QResult:
current_values = self._result[key]
mean = np.mean(current_values)
std = np.std(current_values)
# values.append("{:.4f} $\pm$ {:.4f}".format(mean, std))
values.append("{:.2f} $\pm$ {:.2f}".format(mean, std))
if version == "v0":
values.append("{:.2f} $\pm$ {:.2f}".format(mean, std))
elif version == "v1":
values.append(
"{:.2f}".format(mean) + " \\subs{" + "{:.2f}".format(std) + "}"
)
else:
raise ValueError("Unknown version")
value_str = separate.join([self.full_str(x, space) for x in values])
if verbose:
print(head_str)