182 lines
6.6 KiB
Python
182 lines
6.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
|
|
def count_parameters_in_MB(model):
|
|
if isinstance(model, nn.Module):
|
|
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
|
|
else:
|
|
return np.sum(np.prod(v.size()) for v in model)/1e6
|
|
|
|
|
|
def get_model_infos(model, shape):
|
|
#model = copy.deepcopy( model )
|
|
|
|
model = add_flops_counting_methods(model)
|
|
#model = model.cuda()
|
|
model.eval()
|
|
|
|
#cache_inputs = torch.zeros(*shape).cuda()
|
|
#cache_inputs = torch.zeros(*shape)
|
|
cache_inputs = torch.rand(*shape)
|
|
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
|
|
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
|
|
with torch.no_grad():
|
|
_____ = model(cache_inputs)
|
|
FLOPs = compute_average_flops_cost( model ) / 1e6
|
|
Param = count_parameters_in_MB(model)
|
|
|
|
if hasattr(model, 'auxiliary_param'):
|
|
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
|
print ('The auxiliary params of this model is : {:}'.format(aux_params))
|
|
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
|
|
Param = Param - aux_params
|
|
|
|
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
|
torch.cuda.empty_cache()
|
|
model.apply( remove_hook_function )
|
|
return FLOPs, Param
|
|
|
|
|
|
# ---- Public functions
|
|
def add_flops_counting_methods( model ):
|
|
model.__batch_counter__ = 0
|
|
add_batch_counter_hook_function( model )
|
|
model.apply( add_flops_counter_variable_or_reset )
|
|
model.apply( add_flops_counter_hook_function )
|
|
return model
|
|
|
|
|
|
|
|
def compute_average_flops_cost(model):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is called on a desired net object.
|
|
Returns current mean flops consumption per image.
|
|
"""
|
|
batches_count = model.__batch_counter__
|
|
flops_sum = 0
|
|
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
|
for module in model.modules():
|
|
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
|
or isinstance(module, torch.nn.Conv1d) \
|
|
or hasattr(module, 'calculate_flop_self'):
|
|
flops_sum += module.__flops__
|
|
return flops_sum / batches_count
|
|
|
|
|
|
# ---- Internal functions
|
|
def pool_flops_counter_hook(pool_module, inputs, output):
|
|
batch_size = inputs[0].size(0)
|
|
kernel_size = pool_module.kernel_size
|
|
out_C, output_height, output_width = output.shape[1:]
|
|
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
|
|
|
|
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
|
pool_module.__flops__ += overall_flops
|
|
|
|
|
|
def self_calculate_flops_counter_hook(self_module, inputs, output):
|
|
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
|
|
self_module.__flops__ += overall_flops
|
|
|
|
|
|
def fc_flops_counter_hook(fc_module, inputs, output):
|
|
batch_size = inputs[0].size(0)
|
|
xin, xout = fc_module.in_features, fc_module.out_features
|
|
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
|
|
overall_flops = batch_size * xin * xout
|
|
if fc_module.bias is not None:
|
|
overall_flops += batch_size * xout
|
|
fc_module.__flops__ += overall_flops
|
|
|
|
|
|
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
|
|
batch_size = inputs[0].size(0)
|
|
outL = outputs.shape[-1]
|
|
[kernel] = conv_module.kernel_size
|
|
in_channels = conv_module.in_channels
|
|
out_channels = conv_module.out_channels
|
|
groups = conv_module.groups
|
|
conv_per_position_flops = kernel * in_channels * out_channels / groups
|
|
|
|
active_elements_count = batch_size * outL
|
|
overall_flops = conv_per_position_flops * active_elements_count
|
|
|
|
if conv_module.bias is not None:
|
|
overall_flops += out_channels * active_elements_count
|
|
conv_module.__flops__ += overall_flops
|
|
|
|
|
|
def conv2d_flops_counter_hook(conv_module, inputs, output):
|
|
batch_size = inputs[0].size(0)
|
|
output_height, output_width = output.shape[2:]
|
|
|
|
kernel_height, kernel_width = conv_module.kernel_size
|
|
in_channels = conv_module.in_channels
|
|
out_channels = conv_module.out_channels
|
|
groups = conv_module.groups
|
|
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
|
|
|
|
active_elements_count = batch_size * output_height * output_width
|
|
overall_flops = conv_per_position_flops * active_elements_count
|
|
|
|
if conv_module.bias is not None:
|
|
overall_flops += out_channels * active_elements_count
|
|
conv_module.__flops__ += overall_flops
|
|
|
|
|
|
def batch_counter_hook(module, inputs, output):
|
|
# Can have multiple inputs, getting the first one
|
|
inputs = inputs[0]
|
|
batch_size = inputs.shape[0]
|
|
module.__batch_counter__ += batch_size
|
|
|
|
|
|
def add_batch_counter_hook_function(module):
|
|
if not hasattr(module, '__batch_counter_handle__'):
|
|
handle = module.register_forward_hook(batch_counter_hook)
|
|
module.__batch_counter_handle__ = handle
|
|
|
|
|
|
def add_flops_counter_variable_or_reset(module):
|
|
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
|
or isinstance(module, torch.nn.Conv1d) \
|
|
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
|
or hasattr(module, 'calculate_flop_self'):
|
|
module.__flops__ = 0
|
|
|
|
|
|
def add_flops_counter_hook_function(module):
|
|
if isinstance(module, torch.nn.Conv2d):
|
|
if not hasattr(module, '__flops_handle__'):
|
|
handle = module.register_forward_hook(conv2d_flops_counter_hook)
|
|
module.__flops_handle__ = handle
|
|
elif isinstance(module, torch.nn.Conv1d):
|
|
if not hasattr(module, '__flops_handle__'):
|
|
handle = module.register_forward_hook(conv1d_flops_counter_hook)
|
|
module.__flops_handle__ = handle
|
|
elif isinstance(module, torch.nn.Linear):
|
|
if not hasattr(module, '__flops_handle__'):
|
|
handle = module.register_forward_hook(fc_flops_counter_hook)
|
|
module.__flops_handle__ = handle
|
|
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
|
if not hasattr(module, '__flops_handle__'):
|
|
handle = module.register_forward_hook(pool_flops_counter_hook)
|
|
module.__flops_handle__ = handle
|
|
elif hasattr(module, 'calculate_flop_self'): # self-defined module
|
|
if not hasattr(module, '__flops_handle__'):
|
|
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
|
|
module.__flops_handle__ = handle
|
|
|
|
|
|
def remove_hook_function(module):
|
|
hookers = ['__batch_counter_handle__', '__flops_handle__']
|
|
for hooker in hookers:
|
|
if hasattr(module, hooker):
|
|
handle = getattr(module, hooker)
|
|
handle.remove()
|
|
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
|
|
for ckey in keys:
|
|
if hasattr(module, ckey): delattr(module, ckey)
|