diff --git a/zero-cost-nas/foresight/pruners/measures/gradsign.py b/zero-cost-nas/foresight/pruners/measures/gradsign.py new file mode 100644 index 0000000..9d7e1f1 --- /dev/null +++ b/zero-cost-nas/foresight/pruners/measures/gradsign.py @@ -0,0 +1,76 @@ +# Copyright 2021 Samsung Electronics Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import torch +from torch import nn +import numpy as np + +from . import measure + + +def get_flattened_metric(net, metric): + grad_list = [] + for layer in net.modules(): + if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): + grad_list.append(metric(layer).flatten()) + flattened_grad = np.concatenate(grad_list) + + return flattened_grad + + +def get_grad_conflict(net, inputs, targets, loss_fn): + N = inputs.shape[0] + batch_grad = [] + for i in range(N): + net.zero_grad() + outputs = net.forward(inputs[[i]]) + loss = loss_fn(outputs, targets[[i]]) + loss.backward() + flattened_grad = get_flattened_metric(net, lambda + l: l.weight.grad.data.clone().cpu().numpy() if l.weight.grad is not None else torch.zeros_like( + l.weight).clone().cpu().numpy()) + batch_grad.append(flattened_grad) + batch_grad = np.stack(batch_grad) + direction_code = np.sign(batch_grad) + direction_code = abs(direction_code.sum(axis=0)) + score = np.nansum(direction_code) + return score + + +def get_gradsign(input, target, net, device, loss_fn): + s = [] + net = net.to(device) + x, target = input, target + # x2 = torch.clone(x) + # x2 = x2.to(device) + x, target = x.to(device), target.to(device) + s.append(get_grad_conflict(net=net, inputs=x, targets=target, loss_fn=loss_fn)) + s = np.mean(s) + return s + +@measure('gradsign', bn=True) +def compute_gradsign(net, inputs, targets, split_data=1, loss_fn=None): + device = inputs.device + # Compute gradients (but don't apply them) + net.zero_grad() + + + try: + gradsign = get_gradsign(inputs, targets, net, device, loss_fn) + except Exception as e: + print(e) + gradsign= np.nan + + return gradsign diff --git a/zero-cost-nas/foresight/pruners/measures/ntk.py b/zero-cost-nas/foresight/pruners/measures/ntk.py new file mode 100644 index 0000000..8891d19 --- /dev/null +++ b/zero-cost-nas/foresight/pruners/measures/ntk.py @@ -0,0 +1,94 @@ +# Copyright 2021 Samsung Electronics Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import torch +import numpy as np + +from . import measure + + +def recal_bn(network, inputs, targets, recalbn, device): + for m in network.modules(): + if isinstance(m, torch.nn.BatchNorm2d): + m.running_mean.data.fill_(0) + m.running_var.data.fill_(0) + m.num_batches_tracked.data.zero_() + m.momentum = None + network.train() + with torch.no_grad(): + for i, (inputs, targets) in enumerate(zip(inputs, targets)): + if i >= recalbn: break + inputs = inputs.cuda(device=device, non_blocking=True) + _, _ = network(inputs) + return network + + +def get_ntk_n(inputs, targets, network, device, recalbn=0, train_mode=False, num_batch=1): + device = device + # if recalbn > 0: + # network = recal_bn(network, xloader, recalbn, device) + # if network_2 is not None: + # network_2 = recal_bn(network_2, xloader, recalbn, device) + network.eval() + networks = [] + networks.append(network) + ntks = [] + # if train_mode: + # networks.train() + # else: + # networks.eval() + ###### + grads = [[] for _ in range(len(networks))] + for i in range(num_batch): + if num_batch > 0 and i >= num_batch: break + inputs = inputs.cuda(device=device, non_blocking=True) + for net_idx, network in enumerate(networks): + network.zero_grad() + # print(inputs.size()) + inputs_ = inputs.clone().cuda(device=device, non_blocking=True) + logit = network(inputs_) + if isinstance(logit, tuple): + logit = logit[1] # 201 networks: return features and logits + for _idx in range(len(inputs_)): + logit[_idx:_idx + 1].backward(torch.ones_like(logit[_idx:_idx + 1]), retain_graph=True) + grad = [] + for name, W in network.named_parameters(): + if 'weight' in name and W.grad is not None: + grad.append(W.grad.view(-1).detach()) + grads[net_idx].append(torch.cat(grad, -1)) + network.zero_grad() + torch.cuda.empty_cache() + ###### + grads = [torch.stack(_grads, 0) for _grads in grads] + ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads] + for ntk in ntks: + eigenvalues, _ = torch.linalg.eigh(ntk) # ascending + conds = np.nan_to_num((eigenvalues[0] / eigenvalues[-1]).item(), copy=True, nan=100000.0) + return conds + +@measure('ntk', bn=True) +def compute_ntk(net, inputs, targets, split_data=1, loss_fn=None): + device = inputs.device + # Compute gradients (but don't apply them) + net.zero_grad() + + + try: + conds = get_ntk_n(inputs, targets, net, device) + except Exception as e: + print(e) + conds= np.nan + + return conds diff --git a/zero-cost-nas/foresight/pruners/measures/zen.py b/zero-cost-nas/foresight/pruners/measures/zen.py new file mode 100644 index 0000000..8b6e64e --- /dev/null +++ b/zero-cost-nas/foresight/pruners/measures/zen.py @@ -0,0 +1,110 @@ +# Copyright 2021 Samsung Electronics Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import torch +from torch import nn +import numpy as np + +from . import measure + + +def network_weight_gaussian_init(net: nn.Module): + with torch.no_grad(): + for n, m in net.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + try: + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + except: + pass + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.zeros_(m.bias) + else: + continue + + return net + + +def get_zen(gpu, model, mixup_gamma=1e-2, resolution=32, batch_size=64, repeat=32, + fp16=False): + info = {} + nas_score_list = [] + if gpu is not None: + device = torch.device(gpu) + else: + device = torch.device('cpu') + + if fp16: + dtype = torch.half + else: + dtype = torch.float32 + + with torch.no_grad(): + for repeat_count in range(repeat): + network_weight_gaussian_init(model) + input = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) + input2 = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) + mixup_input = input + mixup_gamma * input2 + output = model.forward_pre_GAP(input) + mixup_output = model.forward_pre_GAP(mixup_input) + + nas_score = torch.sum(torch.abs(output - mixup_output), dim=[1, 2, 3]) + nas_score = torch.mean(nas_score) + + # compute BN scaling + log_bn_scaling_factor = 0.0 + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + try: + bn_scaling_factor = torch.sqrt(torch.mean(m.running_var)) + log_bn_scaling_factor += torch.log(bn_scaling_factor) + except: + pass + pass + pass + nas_score = torch.log(nas_score) + log_bn_scaling_factor + nas_score_list.append(float(nas_score)) + + std_nas_score = np.std(nas_score_list) + avg_precision = 1.96 * std_nas_score / np.sqrt(len(nas_score_list)) + avg_nas_score = np.mean(nas_score_list) + + info = float(avg_nas_score) + return info + + + + + +@measure('zen', bn=True) +def compute_zen(net, inputs, targets, split_data=1, loss_fn=None): + device = inputs.device + # Compute gradients (but don't apply them) + net.zero_grad() + + + try: + zen = get_zen(device,net) + except Exception as e: + print(e) + zen= np.nan + + return zen diff --git a/zero-cost-nas/foresight/pruners/predictive.py b/zero-cost-nas/foresight/pruners/predictive.py index dd52b91..7029637 100644 --- a/zero-cost-nas/foresight/pruners/predictive.py +++ b/zero-cost-nas/foresight/pruners/predictive.py @@ -108,7 +108,7 @@ def find_measures(net_orig, # neural network measures = {} for k,v in measures_arr.items(): - if k in ['jacob_cov', 'meco', 'zico']: + if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico', 'ntk', 'gradsign', 'zen']: measures[k] = v else: measures[k] = sum_arr(v)