MeCo/correlation/foresight/pruners/measures/zico.py
HamsterMimi 3f6d16e791 update
2024-01-23 10:08:45 +08:00

106 lines
3.4 KiB
Python

# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
from torch import nn
from ...dataset import get_cifar_dataloaders
def getgrad(model: torch.nn.Module, grad_dict: dict, step_iter=0):
if step_iter == 0:
for name, mod in model.named_modules():
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
# print(mod.weight.grad.data.size())
# print(mod.weight.data.size())
try:
grad_dict[name] = [mod.weight.grad.data.cpu().reshape(-1).numpy()]
except:
continue
else:
for name, mod in model.named_modules():
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
try:
grad_dict[name].append(mod.weight.grad.data.cpu().reshape(-1).numpy())
except:
continue
return grad_dict
def caculate_zico(grad_dict):
allgrad_array = None
for i, modname in enumerate(grad_dict.keys()):
grad_dict[modname] = np.array(grad_dict[modname])
nsr_mean_sum = 0
nsr_mean_sum_abs = 0
nsr_mean_avg = 0
nsr_mean_avg_abs = 0
for j, modname in enumerate(grad_dict.keys()):
nsr_std = np.std(grad_dict[modname], axis=0)
# print(grad_dict[modname].shape)
# print(grad_dict[modname].shape, nsr_std.shape)
nonzero_idx = np.nonzero(nsr_std)[0]
nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0)
tmpsum = np.sum(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx])
if tmpsum == 0:
pass
else:
nsr_mean_sum_abs += np.log(tmpsum)
nsr_mean_avg_abs += np.log(np.mean(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx]))
return nsr_mean_sum_abs
def getzico(network, inputs, targets, loss_fn, split_data=2):
grad_dict = {}
network.train()
device = inputs.device
network.to(device)
N = inputs.shape[0]
split_data = 2
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
outputs = network.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
grad_dict = getgrad(network, grad_dict, sp)
# print(grad_dict)
res = caculate_zico(grad_dict)
return res
@measure('zico', bn=True)
def compute_zico(net, inputs, targets, split_data=2, loss_fn=None):
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', feature.shape)
try:
zico = getzico(net, inputs, targets, loss_fn, split_data=split_data)
except Exception as e:
print(e)
zico= np.nan
# print(jc)
# print(f'var time: {t} s')
return zico