Update ablation for GeMOSA

This commit is contained in:
D-X-Y 2021-05-27 21:51:42 +08:00
parent ffc0d16d6c
commit 726dbef326
3 changed files with 16 additions and 4 deletions

View File

@ -11,3 +11,4 @@ from .affine_utils import normalize_points, denormalize_points
from .affine_utils import identity2affine, solve2theta, affine2image
from .hash_utils import get_md5_file
from .str_utils import split_str2indexes
from .str_utils import show_mean_var

View File

@ -1,3 +1,6 @@
import numpy as np
def split_str2indexes(string: str, max_check: int, length_limit=5):
if not isinstance(string, str):
raise ValueError("Invalid scheme for {:}".format(string))
@ -19,3 +22,13 @@ def split_str2indexes(string: str, max_check: int, length_limit=5):
for i in range(srange[0], srange[1] + 1):
indexes.add(i)
return indexes
def show_mean_var(xlist):
values = np.array(xlist)
print(
"{:.3f}".format(values.mean())
+ "$_{{\pm}{"
+ "{:.3f}".format(values.std())
+ "}}$"
)

View File

@ -20,9 +20,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
SuperLinear(100, 1),
).to(device)
model.train()
optimizer = torch.optim.Adam(
model.parameters(), lr=max_lr, amsgrad=True
)
optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, amsgrad=True)
loss_func = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
@ -47,7 +45,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
# print('loss={:}, best-loss={:}'.format(loss.item(), best_loss))
model.load_state_dict(best_param)
return model, loss_func, best_loss