219 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			219 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import torch.nn as nn
 | |
| import numpy as np
 | |
| 
 | |
| 
 | |
| def count_parameters_in_MB(model):
 | |
|     return count_parameters(model, "mb")
 | |
| 
 | |
| 
 | |
| def count_parameters(model_or_parameters, unit="mb"):
 | |
|     if isinstance(model_or_parameters, nn.Module):
 | |
|         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
 | |
|     elif isinstance(models_or_parameters, nn.Parameter):
 | |
|         counts = models_or_parameters.numel()
 | |
|     elif isinstance(models_or_parameters, (list, tuple)):
 | |
|         counts = sum(count_parameters(x, None) for x in models_or_parameters)
 | |
|     else:
 | |
|         counts = sum(np.prod(v.size()) for v in model_or_parameters)
 | |
|     if unit.lower() == "kb" or unit.lower() == "k":
 | |
|         counts /= 2 ** 10  # changed from 1e3 to 2^10
 | |
|     elif unit.lower() == "mb" or unit.lower() == "m":
 | |
|         counts /= 2 ** 20  # changed from 1e6 to 2^20
 | |
|     elif unit.lower() == "gb" or unit.lower() == "g":
 | |
|         counts /= 2 ** 30  # changed from 1e9 to 2^30
 | |
|     elif unit is not None:
 | |
|         raise ValueError("Unknow unit: {:}".format(unit))
 | |
|     return counts
 | |
| 
 | |
| 
 | |
| def get_model_infos(model, shape):
 | |
|     # model = copy.deepcopy( model )
 | |
| 
 | |
|     model = add_flops_counting_methods(model)
 | |
|     # model = model.cuda()
 | |
|     model.eval()
 | |
| 
 | |
|     # cache_inputs = torch.zeros(*shape).cuda()
 | |
|     # cache_inputs = torch.zeros(*shape)
 | |
|     cache_inputs = torch.rand(*shape)
 | |
|     if next(model.parameters()).is_cuda:
 | |
|         cache_inputs = cache_inputs.cuda()
 | |
|     # print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
 | |
|     with torch.no_grad():
 | |
|         _____ = model(cache_inputs)
 | |
|     FLOPs = compute_average_flops_cost(model) / 1e6
 | |
|     Param = count_parameters_in_MB(model)
 | |
| 
 | |
|     if hasattr(model, "auxiliary_param"):
 | |
|         aux_params = count_parameters_in_MB(model.auxiliary_param())
 | |
|         print("The auxiliary params of this model is : {:}".format(aux_params))
 | |
|         print(
 | |
|             "We remove the auxiliary params from the total params ({:}) when counting".format(
 | |
|                 Param
 | |
|             )
 | |
|         )
 | |
|         Param = Param - aux_params
 | |
| 
 | |
|     # print_log('FLOPs : {:} MB'.format(FLOPs), log)
 | |
|     torch.cuda.empty_cache()
 | |
|     model.apply(remove_hook_function)
 | |
|     return FLOPs, Param
 | |
| 
 | |
| 
 | |
| # ---- Public functions
 | |
| def add_flops_counting_methods(model):
 | |
|     model.__batch_counter__ = 0
 | |
|     add_batch_counter_hook_function(model)
 | |
|     model.apply(add_flops_counter_variable_or_reset)
 | |
|     model.apply(add_flops_counter_hook_function)
 | |
|     return model
 | |
| 
 | |
| 
 | |
| def compute_average_flops_cost(model):
 | |
|     """
 | |
|     A method that will be available after add_flops_counting_methods() is called on a desired net object.
 | |
|     Returns current mean flops consumption per image.
 | |
|     """
 | |
|     batches_count = model.__batch_counter__
 | |
|     flops_sum = 0
 | |
|     # or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
 | |
|     for module in model.modules():
 | |
|         if (
 | |
|             isinstance(module, torch.nn.Conv2d)
 | |
|             or isinstance(module, torch.nn.Linear)
 | |
|             or isinstance(module, torch.nn.Conv1d)
 | |
|             or hasattr(module, "calculate_flop_self")
 | |
|         ):
 | |
|             flops_sum += module.__flops__
 | |
|     return flops_sum / batches_count
 | |
| 
 | |
| 
 | |
| # ---- Internal functions
 | |
| def pool_flops_counter_hook(pool_module, inputs, output):
 | |
|     batch_size = inputs[0].size(0)
 | |
|     kernel_size = pool_module.kernel_size
 | |
|     out_C, output_height, output_width = output.shape[1:]
 | |
|     assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size())
 | |
| 
 | |
|     overall_flops = (
 | |
|         batch_size * out_C * output_height * output_width * kernel_size * kernel_size
 | |
|     )
 | |
|     pool_module.__flops__ += overall_flops
 | |
| 
 | |
| 
 | |
| def self_calculate_flops_counter_hook(self_module, inputs, output):
 | |
|     overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
 | |
|     self_module.__flops__ += overall_flops
 | |
| 
 | |
| 
 | |
| def fc_flops_counter_hook(fc_module, inputs, output):
 | |
|     batch_size = inputs[0].size(0)
 | |
|     xin, xout = fc_module.in_features, fc_module.out_features
 | |
|     assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(
 | |
|         xin, xout
 | |
|     )
 | |
|     overall_flops = batch_size * xin * xout
 | |
|     if fc_module.bias is not None:
 | |
|         overall_flops += batch_size * xout
 | |
|     fc_module.__flops__ += overall_flops
 | |
| 
 | |
| 
 | |
| def conv1d_flops_counter_hook(conv_module, inputs, outputs):
 | |
|     batch_size = inputs[0].size(0)
 | |
|     outL = outputs.shape[-1]
 | |
|     [kernel] = conv_module.kernel_size
 | |
|     in_channels = conv_module.in_channels
 | |
|     out_channels = conv_module.out_channels
 | |
|     groups = conv_module.groups
 | |
|     conv_per_position_flops = kernel * in_channels * out_channels / groups
 | |
| 
 | |
|     active_elements_count = batch_size * outL
 | |
|     overall_flops = conv_per_position_flops * active_elements_count
 | |
| 
 | |
|     if conv_module.bias is not None:
 | |
|         overall_flops += out_channels * active_elements_count
 | |
|     conv_module.__flops__ += overall_flops
 | |
| 
 | |
| 
 | |
| def conv2d_flops_counter_hook(conv_module, inputs, output):
 | |
|     batch_size = inputs[0].size(0)
 | |
|     output_height, output_width = output.shape[2:]
 | |
| 
 | |
|     kernel_height, kernel_width = conv_module.kernel_size
 | |
|     in_channels = conv_module.in_channels
 | |
|     out_channels = conv_module.out_channels
 | |
|     groups = conv_module.groups
 | |
|     conv_per_position_flops = (
 | |
|         kernel_height * kernel_width * in_channels * out_channels / groups
 | |
|     )
 | |
| 
 | |
|     active_elements_count = batch_size * output_height * output_width
 | |
|     overall_flops = conv_per_position_flops * active_elements_count
 | |
| 
 | |
|     if conv_module.bias is not None:
 | |
|         overall_flops += out_channels * active_elements_count
 | |
|     conv_module.__flops__ += overall_flops
 | |
| 
 | |
| 
 | |
| def batch_counter_hook(module, inputs, output):
 | |
|     # Can have multiple inputs, getting the first one
 | |
|     inputs = inputs[0]
 | |
|     batch_size = inputs.shape[0]
 | |
|     module.__batch_counter__ += batch_size
 | |
| 
 | |
| 
 | |
| def add_batch_counter_hook_function(module):
 | |
|     if not hasattr(module, "__batch_counter_handle__"):
 | |
|         handle = module.register_forward_hook(batch_counter_hook)
 | |
|         module.__batch_counter_handle__ = handle
 | |
| 
 | |
| 
 | |
| def add_flops_counter_variable_or_reset(module):
 | |
|     if (
 | |
|         isinstance(module, torch.nn.Conv2d)
 | |
|         or isinstance(module, torch.nn.Linear)
 | |
|         or isinstance(module, torch.nn.Conv1d)
 | |
|         or isinstance(module, torch.nn.AvgPool2d)
 | |
|         or isinstance(module, torch.nn.MaxPool2d)
 | |
|         or hasattr(module, "calculate_flop_self")
 | |
|     ):
 | |
|         module.__flops__ = 0
 | |
| 
 | |
| 
 | |
| def add_flops_counter_hook_function(module):
 | |
|     if isinstance(module, torch.nn.Conv2d):
 | |
|         if not hasattr(module, "__flops_handle__"):
 | |
|             handle = module.register_forward_hook(conv2d_flops_counter_hook)
 | |
|             module.__flops_handle__ = handle
 | |
|     elif isinstance(module, torch.nn.Conv1d):
 | |
|         if not hasattr(module, "__flops_handle__"):
 | |
|             handle = module.register_forward_hook(conv1d_flops_counter_hook)
 | |
|             module.__flops_handle__ = handle
 | |
|     elif isinstance(module, torch.nn.Linear):
 | |
|         if not hasattr(module, "__flops_handle__"):
 | |
|             handle = module.register_forward_hook(fc_flops_counter_hook)
 | |
|             module.__flops_handle__ = handle
 | |
|     elif isinstance(module, torch.nn.AvgPool2d) or isinstance(
 | |
|         module, torch.nn.MaxPool2d
 | |
|     ):
 | |
|         if not hasattr(module, "__flops_handle__"):
 | |
|             handle = module.register_forward_hook(pool_flops_counter_hook)
 | |
|             module.__flops_handle__ = handle
 | |
|     elif hasattr(module, "calculate_flop_self"):  # self-defined module
 | |
|         if not hasattr(module, "__flops_handle__"):
 | |
|             handle = module.register_forward_hook(self_calculate_flops_counter_hook)
 | |
|             module.__flops_handle__ = handle
 | |
| 
 | |
| 
 | |
| def remove_hook_function(module):
 | |
|     hookers = ["__batch_counter_handle__", "__flops_handle__"]
 | |
|     for hooker in hookers:
 | |
|         if hasattr(module, hooker):
 | |
|             handle = getattr(module, hooker)
 | |
|             handle.remove()
 | |
|     keys = ["__flops__", "__batch_counter__", "__flops__"] + hookers
 | |
|     for ckey in keys:
 | |
|         if hasattr(module, ckey):
 | |
|             delattr(module, ckey)
 |