diff --git a/exps/GeMOSA/main.py b/exps/GeMOSA/main.py index e489e0d..52258ff 100644 --- a/exps/GeMOSA/main.py +++ b/exps/GeMOSA/main.py @@ -5,6 +5,7 @@ # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda +# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -32,15 +33,24 @@ from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric from xautodl.datasets.synthetic_core import get_synthetic_env from xautodl.models.xcore import get_model -from xautodl.xlayers import super_core, trunc_normal_ +from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric from meta_model import MetaModelV1 def online_evaluate( - env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False + env, + meta_model, + base_model, + criterion, + metric, + args, + logger, + save=False, + easy_adapt=False, ): logger.log("Online evaluate: {:}".format(env)) + metric.reset() loss_meter = AverageMeter() w_containers = dict() for idx, (future_time, (future_x, future_y)) in enumerate(env): @@ -57,6 +67,8 @@ def online_evaluate( future_y_hat = base_model.forward_with_container(future_x, future_container) future_loss = criterion(future_y_hat, future_y) loss_meter.update(future_loss.item()) + # accumulate the metric scores + metric(future_y_hat, future_y) if easy_adapt: meta_model.easy_adapt(future_time.item(), future_time_embed) refine, post_refine_loss = False, -1 @@ -79,7 +91,7 @@ def online_evaluate( ) meta_model.clear_fixed() meta_model.clear_learnt() - return w_containers, loss_meter + return w_containers, loss_meter.avg, metric.get_info()["score"] def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): @@ -203,7 +215,16 @@ def main(args): base_model = get_model(**model_kwargs) base_model = base_model.to(args.device) - criterion = torch.nn.MSELoss() + if all_env.meta_info["task"] == "regression": + criterion = torch.nn.MSELoss() + metric = MSEMetric(True) + elif all_env.meta_info["task"] == "classification": + criterion = torch.nn.CrossEntropyLoss() + metric = Top1AccMetric(True) + else: + raise ValueError( + "This task ({:}) is not supported.".format(all_env.meta_info["task"]) + ) shape_container = base_model.get_w_container().to_shape_container() @@ -235,27 +256,29 @@ def main(args): ) logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) """ - _, test_loss_meter_adapt_v1 = online_evaluate( - valid_env, meta_model, base_model, criterion, args, logger, False, False + _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( + valid_env, meta_model, base_model, criterion, metric, args, logger, False, False ) - _, test_loss_meter_adapt_v2 = online_evaluate( - valid_env, meta_model, base_model, criterion, args, logger, False, True + _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( + valid_env, meta_model, base_model, criterion, metric, args, logger, False, True ) logger.log( - "In the online test enviornment, the total loss for refine-adapt is {:}".format( - test_loss_meter_adapt_v1 + "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format( + loss_adapt_v1, metric_adapt_v1 ) ) logger.log( - "In the online test enviornment, the total loss for easy-adapt is {:}".format( - test_loss_meter_adapt_v2 + "[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format( + loss_adapt_v2, metric_adapt_v2 ) ) save_checkpoint( { - "test_loss_adapt_v1": test_loss_meter_adapt_v1.avg, - "test_loss_adapt_v2": test_loss_meter_adapt_v2.avg, + "test_loss_adapt_v1": loss_adapt_v1, + "test_loss_adapt_v2": loss_adapt_v2, + "test_metric_adapt_v1": metric_adapt_v1, + "test_metric_adapt_v2": metric_adapt_v2, }, logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), logger, diff --git a/exps/GeMOSA/vis-synthetic.py b/exps/GeMOSA/vis-synthetic.py index 5148a5a..a6f61fe 100644 --- a/exps/GeMOSA/vis-synthetic.py +++ b/exps/GeMOSA/vis-synthetic.py @@ -33,7 +33,9 @@ from xautodl.procedures.metric_utils import MSEMetric def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label) - cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None) + cur_ax.scatter( + xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None + ) def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): @@ -193,16 +195,28 @@ def visualize_env(save_dir, version): for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): allxs.append(allx) allys.append(ally) - if dynamic_env.meta_info['task'] == 'regression': + if dynamic_env.meta_info["task"] == "regression": allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) - print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) - print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) - elif dynamic_env.meta_info['task'] == 'classification': + print( + "x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()) + ) + print( + "y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()) + ) + elif dynamic_env.meta_info["task"] == "classification": allxs = torch.cat(allxs) - print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item())) - print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item())) + print( + "x[0] - min={:.3f}, max={:.3f}".format( + allxs[:, 0].min().item(), allxs[:, 0].max().item() + ) + ) + print( + "x[1] - min={:.3f}, max={:.3f}".format( + allxs[:, 1].min().item(), allxs[:, 1].max().item() + ) + ) else: - raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) + raise ValueError("Unknown task".format(dynamic_env.meta_info["task"])) for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): dpi, width, height = 30, 1800, 1400 @@ -211,29 +225,51 @@ def visualize_env(save_dir, version): fig = plt.figure(figsize=figsize) cur_ax = fig.add_subplot(1, 1, 1) - if dynamic_env.meta_info['task'] == 'regression': + if dynamic_env.meta_info["task"] == "regression": allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() - plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)) + plot_scatter( + cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx) + ) cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) - elif dynamic_env.meta_info['task'] == 'classification': + elif dynamic_env.meta_info["task"] == "classification": positive, negative = ally == 1, ally == 0 # plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx)) - plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive") - plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative") - cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1)) - cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1)) + plot_scatter( + cur_ax, + allx[positive, 0], + allx[positive, 1], + "r", + 0.99, + (20, 10), + "positive", + ) + plot_scatter( + cur_ax, + allx[negative, 0], + allx[negative, 1], + "g", + 0.99, + (20, 10), + "negative", + ) + cur_ax.set_xlim( + round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1) + ) + cur_ax.set_ylim( + round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1) + ) else: - raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) + raise ValueError("Unknown task".format(dynamic_env.meta_info["task"])) cur_ax.set_xlabel("X", fontsize=LabelSize) cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) for tick in cur_ax.xaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize - font_gap) - tick.label.set_rotation(10) + tick.label.set_fontsize(LabelSize - font_gap) + tick.label.set_rotation(10) for tick in cur_ax.yaxis.get_major_ticks(): - tick.label.set_fontsize(LabelSize - font_gap) - cur_ax.legend(loc=1, fontsize=LegendFontsize) + tick.label.set_fontsize(LabelSize - font_gap) + cur_ax.legend(loc=1, fontsize=LegendFontsize) pdf_save_path = ( save_dir / "pdf-{:}".format(version) diff --git a/xautodl/procedures/metric_utils.py b/xautodl/procedures/metric_utils.py index f88c587..28b2cff 100644 --- a/xautodl/procedures/metric_utils.py +++ b/xautodl/procedures/metric_utils.py @@ -98,21 +98,53 @@ class ComposeMetric(Metric): class MSEMetric(Metric): """The metric for mse.""" + def __init__(self, ignore_batch): + super(MSEMetric, self).__init__() + self._ignore_batch = ignore_batch + def reset(self): self._mse = AverageMeter() def __call__(self, predictions, targets): if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): - batch = predictions.shape[0] - loss = torch.nn.functional.mse_loss(predictions.data, targets.data) - loss = loss.item() - self._mse.update(loss, batch) + loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item() + if self._ignore_batch: + self._mse.update(loss, 1) + else: + self._mse.update(loss, predictions.shape[0]) return loss else: raise NotImplementedError def get_info(self): - return {"mse": self._mse.avg} + return {"mse": self._mse.avg, "score": self._mse.avg} + + +class Top1AccMetric(Metric): + """The metric for the top-1 accuracy.""" + + def __init__(self, ignore_batch): + super(Top1AccMetric, self).__init__() + self._ignore_batch = ignore_batch + + def reset(self): + self._accuracy = AverageMeter() + + def __call__(self, predictions, targets): + if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): + max_prob_indexes = torch.argmax(predictions, dim=-1) + corrects = torch.eq(max_prob_indexes, targets) + accuracy = corrects.float().mean().float() + if self._ignore_batch: + self._accuracy.update(accuracy, 1) + else: # [TODO] for 3-d tensor + self._accuracy.update(accuracy, predictions.shape[0]) + return accuracy + else: + raise NotImplementedError + + def get_info(self): + return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} class SaveMetric(Metric):