MeCo/zero-cost-nas/foresight/pruners/measures/param_count.py
HamsterMimi 189df25fd3 upload
2023-05-04 13:09:03 +08:00

16 lines
475 B
Python

import time
import torch
from . import measure
from ..p_utils import get_layer_metric_array
@measure('param_count', copy_net=False, mode='param')
def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1):
s = time.time()
count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode)
e = time.time()
t = e - s
# print(f'param_count time: {t} s')
return count, t