Update ablation for GeMOSA
This commit is contained in:
parent
ffc0d16d6c
commit
726dbef326
@ -11,3 +11,4 @@ from .affine_utils import normalize_points, denormalize_points
|
||||
from .affine_utils import identity2affine, solve2theta, affine2image
|
||||
from .hash_utils import get_md5_file
|
||||
from .str_utils import split_str2indexes
|
||||
from .str_utils import show_mean_var
|
||||
|
@ -1,3 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def split_str2indexes(string: str, max_check: int, length_limit=5):
|
||||
if not isinstance(string, str):
|
||||
raise ValueError("Invalid scheme for {:}".format(string))
|
||||
@ -19,3 +22,13 @@ def split_str2indexes(string: str, max_check: int, length_limit=5):
|
||||
for i in range(srange[0], srange[1] + 1):
|
||||
indexes.add(i)
|
||||
return indexes
|
||||
|
||||
|
||||
def show_mean_var(xlist):
|
||||
values = np.array(xlist)
|
||||
print(
|
||||
"{:.3f}".format(values.mean())
|
||||
+ "$_{{\pm}{"
|
||||
+ "{:.3f}".format(values.std())
|
||||
+ "}}$"
|
||||
)
|
||||
|
@ -20,9 +20,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
|
||||
SuperLinear(100, 1),
|
||||
).to(device)
|
||||
model.train()
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(), lr=max_lr, amsgrad=True
|
||||
)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, amsgrad=True)
|
||||
loss_func = torch.nn.MSELoss()
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
@ -47,7 +45,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
|
||||
if best_loss is None or best_loss > loss.item():
|
||||
best_loss = loss.item()
|
||||
best_param = copy.deepcopy(model.state_dict())
|
||||
|
||||
|
||||
# print('loss={:}, best-loss={:}'.format(loss.item(), best_loss))
|
||||
model.load_state_dict(best_param)
|
||||
return model, loss_func, best_loss
|
||||
|
Loading…
Reference in New Issue
Block a user