Temp / 0.5
This commit is contained in:
		| @@ -4,10 +4,23 @@ import numpy as np | ||||
|  | ||||
|  | ||||
| def count_parameters_in_MB(model): | ||||
|   if isinstance(model, nn.Module): | ||||
|     return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 | ||||
|   return count_parameters(model, "mb") | ||||
|  | ||||
|  | ||||
| def count_parameters(model_or_parameters, unit="mb"): | ||||
|   if isinstance(model_or_parameters, nn.Module): | ||||
|     counts = np.sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||
|   else: | ||||
|     return np.sum(np.prod(v.size()) for v in model)/1e6 | ||||
|     counts = np.sum(np.prod(v.size()) for v in model_or_parameters) | ||||
|   if unit.lower() == "mb": | ||||
|     counts /= 1e6 | ||||
|   elif unit.lower() == "kb": | ||||
|     counts /= 1e3 | ||||
|   elif unit.lower() == "gb": | ||||
|     counts /= 1e9 | ||||
|   elif unit is not None: | ||||
|     raise ValueError("Unknow unit: {:}".format(unit)) | ||||
|   return counts | ||||
|  | ||||
|  | ||||
| def get_model_infos(model, shape): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user