update README
This commit is contained in:
		| @@ -1,78 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # python exps/compare.py --checkpoints basic.pth order.pth --names basic order --save ./output/vis/basic-vs-order.pdf | ||||
| import sys, time, torch, random, argparse | ||||
| from PIL     import ImageFile | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
| import numpy as np | ||||
| import matplotlib | ||||
| matplotlib.use('agg') | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='Visualize the checkpoint and compare', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
| parser.add_argument('--checkpoints', type=str,    nargs='+',     help='checkpoint paths.') | ||||
| parser.add_argument('--names',       type=str,    nargs='+',     help='names.') | ||||
| parser.add_argument('--save',        type=str,                   help='the save path.') | ||||
| args = parser.parse_args() | ||||
|  | ||||
|  | ||||
| def visualize_acc(epochs, accuracies, names, save_path): | ||||
|  | ||||
|   LabelSize = 24 | ||||
|   LegendFontsize = 22 | ||||
|   matplotlib.rcParams['xtick.labelsize'] = LabelSize  | ||||
|   matplotlib.rcParams['ytick.labelsize'] = LabelSize  | ||||
|   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||
|   dpi = 300 | ||||
|   width, height = 3400, 3600 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|  | ||||
|   fig = plt.figure(figsize=figsize) | ||||
|   plt.xlim(0, max(epochs)) | ||||
|   plt.ylim(0, 100) | ||||
|   interval_x, interval_y = 20, 10 | ||||
|   plt.xticks(np.arange(0, max(epochs) + interval_x, interval_x), fontsize=LegendFontsize) | ||||
|   plt.yticks(np.arange(0, 100 + interval_y, interval_y), fontsize=LegendFontsize) | ||||
|   plt.grid() | ||||
|    | ||||
|   plt.xlabel('epoch', fontsize=16) | ||||
|   plt.ylabel('accuracy (%)', fontsize=16) | ||||
|  | ||||
|   for idx, tag in enumerate(names): | ||||
|     xaccs = [accuracies[idx][x] for x in epochs] | ||||
|     plt.plot(epochs, xaccs, color=color_set[idx], linestyle='-', label='Test Accuracy : {:}'.format(tag), lw=3) | ||||
|     plt.legend(loc=4, fontsize=LegendFontsize) | ||||
|    | ||||
|   if save_path is not None: | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||
|     print ('---- save figure into {:}.'.format(save_path)) | ||||
|   plt.close(fig) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|   checkpoints, names = args.checkpoints, args.names | ||||
|   assert len(checkpoints) == len(names), 'invalid length : {:} vs {:}'.format(len(checkpoints), len(names)) | ||||
|   for i, checkpoint in enumerate(checkpoints): | ||||
|     assert Path(checkpoint).exists(), 'The {:}-th checkpoint : {:} does not exist'.format( checkpoint ) | ||||
|  | ||||
|   save_path = Path(args.save) | ||||
|   save_dir  = save_path.parent | ||||
|   save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   accuracies = [] | ||||
|   for checkpoint in checkpoints: | ||||
|     checkpoint = torch.load( checkpoint ) | ||||
|     accuracies.append( checkpoint['valid_accuracies'] ) | ||||
|   epochs = [x for x in accuracies[0].keys() if isinstance(x, int)] | ||||
|   epochs = sorted( epochs ) | ||||
|     | ||||
|   visualize_acc(epochs, accuracies, names, save_path) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main() | ||||
		Reference in New Issue
	
	Block a user