Add more algorithms
This commit is contained in:
		| @@ -1,16 +1,6 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .utils import AverageMeter, RecorderMeter, convert_secs2time | ||||
| from .utils import time_file_str, time_string | ||||
| from .utils import test_imagenet_data | ||||
| from .utils import print_log | ||||
| from .evaluation_utils import obtain_accuracy | ||||
| #from .draw_pts import draw_points | ||||
| from .gpu_manager import GPUManager | ||||
|  | ||||
| from .save_meta import Save_Meta | ||||
|  | ||||
| from .model_utils import count_parameters_in_MB | ||||
| from .model_utils import Cutout | ||||
| from .flop_benchmark import print_FLOPs | ||||
| from .gpu_manager      import GPUManager | ||||
| from .flop_benchmark   import get_model_infos | ||||
|   | ||||
| @@ -1,41 +0,0 @@ | ||||
| import os, sys, time | ||||
| import numpy as np | ||||
| import matplotlib | ||||
| import random | ||||
| matplotlib.use('agg') | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.cm as cm | ||||
|  | ||||
| def draw_points(points, labels, save_path): | ||||
|   title = 'the visualized features' | ||||
|   dpi = 100  | ||||
|   width, height = 1000, 1000 | ||||
|   legend_fontsize = 10 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   fig = plt.figure(figsize=figsize) | ||||
|  | ||||
|   classes = np.unique(labels).tolist() | ||||
|   colors = cm.rainbow(np.linspace(0, 1, len(classes))) | ||||
|  | ||||
|   legends = [] | ||||
|   legendnames = [] | ||||
|  | ||||
|   for cls, c in zip(classes, colors): | ||||
|      | ||||
|     indexes = labels == cls | ||||
|     ptss = points[indexes, :] | ||||
|     x = ptss[:,0] | ||||
|     y = ptss[:,1] | ||||
|     if cls % 2 == 0: marker = 'x' | ||||
|     else:            marker = 'o' | ||||
|     legend = plt.scatter(x, y, color=c, s=1, marker=marker) | ||||
|     legendname = '{:02d}'.format(cls+1) | ||||
|     legends.append( legend ) | ||||
|     legendnames.append( legendname ) | ||||
|  | ||||
|   plt.legend(legends, legendnames, scatterpoints=1, ncol=5, fontsize=8) | ||||
|  | ||||
|   if save_path is not None: | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches='tight') | ||||
|     print ('---- save figure {} into {}'.format(title, save_path)) | ||||
|   plt.close(fig) | ||||
| @@ -3,21 +3,44 @@ | ||||
| ################################################## | ||||
| # modified from https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py | ||||
| import copy, torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
|  | ||||
| def print_FLOPs(model, shape, logs): | ||||
|   print_log, log = logs | ||||
|   model = copy.deepcopy( model ) | ||||
|  | ||||
| def count_parameters_in_MB(model): | ||||
|   if isinstance(model, nn.Module): | ||||
|     return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 | ||||
|   else: | ||||
|     return np.sum(np.prod(v.size()) for v in model)/1e6 | ||||
|  | ||||
|  | ||||
| def get_model_infos(model, shape): | ||||
|   #model = copy.deepcopy( model ) | ||||
|  | ||||
|   model = add_flops_counting_methods(model) | ||||
|   model = model.cuda() | ||||
|   #model = model.cuda() | ||||
|   model.eval() | ||||
|  | ||||
|   cache_inputs = torch.zeros(*shape).cuda() | ||||
|   #cache_inputs = torch.zeros(*shape).cuda() | ||||
|   #cache_inputs = torch.zeros(*shape) | ||||
|   cache_inputs = torch.rand(*shape) | ||||
|   if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda() | ||||
|   #print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log) | ||||
|   _ = model(cache_inputs) | ||||
|   with torch.no_grad(): | ||||
|     _____ = model(cache_inputs) | ||||
|   FLOPs = compute_average_flops_cost( model ) / 1e6 | ||||
|   print_log('FLOPs : {:} MB'.format(FLOPs), log) | ||||
|   Param = count_parameters_in_MB(model) | ||||
|  | ||||
|   if hasattr(model, 'auxiliary_param'): | ||||
|     aux_params = count_parameters_in_MB(model.auxiliary_param())  | ||||
|     print ('The auxiliary params of this model is : {:}'.format(aux_params)) | ||||
|     print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param)) | ||||
|     Param = Param - aux_params | ||||
|    | ||||
|   #print_log('FLOPs : {:} MB'.format(FLOPs), log) | ||||
|   torch.cuda.empty_cache() | ||||
|   model.apply( remove_hook_function ) | ||||
|   return FLOPs, Param | ||||
|  | ||||
|  | ||||
| # ---- Public functions | ||||
| @@ -37,8 +60,11 @@ def compute_average_flops_cost(model): | ||||
|   """ | ||||
|   batches_count = model.__batch_counter__ | ||||
|   flops_sum = 0 | ||||
|   #or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ | ||||
|   for module in model.modules(): | ||||
|     if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): | ||||
|     if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ | ||||
|       or isinstance(module, torch.nn.Conv1d) \ | ||||
|       or hasattr(module, 'calculate_flop_self'): | ||||
|       flops_sum += module.__flops__ | ||||
|   return flops_sum / batches_count | ||||
|  | ||||
| @@ -54,6 +80,11 @@ def pool_flops_counter_hook(pool_module, inputs, output): | ||||
|   pool_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def self_calculate_flops_counter_hook(self_module, inputs, output): | ||||
|   overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape) | ||||
|   self_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def fc_flops_counter_hook(fc_module, inputs, output): | ||||
|   batch_size = inputs[0].size(0) | ||||
|   xin, xout = fc_module.in_features, fc_module.out_features | ||||
| @@ -64,7 +95,24 @@ def fc_flops_counter_hook(fc_module, inputs, output): | ||||
|   fc_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def conv_flops_counter_hook(conv_module, inputs, output): | ||||
| def conv1d_flops_counter_hook(conv_module, inputs, outputs): | ||||
|   batch_size   = inputs[0].size(0) | ||||
|   outL         = outputs.shape[-1] | ||||
|   [kernel]     = conv_module.kernel_size | ||||
|   in_channels  = conv_module.in_channels | ||||
|   out_channels = conv_module.out_channels | ||||
|   groups       = conv_module.groups | ||||
|   conv_per_position_flops = kernel * in_channels * out_channels / groups | ||||
|    | ||||
|   active_elements_count = batch_size * outL  | ||||
|   overall_flops = conv_per_position_flops * active_elements_count | ||||
|  | ||||
|   if conv_module.bias is not None: | ||||
|     overall_flops += out_channels * active_elements_count | ||||
|   conv_module.__flops__ += overall_flops | ||||
|  | ||||
|  | ||||
| def conv2d_flops_counter_hook(conv_module, inputs, output): | ||||
|   batch_size = inputs[0].size(0) | ||||
|   output_height, output_width = output.shape[2:] | ||||
|    | ||||
| @@ -97,14 +145,20 @@ def add_batch_counter_hook_function(module): | ||||
|    | ||||
| def add_flops_counter_variable_or_reset(module): | ||||
|   if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ | ||||
|     or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): | ||||
|     or isinstance(module, torch.nn.Conv1d) \ | ||||
|     or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ | ||||
|     or hasattr(module, 'calculate_flop_self'): | ||||
|     module.__flops__ = 0 | ||||
|  | ||||
|  | ||||
| def add_flops_counter_hook_function(module): | ||||
|   if isinstance(module, torch.nn.Conv2d): | ||||
|     if not hasattr(module, '__flops_handle__'): | ||||
|       handle = module.register_forward_hook(conv_flops_counter_hook) | ||||
|       handle = module.register_forward_hook(conv2d_flops_counter_hook) | ||||
|       module.__flops_handle__ = handle | ||||
|   elif isinstance(module, torch.nn.Conv1d): | ||||
|     if not hasattr(module, '__flops_handle__'): | ||||
|       handle = module.register_forward_hook(conv1d_flops_counter_hook) | ||||
|       module.__flops_handle__ = handle | ||||
|   elif isinstance(module, torch.nn.Linear): | ||||
|     if not hasattr(module, '__flops_handle__'): | ||||
| @@ -114,3 +168,18 @@ def add_flops_counter_hook_function(module): | ||||
|     if not hasattr(module, '__flops_handle__'): | ||||
|       handle = module.register_forward_hook(pool_flops_counter_hook) | ||||
|       module.__flops_handle__ = handle | ||||
|   elif hasattr(module, 'calculate_flop_self'): # self-defined module | ||||
|     if not hasattr(module, '__flops_handle__'): | ||||
|       handle = module.register_forward_hook(self_calculate_flops_counter_hook) | ||||
|       module.__flops_handle__ = handle | ||||
|  | ||||
|  | ||||
| def remove_hook_function(module): | ||||
|   hookers = ['__batch_counter_handle__', '__flops_handle__'] | ||||
|   for hooker in hookers: | ||||
|     if hasattr(module, hooker): | ||||
|       handle = getattr(module, hooker) | ||||
|       handle.remove() | ||||
|   keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers | ||||
|   for ckey in keys: | ||||
|     if hasattr(module, ckey): delattr(module, ckey) | ||||
|   | ||||
| @@ -1,35 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def count_parameters_in_MB(model): | ||||
|   if isinstance(model, nn.Module): | ||||
|     return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 | ||||
|   else: | ||||
|     return np.sum(np.prod(v.size()) for v in model)/1e6 | ||||
|  | ||||
|  | ||||
| class Cutout(object): | ||||
|   def __init__(self, length): | ||||
|     self.length = length | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def __call__(self, img): | ||||
|     h, w = img.size(1), img.size(2) | ||||
|     mask = np.ones((h, w), np.float32) | ||||
|     y = np.random.randint(h) | ||||
|     x = np.random.randint(w) | ||||
|  | ||||
|     y1 = np.clip(y - self.length // 2, 0, h) | ||||
|     y2 = np.clip(y + self.length // 2, 0, h) | ||||
|     x1 = np.clip(x - self.length // 2, 0, w) | ||||
|     x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|     mask[y1: y2, x1: x2] = 0. | ||||
|     mask = torch.from_numpy(mask) | ||||
|     mask = mask.expand_as(img) | ||||
|     img *= mask | ||||
|     return img | ||||
| @@ -1,53 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import os, sys | ||||
| import os.path as osp | ||||
| import numpy as np | ||||
|  | ||||
| def tensor2np(x): | ||||
|   if isinstance(x, np.ndarray): return x | ||||
|   if x.is_cuda: x = x.cpu() | ||||
|   return x.numpy() | ||||
|  | ||||
| class Save_Meta(): | ||||
|  | ||||
|   def __init__(self): | ||||
|     self.reset() | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}'.format(name=self.__class__.__name__)+'(number of data = {})'.format(len(self))) | ||||
|  | ||||
|   def reset(self): | ||||
|     self.predictions = [] | ||||
|     self.groundtruth = [] | ||||
|      | ||||
|   def __len__(self): | ||||
|     return len(self.predictions) | ||||
|  | ||||
|   def append(self, _pred, _ground): | ||||
|     _pred, _ground = tensor2np(_pred), tensor2np(_ground) | ||||
|     assert _ground.shape[0] == _pred.shape[0] and len(_pred.shape) == 2 and len(_ground.shape) == 1, 'The shapes are wrong : {} & {}'.format(_pred.shape, _ground.shape) | ||||
|     self.predictions.append(_pred) | ||||
|     self.groundtruth.append(_ground) | ||||
|  | ||||
|   def save(self, save_dir, filename, test=True): | ||||
|     meta = {'predictions': self.predictions,  | ||||
|             'groundtruth': self.groundtruth} | ||||
|     filename = osp.join(save_dir, filename) | ||||
|     torch.save(meta, filename) | ||||
|     if test: | ||||
|       predictions = np.concatenate(self.predictions) | ||||
|       groundtruth = np.concatenate(self.groundtruth) | ||||
|       predictions = np.argmax(predictions, axis=1) | ||||
|       accuracy = np.sum(groundtruth==predictions) * 100.0 / predictions.size | ||||
|     else: | ||||
|       accuracy = None | ||||
|     print ('save save_meta into {} with accuracy = {}'.format(filename, accuracy)) | ||||
|  | ||||
|   def load(self, filename): | ||||
|     assert os.path.isfile(filename), '{} is not a file'.format(filename) | ||||
|     checkpoint       = torch.load(filename) | ||||
|     self.predictions = checkpoint['predictions'] | ||||
|     self.groundtruth = checkpoint['groundtruth'] | ||||
| @@ -1,140 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time | ||||
| import numpy as np | ||||
| import random | ||||
|  | ||||
| class AverageMeter(object): | ||||
|   """Computes and stores the average and current value""" | ||||
|   def __init__(self): | ||||
|     self.reset() | ||||
|  | ||||
|   def reset(self): | ||||
|     self.val = 0 | ||||
|     self.avg = 0 | ||||
|     self.sum = 0 | ||||
|     self.count = 0 | ||||
|  | ||||
|   def update(self, val, n=1): | ||||
|     self.val = val | ||||
|     self.sum += val * n | ||||
|     self.count += n | ||||
|     self.avg = self.sum / self.count | ||||
|  | ||||
|  | ||||
| class RecorderMeter(object): | ||||
|   """Computes and stores the minimum loss value and its epoch index""" | ||||
|   def __init__(self, total_epoch): | ||||
|     self.reset(total_epoch) | ||||
|  | ||||
|   def reset(self, total_epoch): | ||||
|     assert total_epoch > 0 | ||||
|     self.total_epoch   = total_epoch | ||||
|     self.current_epoch = 0 | ||||
|     self.epoch_losses  = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] | ||||
|     self.epoch_losses  = self.epoch_losses - 1 | ||||
|  | ||||
|     self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] | ||||
|     self.epoch_accuracy= self.epoch_accuracy | ||||
|  | ||||
|   def update(self, idx, train_loss, train_acc, val_loss, val_acc): | ||||
|     assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx) | ||||
|     self.epoch_losses  [idx, 0] = train_loss | ||||
|     self.epoch_losses  [idx, 1] = val_loss | ||||
|     self.epoch_accuracy[idx, 0] = train_acc | ||||
|     self.epoch_accuracy[idx, 1] = val_acc | ||||
|     self.current_epoch = idx + 1 | ||||
|     return self.max_accuracy(False) == self.epoch_accuracy[idx, 1] | ||||
|  | ||||
|   def max_accuracy(self, istrain): | ||||
|     if self.current_epoch <= 0: return 0 | ||||
|     if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max() | ||||
|     else:       return self.epoch_accuracy[:self.current_epoch, 1].max() | ||||
|  | ||||
|   def plot_curve(self, save_path): | ||||
|     import matplotlib | ||||
|     matplotlib.use('agg') | ||||
|     import matplotlib.pyplot as plt | ||||
|     title = 'the accuracy/loss curve of train/val' | ||||
|     dpi = 100  | ||||
|     width, height = 1600, 1000 | ||||
|     legend_fontsize = 10 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|  | ||||
|     fig = plt.figure(figsize=figsize) | ||||
|     x_axis = np.array([i for i in range(self.total_epoch)]) # epochs | ||||
|     y_axis = np.zeros(self.total_epoch) | ||||
|  | ||||
|     plt.xlim(0, self.total_epoch) | ||||
|     plt.ylim(0, 100) | ||||
|     interval_y = 5 | ||||
|     interval_x = 5 | ||||
|     plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) | ||||
|     plt.yticks(np.arange(0, 100 + interval_y, interval_y)) | ||||
|     plt.grid() | ||||
|     plt.title(title, fontsize=20) | ||||
|     plt.xlabel('the training epoch', fontsize=16) | ||||
|     plt.ylabel('accuracy', fontsize=16) | ||||
|    | ||||
|     y_axis[:] = self.epoch_accuracy[:, 0] | ||||
|     plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2) | ||||
|     plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|     y_axis[:] = self.epoch_accuracy[:, 1] | ||||
|     plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2) | ||||
|     plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|      | ||||
|     y_axis[:] = self.epoch_losses[:, 0] | ||||
|     plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2) | ||||
|     plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|     y_axis[:] = self.epoch_losses[:, 1] | ||||
|     plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2) | ||||
|     plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|     if save_path is not None: | ||||
|       fig.savefig(save_path, dpi=dpi, bbox_inches='tight') | ||||
|       print ('---- save figure {} into {}'.format(title, save_path)) | ||||
|     plt.close(fig) | ||||
|      | ||||
| def print_log(print_string, log): | ||||
|   print ("{:}".format(print_string)) | ||||
|   if log is not None: | ||||
|     log.write('{}\n'.format(print_string)) | ||||
|     log.flush() | ||||
|  | ||||
| def time_file_str(): | ||||
|   ISOTIMEFORMAT='%Y-%m-%d' | ||||
|   string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) | ||||
|   return string + '-{}'.format(random.randint(1, 10000)) | ||||
|  | ||||
| def time_string(): | ||||
|   ISOTIMEFORMAT='%Y-%m-%d-%X' | ||||
|   string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) | ||||
|   return string | ||||
|  | ||||
| def convert_secs2time(epoch_time, return_str=False): | ||||
|   need_hour = int(epoch_time / 3600) | ||||
|   need_mins = int((epoch_time - 3600*need_hour) / 60) | ||||
|   need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) | ||||
|   if return_str == False: | ||||
|     return need_hour, need_mins, need_secs | ||||
|   else: | ||||
|     return '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) | ||||
|  | ||||
| def test_imagenet_data(imagenet): | ||||
|   total_length = len(imagenet) | ||||
|   assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) | ||||
|   map_id = {} | ||||
|   for index in range(total_length): | ||||
|     path, target = imagenet.imgs[index] | ||||
|     folder, image_name = os.path.split(path) | ||||
|     _, folder = os.path.split(folder) | ||||
|     if folder not in map_id: | ||||
|       map_id[folder] = target | ||||
|     else: | ||||
|       assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) | ||||
|     assert image_name.find(folder) == 0, '{} is wrong.'.format(path) | ||||
|   print ('Check ImageNet Dataset OK') | ||||
		Reference in New Issue
	
	Block a user