79 lines
2.9 KiB
Python
79 lines
2.9 KiB
Python
|
##################################################
|
||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||
|
##################################################
|
||
|
# python exps/compare.py --checkpoints basic.pth order.pth --names basic order --save ./output/vis/basic-vs-order.pdf
|
||
|
import sys, time, torch, random, argparse
|
||
|
from PIL import ImageFile
|
||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||
|
from copy import deepcopy
|
||
|
from pathlib import Path
|
||
|
import numpy as np
|
||
|
import matplotlib
|
||
|
matplotlib.use('agg')
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||
|
|
||
|
parser = argparse.ArgumentParser(description='Visualize the checkpoint and compare', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||
|
parser.add_argument('--checkpoints', type=str, nargs='+', help='checkpoint paths.')
|
||
|
parser.add_argument('--names', type=str, nargs='+', help='names.')
|
||
|
parser.add_argument('--save', type=str, help='the save path.')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
|
||
|
def visualize_acc(epochs, accuracies, names, save_path):
|
||
|
|
||
|
LabelSize = 24
|
||
|
LegendFontsize = 22
|
||
|
matplotlib.rcParams['xtick.labelsize'] = LabelSize
|
||
|
matplotlib.rcParams['ytick.labelsize'] = LabelSize
|
||
|
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||
|
dpi = 300
|
||
|
width, height = 3400, 3600
|
||
|
figsize = width / float(dpi), height / float(dpi)
|
||
|
|
||
|
fig = plt.figure(figsize=figsize)
|
||
|
plt.xlim(0, max(epochs))
|
||
|
plt.ylim(0, 100)
|
||
|
interval_x, interval_y = 20, 10
|
||
|
plt.xticks(np.arange(0, max(epochs) + interval_x, interval_x), fontsize=LegendFontsize)
|
||
|
plt.yticks(np.arange(0, 100 + interval_y, interval_y), fontsize=LegendFontsize)
|
||
|
plt.grid()
|
||
|
|
||
|
plt.xlabel('epoch', fontsize=16)
|
||
|
plt.ylabel('accuracy (%)', fontsize=16)
|
||
|
|
||
|
for idx, tag in enumerate(names):
|
||
|
xaccs = [accuracies[idx][x] for x in epochs]
|
||
|
plt.plot(epochs, xaccs, color=color_set[idx], linestyle='-', label='Test Accuracy : {:}'.format(tag), lw=3)
|
||
|
plt.legend(loc=4, fontsize=LegendFontsize)
|
||
|
|
||
|
if save_path is not None:
|
||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||
|
print ('---- save figure into {:}.'.format(save_path))
|
||
|
plt.close(fig)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
checkpoints, names = args.checkpoints, args.names
|
||
|
assert len(checkpoints) == len(names), 'invalid length : {:} vs {:}'.format(len(checkpoints), len(names))
|
||
|
for i, checkpoint in enumerate(checkpoints):
|
||
|
assert Path(checkpoint).exists(), 'The {:}-th checkpoint : {:} does not exist'.format( checkpoint )
|
||
|
|
||
|
save_path = Path(args.save)
|
||
|
save_dir = save_path.parent
|
||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
accuracies = []
|
||
|
for checkpoint in checkpoints:
|
||
|
checkpoint = torch.load( checkpoint )
|
||
|
accuracies.append( checkpoint['valid_accuracies'] )
|
||
|
epochs = [x for x in accuracies[0].keys() if isinstance(x, int)]
|
||
|
epochs = sorted( epochs )
|
||
|
|
||
|
visualize_acc(epochs, accuracies, names, save_path)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|