Update GeMOSA v4
This commit is contained in:
		| @@ -5,6 +5,7 @@ | |||||||
| # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||||
|  | # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -32,15 +33,24 @@ from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | |||||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | from xautodl.datasets.synthetic_core import get_synthetic_env | ||||||
| from xautodl.models.xcore import get_model | from xautodl.models.xcore import get_model | ||||||
| from xautodl.xlayers import super_core, trunc_normal_ | from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric | ||||||
|  |  | ||||||
| from meta_model import MetaModelV1 | from meta_model import MetaModelV1 | ||||||
|  |  | ||||||
|  |  | ||||||
| def online_evaluate( | def online_evaluate( | ||||||
|     env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False |     env, | ||||||
|  |     meta_model, | ||||||
|  |     base_model, | ||||||
|  |     criterion, | ||||||
|  |     metric, | ||||||
|  |     args, | ||||||
|  |     logger, | ||||||
|  |     save=False, | ||||||
|  |     easy_adapt=False, | ||||||
| ): | ): | ||||||
|     logger.log("Online evaluate: {:}".format(env)) |     logger.log("Online evaluate: {:}".format(env)) | ||||||
|  |     metric.reset() | ||||||
|     loss_meter = AverageMeter() |     loss_meter = AverageMeter() | ||||||
|     w_containers = dict() |     w_containers = dict() | ||||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): |     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||||
| @@ -57,6 +67,8 @@ def online_evaluate( | |||||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) |             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||||
|             future_loss = criterion(future_y_hat, future_y) |             future_loss = criterion(future_y_hat, future_y) | ||||||
|             loss_meter.update(future_loss.item()) |             loss_meter.update(future_loss.item()) | ||||||
|  |             # accumulate the metric scores | ||||||
|  |             metric(future_y_hat, future_y) | ||||||
|         if easy_adapt: |         if easy_adapt: | ||||||
|             meta_model.easy_adapt(future_time.item(), future_time_embed) |             meta_model.easy_adapt(future_time.item(), future_time_embed) | ||||||
|             refine, post_refine_loss = False, -1 |             refine, post_refine_loss = False, -1 | ||||||
| @@ -79,7 +91,7 @@ def online_evaluate( | |||||||
|         ) |         ) | ||||||
|     meta_model.clear_fixed() |     meta_model.clear_fixed() | ||||||
|     meta_model.clear_learnt() |     meta_model.clear_learnt() | ||||||
|     return w_containers, loss_meter |     return w_containers, loss_meter.avg, metric.get_info()["score"] | ||||||
|  |  | ||||||
|  |  | ||||||
| def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||||
| @@ -203,7 +215,16 @@ def main(args): | |||||||
|  |  | ||||||
|     base_model = get_model(**model_kwargs) |     base_model = get_model(**model_kwargs) | ||||||
|     base_model = base_model.to(args.device) |     base_model = base_model.to(args.device) | ||||||
|  |     if all_env.meta_info["task"] == "regression": | ||||||
|         criterion = torch.nn.MSELoss() |         criterion = torch.nn.MSELoss() | ||||||
|  |         metric = MSEMetric(True) | ||||||
|  |     elif all_env.meta_info["task"] == "classification": | ||||||
|  |         criterion = torch.nn.CrossEntropyLoss() | ||||||
|  |         metric = Top1AccMetric(True) | ||||||
|  |     else: | ||||||
|  |         raise ValueError( | ||||||
|  |             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     shape_container = base_model.get_w_container().to_shape_container() |     shape_container = base_model.get_w_container().to_shape_container() | ||||||
|  |  | ||||||
| @@ -235,27 +256,29 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) |     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||||
|     """ |     """ | ||||||
|     _, test_loss_meter_adapt_v1 = online_evaluate( |     _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( | ||||||
|         valid_env, meta_model, base_model, criterion, args, logger, False, False |         valid_env, meta_model, base_model, criterion, metric, args, logger, False, False | ||||||
|     ) |     ) | ||||||
|     _, test_loss_meter_adapt_v2 = online_evaluate( |     _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( | ||||||
|         valid_env, meta_model, base_model, criterion, args, logger, False, True |         valid_env, meta_model, base_model, criterion, metric, args, logger, False, True | ||||||
|     ) |     ) | ||||||
|     logger.log( |     logger.log( | ||||||
|         "In the online test enviornment, the total loss for refine-adapt is {:}".format( |         "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||||
|             test_loss_meter_adapt_v1 |             loss_adapt_v1, metric_adapt_v1 | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|     logger.log( |     logger.log( | ||||||
|         "In the online test enviornment, the total loss for easy-adapt is {:}".format( |         "[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||||
|             test_loss_meter_adapt_v2 |             loss_adapt_v2, metric_adapt_v2 | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     save_checkpoint( |     save_checkpoint( | ||||||
|         { |         { | ||||||
|             "test_loss_adapt_v1": test_loss_meter_adapt_v1.avg, |             "test_loss_adapt_v1": loss_adapt_v1, | ||||||
|             "test_loss_adapt_v2": test_loss_meter_adapt_v2.avg, |             "test_loss_adapt_v2": loss_adapt_v2, | ||||||
|  |             "test_metric_adapt_v1": metric_adapt_v1, | ||||||
|  |             "test_metric_adapt_v2": metric_adapt_v2, | ||||||
|         }, |         }, | ||||||
|         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), |         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), | ||||||
|         logger, |         logger, | ||||||
|   | |||||||
| @@ -33,7 +33,9 @@ from xautodl.procedures.metric_utils import MSEMetric | |||||||
|  |  | ||||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label) |     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label) | ||||||
|     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None) |     cur_ax.scatter( | ||||||
|  |         xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | ||||||
| @@ -193,16 +195,28 @@ def visualize_env(save_dir, version): | |||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         allxs.append(allx) |         allxs.append(allx) | ||||||
|         allys.append(ally) |         allys.append(ally) | ||||||
|     if dynamic_env.meta_info['task'] == 'regression': |     if dynamic_env.meta_info["task"] == "regression": | ||||||
|         allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) |         allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||||
|         print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) |         print( | ||||||
|         print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) |             "x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()) | ||||||
|     elif dynamic_env.meta_info['task'] == 'classification': |         ) | ||||||
|  |         print( | ||||||
|  |             "y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()) | ||||||
|  |         ) | ||||||
|  |     elif dynamic_env.meta_info["task"] == "classification": | ||||||
|         allxs = torch.cat(allxs) |         allxs = torch.cat(allxs) | ||||||
|         print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item())) |         print( | ||||||
|         print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item())) |             "x[0] - min={:.3f}, max={:.3f}".format( | ||||||
|  |                 allxs[:, 0].min().item(), allxs[:, 0].max().item() | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         print( | ||||||
|  |             "x[1] - min={:.3f}, max={:.3f}".format( | ||||||
|  |                 allxs[:, 1].min().item(), allxs[:, 1].max().item() | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|     else: |     else: | ||||||
|         raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) |         raise ValueError("Unknown task".format(dynamic_env.meta_info["task"])) | ||||||
|  |  | ||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         dpi, width, height = 30, 1800, 1400 |         dpi, width, height = 30, 1800, 1400 | ||||||
| @@ -211,20 +225,42 @@ def visualize_env(save_dir, version): | |||||||
|         fig = plt.figure(figsize=figsize) |         fig = plt.figure(figsize=figsize) | ||||||
|  |  | ||||||
|         cur_ax = fig.add_subplot(1, 1, 1) |         cur_ax = fig.add_subplot(1, 1, 1) | ||||||
|         if dynamic_env.meta_info['task'] == 'regression': |         if dynamic_env.meta_info["task"] == "regression": | ||||||
|             allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() |             allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() | ||||||
|             plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)) |             plot_scatter( | ||||||
|  |                 cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx) | ||||||
|  |             ) | ||||||
|             cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) |             cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||||
|             cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) |             cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||||
|         elif dynamic_env.meta_info['task'] == 'classification': |         elif dynamic_env.meta_info["task"] == "classification": | ||||||
|             positive, negative = ally == 1, ally == 0 |             positive, negative = ally == 1, ally == 0 | ||||||
|             # plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx)) |             # plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx)) | ||||||
|             plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive") |             plot_scatter( | ||||||
|             plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative") |                 cur_ax, | ||||||
|             cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1)) |                 allx[positive, 0], | ||||||
|             cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1)) |                 allx[positive, 1], | ||||||
|  |                 "r", | ||||||
|  |                 0.99, | ||||||
|  |                 (20, 10), | ||||||
|  |                 "positive", | ||||||
|  |             ) | ||||||
|  |             plot_scatter( | ||||||
|  |                 cur_ax, | ||||||
|  |                 allx[negative, 0], | ||||||
|  |                 allx[negative, 1], | ||||||
|  |                 "g", | ||||||
|  |                 0.99, | ||||||
|  |                 (20, 10), | ||||||
|  |                 "negative", | ||||||
|  |             ) | ||||||
|  |             cur_ax.set_xlim( | ||||||
|  |                 round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1) | ||||||
|  |             ) | ||||||
|  |             cur_ax.set_ylim( | ||||||
|  |                 round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1) | ||||||
|  |             ) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) |             raise ValueError("Unknown task".format(dynamic_env.meta_info["task"])) | ||||||
|  |  | ||||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) |         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) |         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||||
|   | |||||||
| @@ -98,21 +98,53 @@ class ComposeMetric(Metric): | |||||||
| class MSEMetric(Metric): | class MSEMetric(Metric): | ||||||
|     """The metric for mse.""" |     """The metric for mse.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, ignore_batch): | ||||||
|  |         super(MSEMetric, self).__init__() | ||||||
|  |         self._ignore_batch = ignore_batch | ||||||
|  |  | ||||||
|     def reset(self): |     def reset(self): | ||||||
|         self._mse = AverageMeter() |         self._mse = AverageMeter() | ||||||
|  |  | ||||||
|     def __call__(self, predictions, targets): |     def __call__(self, predictions, targets): | ||||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): |         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||||
|             batch = predictions.shape[0] |             loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item() | ||||||
|             loss = torch.nn.functional.mse_loss(predictions.data, targets.data) |             if self._ignore_batch: | ||||||
|             loss = loss.item() |                 self._mse.update(loss, 1) | ||||||
|             self._mse.update(loss, batch) |             else: | ||||||
|  |                 self._mse.update(loss, predictions.shape[0]) | ||||||
|             return loss |             return loss | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError |             raise NotImplementedError | ||||||
|  |  | ||||||
|     def get_info(self): |     def get_info(self): | ||||||
|         return {"mse": self._mse.avg} |         return {"mse": self._mse.avg, "score": self._mse.avg} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Top1AccMetric(Metric): | ||||||
|  |     """The metric for the top-1 accuracy.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, ignore_batch): | ||||||
|  |         super(Top1AccMetric, self).__init__() | ||||||
|  |         self._ignore_batch = ignore_batch | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self._accuracy = AverageMeter() | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets): | ||||||
|  |         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||||
|  |             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||||
|  |             corrects = torch.eq(max_prob_indexes, targets) | ||||||
|  |             accuracy = corrects.float().mean().float() | ||||||
|  |             if self._ignore_batch: | ||||||
|  |                 self._accuracy.update(accuracy, 1) | ||||||
|  |             else:  # [TODO] for 3-d tensor | ||||||
|  |                 self._accuracy.update(accuracy, predictions.shape[0]) | ||||||
|  |             return accuracy | ||||||
|  |         else: | ||||||
|  |             raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} | ||||||
|  |  | ||||||
|  |  | ||||||
| class SaveMetric(Metric): | class SaveMetric(Metric): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user