79 lines
2.9 KiB
79 lines
2.9 KiB
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
# python exps/compare.py --checkpoints basic.pth order.pth --names basic order --save ./output/vis/basic-vs-order.pdf
import sys, time, torch, random, argparse
from PIL import ImageFile
from copy import deepcopy
from pathlib import Path
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
parser = argparse.ArgumentParser(description='Visualize the checkpoint and compare', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--checkpoints', type=str, nargs='+', help='checkpoint paths.')
parser.add_argument('--names', type=str, nargs='+', help='names.')
parser.add_argument('--save', type=str, help='the save path.')
args = parser.parse_args()
def visualize_acc(epochs, accuracies, names, save_path):
LabelSize = 24
LegendFontsize = 22
matplotlib.rcParams['xtick.labelsize'] = LabelSize
matplotlib.rcParams['ytick.labelsize'] = LabelSize
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi = 300
width, height = 3400, 3600
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
plt.xlim(0, max(epochs))
plt.ylim(0, 100)
interval_x, interval_y = 20, 10
plt.xticks(np.arange(0, max(epochs) + interval_x, interval_x), fontsize=LegendFontsize)
plt.yticks(np.arange(0, 100 + interval_y, interval_y), fontsize=LegendFontsize)
plt.xlabel('epoch', fontsize=16)
plt.ylabel('accuracy (%)', fontsize=16)
for idx, tag in enumerate(names):
xaccs = [accuracies[idx][x] for x in epochs]
plt.plot(epochs, xaccs, color=color_set[idx], linestyle='-', label='Test Accuracy : {:}'.format(tag), lw=3)
plt.legend(loc=4, fontsize=LegendFontsize)
if save_path is not None:
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
print ('---- save figure into {:}.'.format(save_path))
def main():
checkpoints, names = args.checkpoints, args.names
assert len(checkpoints) == len(names), 'invalid length : {:} vs {:}'.format(len(checkpoints), len(names))
for i, checkpoint in enumerate(checkpoints):
assert Path(checkpoint).exists(), 'The {:}-th checkpoint : {:} does not exist'.format( checkpoint )
save_path = Path(args.save)
save_dir = save_path.parent
save_dir.mkdir(parents=True, exist_ok=True)
accuracies = []
for checkpoint in checkpoints:
checkpoint = torch.load( checkpoint )
accuracies.append( checkpoint['valid_accuracies'] )
epochs = [x for x in accuracies[0].keys() if isinstance(x, int)]
epochs = sorted( epochs )
visualize_acc(epochs, accuracies, names, save_path)
if __name__ == '__main__':