141 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			141 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | |
| ##################################################
 | |
| import os, sys, time
 | |
| import numpy as np
 | |
| import random
 | |
| 
 | |
| class AverageMeter(object):
 | |
|   """Computes and stores the average and current value"""
 | |
|   def __init__(self):
 | |
|     self.reset()
 | |
| 
 | |
|   def reset(self):
 | |
|     self.val = 0
 | |
|     self.avg = 0
 | |
|     self.sum = 0
 | |
|     self.count = 0
 | |
| 
 | |
|   def update(self, val, n=1):
 | |
|     self.val = val
 | |
|     self.sum += val * n
 | |
|     self.count += n
 | |
|     self.avg = self.sum / self.count
 | |
| 
 | |
| 
 | |
| class RecorderMeter(object):
 | |
|   """Computes and stores the minimum loss value and its epoch index"""
 | |
|   def __init__(self, total_epoch):
 | |
|     self.reset(total_epoch)
 | |
| 
 | |
|   def reset(self, total_epoch):
 | |
|     assert total_epoch > 0
 | |
|     self.total_epoch   = total_epoch
 | |
|     self.current_epoch = 0
 | |
|     self.epoch_losses  = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
 | |
|     self.epoch_losses  = self.epoch_losses - 1
 | |
| 
 | |
|     self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
 | |
|     self.epoch_accuracy= self.epoch_accuracy
 | |
| 
 | |
|   def update(self, idx, train_loss, train_acc, val_loss, val_acc):
 | |
|     assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
 | |
|     self.epoch_losses  [idx, 0] = train_loss
 | |
|     self.epoch_losses  [idx, 1] = val_loss
 | |
|     self.epoch_accuracy[idx, 0] = train_acc
 | |
|     self.epoch_accuracy[idx, 1] = val_acc
 | |
|     self.current_epoch = idx + 1
 | |
|     return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
 | |
| 
 | |
|   def max_accuracy(self, istrain):
 | |
|     if self.current_epoch <= 0: return 0
 | |
|     if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
 | |
|     else:       return self.epoch_accuracy[:self.current_epoch, 1].max()
 | |
| 
 | |
|   def plot_curve(self, save_path):
 | |
|     import matplotlib
 | |
|     matplotlib.use('agg')
 | |
|     import matplotlib.pyplot as plt
 | |
|     title = 'the accuracy/loss curve of train/val'
 | |
|     dpi = 100 
 | |
|     width, height = 1600, 1000
 | |
|     legend_fontsize = 10
 | |
|     figsize = width / float(dpi), height / float(dpi)
 | |
| 
 | |
|     fig = plt.figure(figsize=figsize)
 | |
|     x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
 | |
|     y_axis = np.zeros(self.total_epoch)
 | |
| 
 | |
|     plt.xlim(0, self.total_epoch)
 | |
|     plt.ylim(0, 100)
 | |
|     interval_y = 5
 | |
|     interval_x = 5
 | |
|     plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
 | |
|     plt.yticks(np.arange(0, 100 + interval_y, interval_y))
 | |
|     plt.grid()
 | |
|     plt.title(title, fontsize=20)
 | |
|     plt.xlabel('the training epoch', fontsize=16)
 | |
|     plt.ylabel('accuracy', fontsize=16)
 | |
|   
 | |
|     y_axis[:] = self.epoch_accuracy[:, 0]
 | |
|     plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
 | |
|     plt.legend(loc=4, fontsize=legend_fontsize)
 | |
| 
 | |
|     y_axis[:] = self.epoch_accuracy[:, 1]
 | |
|     plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
 | |
|     plt.legend(loc=4, fontsize=legend_fontsize)
 | |
| 
 | |
|     
 | |
|     y_axis[:] = self.epoch_losses[:, 0]
 | |
|     plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
 | |
|     plt.legend(loc=4, fontsize=legend_fontsize)
 | |
| 
 | |
|     y_axis[:] = self.epoch_losses[:, 1]
 | |
|     plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
 | |
|     plt.legend(loc=4, fontsize=legend_fontsize)
 | |
| 
 | |
|     if save_path is not None:
 | |
|       fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
 | |
|       print ('---- save figure {} into {}'.format(title, save_path))
 | |
|     plt.close(fig)
 | |
|     
 | |
| def print_log(print_string, log):
 | |
|   print ("{:}".format(print_string))
 | |
|   if log is not None:
 | |
|     log.write('{}\n'.format(print_string))
 | |
|     log.flush()
 | |
| 
 | |
| def time_file_str():
 | |
|   ISOTIMEFORMAT='%Y-%m-%d'
 | |
|   string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
 | |
|   return string + '-{}'.format(random.randint(1, 10000))
 | |
| 
 | |
| def time_string():
 | |
|   ISOTIMEFORMAT='%Y-%m-%d-%X'
 | |
|   string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
 | |
|   return string
 | |
| 
 | |
| def convert_secs2time(epoch_time, return_str=False):
 | |
|   need_hour = int(epoch_time / 3600)
 | |
|   need_mins = int((epoch_time - 3600*need_hour) / 60)
 | |
|   need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
 | |
|   if return_str == False:
 | |
|     return need_hour, need_mins, need_secs
 | |
|   else:
 | |
|     return '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
 | |
| 
 | |
| def test_imagenet_data(imagenet):
 | |
|   total_length = len(imagenet)
 | |
|   assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length)
 | |
|   map_id = {}
 | |
|   for index in range(total_length):
 | |
|     path, target = imagenet.imgs[index]
 | |
|     folder, image_name = os.path.split(path)
 | |
|     _, folder = os.path.split(folder)
 | |
|     if folder not in map_id:
 | |
|       map_id[folder] = target
 | |
|     else:
 | |
|       assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target)
 | |
|     assert image_name.find(folder) == 0, '{} is wrong.'.format(path)
 | |
|   print ('Check ImageNet Dataset OK')
 |