naswot/pycls/core/plotting.py
Jack Turner b74255e1f3 v2
2021-02-26 16:12:51 +00:00

133 lines
4.6 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Plotting functions."""
import colorlover as cl
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.offline as offline
import pycls.core.logging as logging
def get_plot_colors(max_colors, color_format="pyplot"):
"""Generate colors for plotting."""
colors = cl.scales["11"]["qual"]["Paired"]
if max_colors > len(colors):
colors = cl.to_rgb(cl.interp(colors, max_colors))
if color_format == "pyplot":
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
return colors
def prepare_plot_data(log_files, names, metric="top1_err"):
"""Load logs and extract data for plotting error curves."""
plot_data = []
for file, name in zip(log_files, names):
d, data = {}, logging.sort_log_data(logging.load_log_data(file))
for phase in ["train", "test"]:
x = data[phase + "_epoch"]["epoch_ind"]
y = data[phase + "_epoch"][metric]
d["x_" + phase], d["y_" + phase] = x, y
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
plot_data.append(d)
assert len(plot_data) > 0, "No data to plot"
return plot_data
def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
"""Plot error curves using plotly and save to file."""
plot_data = prepare_plot_data(log_files, names, metric)
colors = get_plot_colors(len(plot_data), "plotly")
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
data = []
for i, d in enumerate(plot_data):
s = str(i)
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=True,
showlegend=False,
)
)
data.append(
go.Scatter(
x=d["x_test"],
y=d["y_test"],
mode="lines",
name=d["test_label"],
line=line_test,
legendgroup=s,
visible=True,
showlegend=True,
)
)
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=False,
showlegend=True,
)
)
# Prepare layout w ability to toggle 'all', 'train', 'test'
titlefont = {"size": 18, "color": "#7f7f7f"}
vis = [[True, True, False], [False, False, True], [False, True, False]]
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
layout = go.Layout(
title=metric + " vs. epoch<br>[dash=train, solid=test]",
xaxis={"title": "epoch", "titlefont": titlefont},
yaxis={"title": metric, "titlefont": titlefont},
showlegend=True,
hoverlabel={"namelength": -1},
updatemenus=[
{
"buttons": buttons,
"direction": "down",
"showactive": True,
"x": 1.02,
"xanchor": "left",
"y": 1.08,
"yanchor": "top",
}
],
)
# Create plotly plot
offline.plot({"data": data, "layout": layout}, filename=filename)
def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
"""Plot error curves using matplotlib.pyplot and save to file."""
plot_data = prepare_plot_data(log_files, names, metric)
colors = get_plot_colors(len(names))
for ind, d in enumerate(plot_data):
c, lbl = colors[ind], d["test_label"]
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
plt.xlabel("epoch", fontsize=14)
plt.ylabel(metric, fontsize=14)
plt.grid(alpha=0.4)
plt.legend()
if filename:
plt.savefig(filename)
plt.clf()
else:
plt.show()