Update ablation for GeMOSA
This commit is contained in:
parent
ffc0d16d6c
commit
726dbef326
@ -11,3 +11,4 @@ from .affine_utils import normalize_points, denormalize_points
|
|||||||
from .affine_utils import identity2affine, solve2theta, affine2image
|
from .affine_utils import identity2affine, solve2theta, affine2image
|
||||||
from .hash_utils import get_md5_file
|
from .hash_utils import get_md5_file
|
||||||
from .str_utils import split_str2indexes
|
from .str_utils import split_str2indexes
|
||||||
|
from .str_utils import show_mean_var
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def split_str2indexes(string: str, max_check: int, length_limit=5):
|
def split_str2indexes(string: str, max_check: int, length_limit=5):
|
||||||
if not isinstance(string, str):
|
if not isinstance(string, str):
|
||||||
raise ValueError("Invalid scheme for {:}".format(string))
|
raise ValueError("Invalid scheme for {:}".format(string))
|
||||||
@ -19,3 +22,13 @@ def split_str2indexes(string: str, max_check: int, length_limit=5):
|
|||||||
for i in range(srange[0], srange[1] + 1):
|
for i in range(srange[0], srange[1] + 1):
|
||||||
indexes.add(i)
|
indexes.add(i)
|
||||||
return indexes
|
return indexes
|
||||||
|
|
||||||
|
|
||||||
|
def show_mean_var(xlist):
|
||||||
|
values = np.array(xlist)
|
||||||
|
print(
|
||||||
|
"{:.3f}".format(values.mean())
|
||||||
|
+ "$_{{\pm}{"
|
||||||
|
+ "{:.3f}".format(values.std())
|
||||||
|
+ "}}$"
|
||||||
|
)
|
||||||
|
@ -20,9 +20,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
|
|||||||
SuperLinear(100, 1),
|
SuperLinear(100, 1),
|
||||||
).to(device)
|
).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, amsgrad=True)
|
||||||
model.parameters(), lr=max_lr, amsgrad=True
|
|
||||||
)
|
|
||||||
loss_func = torch.nn.MSELoss()
|
loss_func = torch.nn.MSELoss()
|
||||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
@ -47,7 +45,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1):
|
|||||||
if best_loss is None or best_loss > loss.item():
|
if best_loss is None or best_loss > loss.item():
|
||||||
best_loss = loss.item()
|
best_loss = loss.item()
|
||||||
best_param = copy.deepcopy(model.state_dict())
|
best_param = copy.deepcopy(model.state_dict())
|
||||||
|
|
||||||
# print('loss={:}, best-loss={:}'.format(loss.item(), best_loss))
|
# print('loss={:}, best-loss={:}'.format(loss.item(), best_loss))
|
||||||
model.load_state_dict(best_param)
|
model.load_state_dict(best_param)
|
||||||
return model, loss_func, best_loss
|
return model, loss_func, best_loss
|
||||||
|
Loading…
Reference in New Issue
Block a user