Update tests for torch/cuda
This commit is contained in:
		| @@ -9,15 +9,19 @@ def count_parameters_in_MB(model): | ||||
|  | ||||
| def count_parameters(model_or_parameters, unit="mb"): | ||||
|     if isinstance(model_or_parameters, nn.Module): | ||||
|         counts = np.sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||
|     elif isinstance(models_or_parameters, nn.Parameter): | ||||
|         counts = models_or_parameters.numel() | ||||
|     elif isinstance(models_or_parameters, (list, tuple)): | ||||
|         counts = sum(count_parameters(x, None) for x in models_or_parameters) | ||||
|     else: | ||||
|         counts = np.sum(np.prod(v.size()) for v in model_or_parameters) | ||||
|     if unit.lower() == "mb": | ||||
|         counts /= 1e6 | ||||
|     elif unit.lower() == "kb": | ||||
|         counts /= 1e3 | ||||
|     elif unit.lower() == "gb": | ||||
|         counts /= 1e9 | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters) | ||||
|     if unit.lower() == "kb" or unit.lower() == "k": | ||||
|         counts /= 2 ** 10  # changed from 1e3 to 2^10 | ||||
|     elif unit.lower() == "mb" or unit.lower() == "m": | ||||
|         counts /= 2 ** 20  # changed from 1e6 to 2^20 | ||||
|     elif unit.lower() == "gb" or unit.lower() == "g": | ||||
|         counts /= 2 ** 30  # changed from 1e9 to 2^30 | ||||
|     elif unit is not None: | ||||
|         raise ValueError("Unknow unit: {:}".format(unit)) | ||||
|     return counts | ||||
|   | ||||
		Reference in New Issue
	
	Block a user