Update GeMOSA v4
This commit is contained in:
		| @@ -5,6 +5,7 @@ | ||||
| # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| @@ -32,15 +33,24 @@ from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.xlayers import super_core, trunc_normal_ | ||||
| from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric | ||||
|  | ||||
| from meta_model import MetaModelV1 | ||||
|  | ||||
|  | ||||
| def online_evaluate( | ||||
|     env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False | ||||
|     env, | ||||
|     meta_model, | ||||
|     base_model, | ||||
|     criterion, | ||||
|     metric, | ||||
|     args, | ||||
|     logger, | ||||
|     save=False, | ||||
|     easy_adapt=False, | ||||
| ): | ||||
|     logger.log("Online evaluate: {:}".format(env)) | ||||
|     metric.reset() | ||||
|     loss_meter = AverageMeter() | ||||
|     w_containers = dict() | ||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||
| @@ -57,6 +67,8 @@ def online_evaluate( | ||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             loss_meter.update(future_loss.item()) | ||||
|             # accumulate the metric scores | ||||
|             metric(future_y_hat, future_y) | ||||
|         if easy_adapt: | ||||
|             meta_model.easy_adapt(future_time.item(), future_time_embed) | ||||
|             refine, post_refine_loss = False, -1 | ||||
| @@ -79,7 +91,7 @@ def online_evaluate( | ||||
|         ) | ||||
|     meta_model.clear_fixed() | ||||
|     meta_model.clear_learnt() | ||||
|     return w_containers, loss_meter | ||||
|     return w_containers, loss_meter.avg, metric.get_info()["score"] | ||||
|  | ||||
|  | ||||
| def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
| @@ -203,7 +215,16 @@ def main(args): | ||||
|  | ||||
|     base_model = get_model(**model_kwargs) | ||||
|     base_model = base_model.to(args.device) | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     if all_env.meta_info["task"] == "regression": | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         metric = MSEMetric(True) | ||||
|     elif all_env.meta_info["task"] == "classification": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|         metric = Top1AccMetric(True) | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||
|         ) | ||||
|  | ||||
|     shape_container = base_model.get_w_container().to_shape_container() | ||||
|  | ||||
| @@ -235,27 +256,29 @@ def main(args): | ||||
|     ) | ||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||
|     """ | ||||
|     _, test_loss_meter_adapt_v1 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, args, logger, False, False | ||||
|     _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, metric, args, logger, False, False | ||||
|     ) | ||||
|     _, test_loss_meter_adapt_v2 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, args, logger, False, True | ||||
|     _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, metric, args, logger, False, True | ||||
|     ) | ||||
|     logger.log( | ||||
|         "In the online test enviornment, the total loss for refine-adapt is {:}".format( | ||||
|             test_loss_meter_adapt_v1 | ||||
|         "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||
|             loss_adapt_v1, metric_adapt_v1 | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         "In the online test enviornment, the total loss for easy-adapt is {:}".format( | ||||
|             test_loss_meter_adapt_v2 | ||||
|         "[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format( | ||||
|             loss_adapt_v2, metric_adapt_v2 | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     save_checkpoint( | ||||
|         { | ||||
|             "test_loss_adapt_v1": test_loss_meter_adapt_v1.avg, | ||||
|             "test_loss_adapt_v2": test_loss_meter_adapt_v2.avg, | ||||
|             "test_loss_adapt_v1": loss_adapt_v1, | ||||
|             "test_loss_adapt_v2": loss_adapt_v2, | ||||
|             "test_metric_adapt_v1": metric_adapt_v1, | ||||
|             "test_metric_adapt_v2": metric_adapt_v2, | ||||
|         }, | ||||
|         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), | ||||
|         logger, | ||||
|   | ||||
| @@ -33,7 +33,9 @@ from xautodl.procedures.metric_utils import MSEMetric | ||||
|  | ||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label) | ||||
|     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None) | ||||
|     cur_ax.scatter( | ||||
|         xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | ||||
| @@ -193,16 +195,28 @@ def visualize_env(save_dir, version): | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     if dynamic_env.meta_info['task'] == 'regression': | ||||
|     if dynamic_env.meta_info["task"] == "regression": | ||||
|         allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|         print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||
|         print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||
|     elif dynamic_env.meta_info['task'] == 'classification': | ||||
|         print( | ||||
|             "x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()) | ||||
|         ) | ||||
|         print( | ||||
|             "y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()) | ||||
|         ) | ||||
|     elif dynamic_env.meta_info["task"] == "classification": | ||||
|         allxs = torch.cat(allxs) | ||||
|         print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item())) | ||||
|         print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item())) | ||||
|         print( | ||||
|             "x[0] - min={:.3f}, max={:.3f}".format( | ||||
|                 allxs[:, 0].min().item(), allxs[:, 0].max().item() | ||||
|             ) | ||||
|         ) | ||||
|         print( | ||||
|             "x[1] - min={:.3f}, max={:.3f}".format( | ||||
|                 allxs[:, 1].min().item(), allxs[:, 1].max().item() | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) | ||||
|         raise ValueError("Unknown task".format(dynamic_env.meta_info["task"])) | ||||
|  | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         dpi, width, height = 30, 1800, 1400 | ||||
| @@ -211,28 +225,50 @@ def visualize_env(save_dir, version): | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|  | ||||
|         cur_ax = fig.add_subplot(1, 1, 1) | ||||
|         if dynamic_env.meta_info['task'] == 'regression': | ||||
|         if dynamic_env.meta_info["task"] == "regression": | ||||
|             allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() | ||||
|             plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)) | ||||
|             plot_scatter( | ||||
|                 cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx) | ||||
|             ) | ||||
|             cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|             cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         elif dynamic_env.meta_info['task'] == 'classification': | ||||
|         elif dynamic_env.meta_info["task"] == "classification": | ||||
|             positive, negative = ally == 1, ally == 0 | ||||
|             # plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx)) | ||||
|             plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive") | ||||
|             plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative") | ||||
|             cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1)) | ||||
|             cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1)) | ||||
|             plot_scatter( | ||||
|                 cur_ax, | ||||
|                 allx[positive, 0], | ||||
|                 allx[positive, 1], | ||||
|                 "r", | ||||
|                 0.99, | ||||
|                 (20, 10), | ||||
|                 "positive", | ||||
|             ) | ||||
|             plot_scatter( | ||||
|                 cur_ax, | ||||
|                 allx[negative, 0], | ||||
|                 allx[negative, 1], | ||||
|                 "g", | ||||
|                 0.99, | ||||
|                 (20, 10), | ||||
|                 "negative", | ||||
|             ) | ||||
|             cur_ax.set_xlim( | ||||
|                 round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1) | ||||
|             ) | ||||
|             cur_ax.set_ylim( | ||||
|                 round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1) | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) | ||||
|             raise ValueError("Unknown task".format(dynamic_env.meta_info["task"])) | ||||
|  | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
|                 tick.label.set_fontsize(LabelSize - font_gap) | ||||
|                 tick.label.set_rotation(10) | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|                 tick.label.set_fontsize(LabelSize - font_gap) | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|         pdf_save_path = ( | ||||
|             save_dir | ||||
|   | ||||
| @@ -98,21 +98,53 @@ class ComposeMetric(Metric): | ||||
| class MSEMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(MSEMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._mse = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             batch = predictions.shape[0] | ||||
|             loss = torch.nn.functional.mse_loss(predictions.data, targets.data) | ||||
|             loss = loss.item() | ||||
|             self._mse.update(loss, batch) | ||||
|             loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item() | ||||
|             if self._ignore_batch: | ||||
|                 self._mse.update(loss, 1) | ||||
|             else: | ||||
|                 self._mse.update(loss, predictions.shape[0]) | ||||
|             return loss | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"mse": self._mse.avg} | ||||
|         return {"mse": self._mse.avg, "score": self._mse.avg} | ||||
|  | ||||
|  | ||||
| class Top1AccMetric(Metric): | ||||
|     """The metric for the top-1 accuracy.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(Top1AccMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._accuracy = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||
|             corrects = torch.eq(max_prob_indexes, targets) | ||||
|             accuracy = corrects.float().mean().float() | ||||
|             if self._ignore_batch: | ||||
|                 self._accuracy.update(accuracy, 1) | ||||
|             else:  # [TODO] for 3-d tensor | ||||
|                 self._accuracy.update(accuracy, predictions.shape[0]) | ||||
|             return accuracy | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} | ||||
|  | ||||
|  | ||||
| class SaveMetric(Metric): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user