113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
import gc
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
import torch
|
|
import torch.nn.functional as f
|
|
from operator import mul
|
|
from functools import reduce
|
|
import copy
|
|
sys.path.insert(0, '../')
|
|
|
|
def Jocab_Score(ori_model, input, target, weights=None):
|
|
model = copy.deepcopy(ori_model)
|
|
model.eval()
|
|
model.proj_weights = weights
|
|
num_edge, num_op = model.num_edge, model.num_op
|
|
for i in range(num_edge):
|
|
model.candidate_flags[i] = False
|
|
batch_size = input.shape[0]
|
|
model.K = torch.zeros(batch_size, batch_size).cuda()
|
|
|
|
def counting_forward_hook(module, inp, out):
|
|
try:
|
|
if isinstance(inp, tuple):
|
|
inp = inp[0]
|
|
inp = inp.view(inp.size(0), -1)
|
|
x = (inp > 0).float()
|
|
K = x @ x.t()
|
|
K2 = (1.-x) @ (1.-x.t())
|
|
model.K = model.K + K + K2
|
|
except:
|
|
pass
|
|
|
|
for name, module in model.named_modules():
|
|
if 'ReLU' in str(type(module)):
|
|
module.register_forward_hook(counting_forward_hook)
|
|
|
|
input = input.cuda()
|
|
model(input)
|
|
score = hooklogdet(model.K.cpu().numpy())
|
|
del model
|
|
del input
|
|
return score
|
|
|
|
def hooklogdet(K, labels=None):
|
|
s, ld = np.linalg.slogdet(K)
|
|
return ld
|
|
|
|
# NTK
|
|
#------------------------------------------------------------
|
|
#https://github.com/VITA-Group/TENAS/blob/main/lib/procedures/ntk.py
|
|
#
|
|
def recal_bn(network, xloader, recalbn, device):
|
|
for m in network.modules():
|
|
if isinstance(m, torch.nn.BatchNorm2d):
|
|
m.running_mean.data.fill_(0)
|
|
m.running_var.data.fill_(0)
|
|
m.num_batches_tracked.data.zero_()
|
|
m.momentum = None
|
|
network.train()
|
|
with torch.no_grad():
|
|
for i, (inputs, targets) in enumerate(xloader):
|
|
if i >= recalbn: break
|
|
inputs = inputs.cuda(device=device, non_blocking=True)
|
|
_, _ = network(inputs)
|
|
return network
|
|
|
|
def get_ntk_n(xloader, networks, recalbn=0, train_mode=False, num_batch=-1, weights=None):
|
|
device = torch.cuda.current_device()
|
|
ntks = []
|
|
copied_networks = []
|
|
for network in networks:
|
|
network = network.cuda(device=device)
|
|
net = copy.deepcopy(network)
|
|
net.proj_weights = weights
|
|
num_edge, num_op = net.num_edge, net.num_op
|
|
for i in range(num_edge):
|
|
net.candidate_flags[i] = False
|
|
if train_mode:
|
|
net.train()
|
|
else:
|
|
net.eval()
|
|
copied_networks.append(net)
|
|
######
|
|
grads = [[] for _ in range(len(copied_networks))]
|
|
for i, (inputs, targets) in enumerate(xloader):
|
|
if num_batch > 0 and i >= num_batch: break
|
|
inputs = inputs.cuda(device=device, non_blocking=True)
|
|
for net_idx, network in enumerate(copied_networks):
|
|
network.zero_grad()
|
|
inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
|
|
logit = network(inputs_)
|
|
if isinstance(logit, tuple):
|
|
logit = logit[1] # 201 networks: return features and logits
|
|
for _idx in range(len(inputs_)):
|
|
logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
|
|
grad = []
|
|
for name, W in network.named_parameters():
|
|
if 'weight' in name and W.grad is not None:
|
|
grad.append(W.grad.view(-1).detach())
|
|
grads[net_idx].append(torch.cat(grad, -1))
|
|
network.zero_grad()
|
|
torch.cuda.empty_cache()
|
|
######
|
|
grads = [torch.stack(_grads, 0) for _grads in grads]
|
|
ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
|
|
conds = []
|
|
for ntk in ntks:
|
|
eigenvalues, _ = torch.symeig(ntk) # ascending
|
|
conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0))
|
|
|
|
del copied_networks
|
|
return conds |