import os, sys, time import numpy as np import random class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class RecorderMeter(object): """Computes and stores the minimum loss value and its epoch index""" def __init__(self, total_epoch): self.reset(total_epoch) def reset(self, total_epoch): assert total_epoch > 0 self.total_epoch = total_epoch self.current_epoch = 0 self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] self.epoch_losses = self.epoch_losses - 1 self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] self.epoch_accuracy= self.epoch_accuracy def update(self, idx, train_loss, train_acc, val_loss, val_acc): assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx) self.epoch_losses [idx, 0] = train_loss self.epoch_losses [idx, 1] = val_loss self.epoch_accuracy[idx, 0] = train_acc self.epoch_accuracy[idx, 1] = val_acc self.current_epoch = idx + 1 return self.max_accuracy(False) == self.epoch_accuracy[idx, 1] def max_accuracy(self, istrain): if self.current_epoch <= 0: return 0 if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max() else: return self.epoch_accuracy[:self.current_epoch, 1].max() def plot_curve(self, save_path): import matplotlib matplotlib.use('agg') import matplotlib.pyplot as plt title = 'the accuracy/loss curve of train/val' dpi = 100 width, height = 1600, 1000 legend_fontsize = 10 figsize = width / float(dpi), height / float(dpi) fig = plt.figure(figsize=figsize) x_axis = np.array([i for i in range(self.total_epoch)]) # epochs y_axis = np.zeros(self.total_epoch) plt.xlim(0, self.total_epoch) plt.ylim(0, 100) interval_y = 5 interval_x = 5 plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) plt.yticks(np.arange(0, 100 + interval_y, interval_y)) plt.grid() plt.title(title, fontsize=20) plt.xlabel('the training epoch', fontsize=16) plt.ylabel('accuracy', fontsize=16) y_axis[:] = self.epoch_accuracy[:, 0] plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2) plt.legend(loc=4, fontsize=legend_fontsize) y_axis[:] = self.epoch_accuracy[:, 1] plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2) plt.legend(loc=4, fontsize=legend_fontsize) y_axis[:] = self.epoch_losses[:, 0] plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2) plt.legend(loc=4, fontsize=legend_fontsize) y_axis[:] = self.epoch_losses[:, 1] plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2) plt.legend(loc=4, fontsize=legend_fontsize) if save_path is not None: fig.savefig(save_path, dpi=dpi, bbox_inches='tight') print ('---- save figure {} into {}'.format(title, save_path)) plt.close(fig) def print_log(print_string, log): print ("{:}".format(print_string)) if log is not None: log.write('{}\n'.format(print_string)) log.flush() def time_file_str(): ISOTIMEFORMAT='%Y-%m-%d' string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) return string + '-{}'.format(random.randint(1, 10000)) def time_string(): ISOTIMEFORMAT='%Y-%m-%d-%X' string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) return string def convert_secs2time(epoch_time, return_str=False): need_hour = int(epoch_time / 3600) need_mins = int((epoch_time - 3600*need_hour) / 60) need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) if return_str == False: return need_hour, need_mins, need_secs else: return '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) def test_imagenet_data(imagenet): total_length = len(imagenet) assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) map_id = {} for index in range(total_length): path, target = imagenet.imgs[index] folder, image_name = os.path.split(path) _, folder = os.path.split(folder) if folder not in map_id: map_id[folder] = target else: assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) assert image_name.find(folder) == 0, '{} is wrong.'.format(path) print ('Check ImageNet Dataset OK')