From 726dbef3261cdb3a22d51237b7a7afff1ebda42b Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 27 May 2021 21:51:42 +0800 Subject: [PATCH] Update ablation for GeMOSA --- xautodl/utils/__init__.py | 1 + xautodl/utils/str_utils.py | 13 +++++++++++++ xautodl/utils/temp_sync.py | 6 ++---- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/xautodl/utils/__init__.py b/xautodl/utils/__init__.py index 3b6052f..4c39ad1 100644 --- a/xautodl/utils/__init__.py +++ b/xautodl/utils/__init__.py @@ -11,3 +11,4 @@ from .affine_utils import normalize_points, denormalize_points from .affine_utils import identity2affine, solve2theta, affine2image from .hash_utils import get_md5_file from .str_utils import split_str2indexes +from .str_utils import show_mean_var diff --git a/xautodl/utils/str_utils.py b/xautodl/utils/str_utils.py index 06bda3d..829379e 100644 --- a/xautodl/utils/str_utils.py +++ b/xautodl/utils/str_utils.py @@ -1,3 +1,6 @@ +import numpy as np + + def split_str2indexes(string: str, max_check: int, length_limit=5): if not isinstance(string, str): raise ValueError("Invalid scheme for {:}".format(string)) @@ -19,3 +22,13 @@ def split_str2indexes(string: str, max_check: int, length_limit=5): for i in range(srange[0], srange[1] + 1): indexes.add(i) return indexes + + +def show_mean_var(xlist): + values = np.array(xlist) + print( + "{:.3f}".format(values.mean()) + + "$_{{\pm}{" + + "{:.3f}".format(values.std()) + + "}}$" + ) diff --git a/xautodl/utils/temp_sync.py b/xautodl/utils/temp_sync.py index 8dd89a6..fe8526d 100644 --- a/xautodl/utils/temp_sync.py +++ b/xautodl/utils/temp_sync.py @@ -20,9 +20,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1): SuperLinear(100, 1), ).to(device) model.train() - optimizer = torch.optim.Adam( - model.parameters(), lr=max_lr, amsgrad=True - ) + optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, amsgrad=True) loss_func = torch.nn.MSELoss() lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, @@ -47,7 +45,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1): if best_loss is None or best_loss > loss.item(): best_loss = loss.item() best_param = copy.deepcopy(model.state_dict()) - + # print('loss={:}, best-loss={:}'.format(loss.item(), best_loss)) model.load_state_dict(best_param) return model, loss_func, best_loss