autodl-projects/xautodl/xmisc/torch_utils.py
2021-06-10 21:53:22 +08:00

27 lines
1.0 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
#####################################################
import torch
import torch.nn as nn
import numpy as np
def count_parameters(model_or_parameters, unit="mb"):
if isinstance(model_or_parameters, nn.Module):
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
elif isinstance(model_or_parameters, nn.Parameter):
counts = models_or_parameters.numel()
elif isinstance(model_or_parameters, (list, tuple)):
counts = sum(count_parameters(x, None) for x in models_or_parameters)
else:
counts = sum(np.prod(v.size()) for v in model_or_parameters)
if unit.lower() == "kb" or unit.lower() == "k":
counts /= 1e3
elif unit.lower() == "mb" or unit.lower() == "m":
counts /= 1e6
elif unit.lower() == "gb" or unit.lower() == "g":
counts /= 1e9
elif unit is not None:
raise ValueError("Unknow unit: {:}".format(unit))
return counts