Update ablation for GeMOSA
This commit is contained in:
		| @@ -11,3 +11,4 @@ from .affine_utils import normalize_points, denormalize_points | ||||
| from .affine_utils import identity2affine, solve2theta, affine2image | ||||
| from .hash_utils import get_md5_file | ||||
| from .str_utils import split_str2indexes | ||||
| from .str_utils import show_mean_var | ||||
|   | ||||
| @@ -1,3 +1,6 @@ | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def split_str2indexes(string: str, max_check: int, length_limit=5): | ||||
|     if not isinstance(string, str): | ||||
|         raise ValueError("Invalid scheme for {:}".format(string)) | ||||
| @@ -19,3 +22,13 @@ def split_str2indexes(string: str, max_check: int, length_limit=5): | ||||
|         for i in range(srange[0], srange[1] + 1): | ||||
|             indexes.add(i) | ||||
|     return indexes | ||||
|  | ||||
|  | ||||
| def show_mean_var(xlist): | ||||
|     values = np.array(xlist) | ||||
|     print( | ||||
|         "{:.3f}".format(values.mean()) | ||||
|         + "$_{{\pm}{" | ||||
|         + "{:.3f}".format(values.std()) | ||||
|         + "}}$" | ||||
|     ) | ||||
|   | ||||
| @@ -20,9 +20,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1): | ||||
|         SuperLinear(100, 1), | ||||
|     ).to(device) | ||||
|     model.train() | ||||
|     optimizer = torch.optim.Adam( | ||||
|         model.parameters(), lr=max_lr, amsgrad=True | ||||
|     ) | ||||
|     optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, amsgrad=True) | ||||
|     loss_func = torch.nn.MSELoss() | ||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|         optimizer, | ||||
| @@ -47,7 +45,7 @@ def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1): | ||||
|         if best_loss is None or best_loss > loss.item(): | ||||
|             best_loss = loss.item() | ||||
|             best_param = copy.deepcopy(model.state_dict()) | ||||
|          | ||||
|  | ||||
|         # print('loss={:}, best-loss={:}'.format(loss.item(), best_loss)) | ||||
|     model.load_state_dict(best_param) | ||||
|     return model, loss_func, best_loss | ||||
|   | ||||
		Reference in New Issue
	
	Block a user