Update GeMOSA v4
This commit is contained in:
parent
16861f0f3d
commit
08337138f1
@ -5,6 +5,7 @@
|
|||||||
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
||||||
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
||||||
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||||
|
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, copy, torch, random, argparse
|
import sys, time, copy, torch, random, argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -32,15 +33,24 @@ from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
|
|||||||
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
|
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
|
||||||
from xautodl.datasets.synthetic_core import get_synthetic_env
|
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||||
from xautodl.models.xcore import get_model
|
from xautodl.models.xcore import get_model
|
||||||
from xautodl.xlayers import super_core, trunc_normal_
|
from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric
|
||||||
|
|
||||||
from meta_model import MetaModelV1
|
from meta_model import MetaModelV1
|
||||||
|
|
||||||
|
|
||||||
def online_evaluate(
|
def online_evaluate(
|
||||||
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
|
env,
|
||||||
|
meta_model,
|
||||||
|
base_model,
|
||||||
|
criterion,
|
||||||
|
metric,
|
||||||
|
args,
|
||||||
|
logger,
|
||||||
|
save=False,
|
||||||
|
easy_adapt=False,
|
||||||
):
|
):
|
||||||
logger.log("Online evaluate: {:}".format(env))
|
logger.log("Online evaluate: {:}".format(env))
|
||||||
|
metric.reset()
|
||||||
loss_meter = AverageMeter()
|
loss_meter = AverageMeter()
|
||||||
w_containers = dict()
|
w_containers = dict()
|
||||||
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
||||||
@ -57,6 +67,8 @@ def online_evaluate(
|
|||||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||||
future_loss = criterion(future_y_hat, future_y)
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
loss_meter.update(future_loss.item())
|
loss_meter.update(future_loss.item())
|
||||||
|
# accumulate the metric scores
|
||||||
|
metric(future_y_hat, future_y)
|
||||||
if easy_adapt:
|
if easy_adapt:
|
||||||
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
||||||
refine, post_refine_loss = False, -1
|
refine, post_refine_loss = False, -1
|
||||||
@ -79,7 +91,7 @@ def online_evaluate(
|
|||||||
)
|
)
|
||||||
meta_model.clear_fixed()
|
meta_model.clear_fixed()
|
||||||
meta_model.clear_learnt()
|
meta_model.clear_learnt()
|
||||||
return w_containers, loss_meter
|
return w_containers, loss_meter.avg, metric.get_info()["score"]
|
||||||
|
|
||||||
|
|
||||||
def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||||
@ -203,7 +215,16 @@ def main(args):
|
|||||||
|
|
||||||
base_model = get_model(**model_kwargs)
|
base_model = get_model(**model_kwargs)
|
||||||
base_model = base_model.to(args.device)
|
base_model = base_model.to(args.device)
|
||||||
criterion = torch.nn.MSELoss()
|
if all_env.meta_info["task"] == "regression":
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
metric = MSEMetric(True)
|
||||||
|
elif all_env.meta_info["task"] == "classification":
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
metric = Top1AccMetric(True)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
|
||||||
|
)
|
||||||
|
|
||||||
shape_container = base_model.get_w_container().to_shape_container()
|
shape_container = base_model.get_w_container().to_shape_container()
|
||||||
|
|
||||||
@ -235,27 +256,29 @@ def main(args):
|
|||||||
)
|
)
|
||||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||||
"""
|
"""
|
||||||
_, test_loss_meter_adapt_v1 = online_evaluate(
|
_, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
|
||||||
valid_env, meta_model, base_model, criterion, args, logger, False, False
|
valid_env, meta_model, base_model, criterion, metric, args, logger, False, False
|
||||||
)
|
)
|
||||||
_, test_loss_meter_adapt_v2 = online_evaluate(
|
_, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
|
||||||
valid_env, meta_model, base_model, criterion, args, logger, False, True
|
valid_env, meta_model, base_model, criterion, metric, args, logger, False, True
|
||||||
)
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
|
"[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(
|
||||||
test_loss_meter_adapt_v1
|
loss_adapt_v1, metric_adapt_v1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
|
"[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format(
|
||||||
test_loss_meter_adapt_v2
|
loss_adapt_v2, metric_adapt_v2
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
{
|
{
|
||||||
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
|
"test_loss_adapt_v1": loss_adapt_v1,
|
||||||
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
|
"test_loss_adapt_v2": loss_adapt_v2,
|
||||||
|
"test_metric_adapt_v1": metric_adapt_v1,
|
||||||
|
"test_metric_adapt_v2": metric_adapt_v2,
|
||||||
},
|
},
|
||||||
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
||||||
logger,
|
logger,
|
||||||
|
@ -33,7 +33,9 @@ from xautodl.procedures.metric_utils import MSEMetric
|
|||||||
|
|
||||||
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
|
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
|
||||||
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label)
|
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label)
|
||||||
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None)
|
cur_ax.scatter(
|
||||||
|
xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
|
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
|
||||||
@ -193,16 +195,28 @@ def visualize_env(save_dir, version):
|
|||||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||||
allxs.append(allx)
|
allxs.append(allx)
|
||||||
allys.append(ally)
|
allys.append(ally)
|
||||||
if dynamic_env.meta_info['task'] == 'regression':
|
if dynamic_env.meta_info["task"] == "regression":
|
||||||
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||||
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
|
print(
|
||||||
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
|
"x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())
|
||||||
elif dynamic_env.meta_info['task'] == 'classification':
|
)
|
||||||
|
print(
|
||||||
|
"y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())
|
||||||
|
)
|
||||||
|
elif dynamic_env.meta_info["task"] == "classification":
|
||||||
allxs = torch.cat(allxs)
|
allxs = torch.cat(allxs)
|
||||||
print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item()))
|
print(
|
||||||
print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item()))
|
"x[0] - min={:.3f}, max={:.3f}".format(
|
||||||
|
allxs[:, 0].min().item(), allxs[:, 0].max().item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"x[1] - min={:.3f}, max={:.3f}".format(
|
||||||
|
allxs[:, 1].min().item(), allxs[:, 1].max().item()
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
|
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
|
||||||
|
|
||||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||||
dpi, width, height = 30, 1800, 1400
|
dpi, width, height = 30, 1800, 1400
|
||||||
@ -211,29 +225,51 @@ def visualize_env(save_dir, version):
|
|||||||
fig = plt.figure(figsize=figsize)
|
fig = plt.figure(figsize=figsize)
|
||||||
|
|
||||||
cur_ax = fig.add_subplot(1, 1, 1)
|
cur_ax = fig.add_subplot(1, 1, 1)
|
||||||
if dynamic_env.meta_info['task'] == 'regression':
|
if dynamic_env.meta_info["task"] == "regression":
|
||||||
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
|
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
|
||||||
plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx))
|
plot_scatter(
|
||||||
|
cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)
|
||||||
|
)
|
||||||
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
|
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
|
||||||
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
|
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
|
||||||
elif dynamic_env.meta_info['task'] == 'classification':
|
elif dynamic_env.meta_info["task"] == "classification":
|
||||||
positive, negative = ally == 1, ally == 0
|
positive, negative = ally == 1, ally == 0
|
||||||
# plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx))
|
# plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx))
|
||||||
plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive")
|
plot_scatter(
|
||||||
plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative")
|
cur_ax,
|
||||||
cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1))
|
allx[positive, 0],
|
||||||
cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1))
|
allx[positive, 1],
|
||||||
|
"r",
|
||||||
|
0.99,
|
||||||
|
(20, 10),
|
||||||
|
"positive",
|
||||||
|
)
|
||||||
|
plot_scatter(
|
||||||
|
cur_ax,
|
||||||
|
allx[negative, 0],
|
||||||
|
allx[negative, 1],
|
||||||
|
"g",
|
||||||
|
0.99,
|
||||||
|
(20, 10),
|
||||||
|
"negative",
|
||||||
|
)
|
||||||
|
cur_ax.set_xlim(
|
||||||
|
round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1)
|
||||||
|
)
|
||||||
|
cur_ax.set_ylim(
|
||||||
|
round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown task".format(dynamic_env.meta_info['task']))
|
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
|
||||||
|
|
||||||
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
||||||
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
|
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
|
||||||
for tick in cur_ax.xaxis.get_major_ticks():
|
for tick in cur_ax.xaxis.get_major_ticks():
|
||||||
tick.label.set_fontsize(LabelSize - font_gap)
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
tick.label.set_rotation(10)
|
tick.label.set_rotation(10)
|
||||||
for tick in cur_ax.yaxis.get_major_ticks():
|
for tick in cur_ax.yaxis.get_major_ticks():
|
||||||
tick.label.set_fontsize(LabelSize - font_gap)
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||||
pdf_save_path = (
|
pdf_save_path = (
|
||||||
save_dir
|
save_dir
|
||||||
/ "pdf-{:}".format(version)
|
/ "pdf-{:}".format(version)
|
||||||
|
@ -98,21 +98,53 @@ class ComposeMetric(Metric):
|
|||||||
class MSEMetric(Metric):
|
class MSEMetric(Metric):
|
||||||
"""The metric for mse."""
|
"""The metric for mse."""
|
||||||
|
|
||||||
|
def __init__(self, ignore_batch):
|
||||||
|
super(MSEMetric, self).__init__()
|
||||||
|
self._ignore_batch = ignore_batch
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._mse = AverageMeter()
|
self._mse = AverageMeter()
|
||||||
|
|
||||||
def __call__(self, predictions, targets):
|
def __call__(self, predictions, targets):
|
||||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||||
batch = predictions.shape[0]
|
loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item()
|
||||||
loss = torch.nn.functional.mse_loss(predictions.data, targets.data)
|
if self._ignore_batch:
|
||||||
loss = loss.item()
|
self._mse.update(loss, 1)
|
||||||
self._mse.update(loss, batch)
|
else:
|
||||||
|
self._mse.update(loss, predictions.shape[0])
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_info(self):
|
def get_info(self):
|
||||||
return {"mse": self._mse.avg}
|
return {"mse": self._mse.avg, "score": self._mse.avg}
|
||||||
|
|
||||||
|
|
||||||
|
class Top1AccMetric(Metric):
|
||||||
|
"""The metric for the top-1 accuracy."""
|
||||||
|
|
||||||
|
def __init__(self, ignore_batch):
|
||||||
|
super(Top1AccMetric, self).__init__()
|
||||||
|
self._ignore_batch = ignore_batch
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._accuracy = AverageMeter()
|
||||||
|
|
||||||
|
def __call__(self, predictions, targets):
|
||||||
|
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||||
|
max_prob_indexes = torch.argmax(predictions, dim=-1)
|
||||||
|
corrects = torch.eq(max_prob_indexes, targets)
|
||||||
|
accuracy = corrects.float().mean().float()
|
||||||
|
if self._ignore_batch:
|
||||||
|
self._accuracy.update(accuracy, 1)
|
||||||
|
else: # [TODO] for 3-d tensor
|
||||||
|
self._accuracy.update(accuracy, predictions.shape[0])
|
||||||
|
return accuracy
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100}
|
||||||
|
|
||||||
|
|
||||||
class SaveMetric(Metric):
|
class SaveMetric(Metric):
|
||||||
|
Loading…
Reference in New Issue
Block a user