update
This commit is contained in:
		
							
								
								
									
										76
									
								
								zero-cost-nas/foresight/pruners/measures/gradsign.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								zero-cost-nas/foresight/pruners/measures/gradsign.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | |||||||
|  | # Copyright 2021 Samsung Electronics Co., Ltd. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================= | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from . import measure | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_flattened_metric(net, metric): | ||||||
|  |     grad_list = [] | ||||||
|  |     for layer in net.modules(): | ||||||
|  |         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||||
|  |             grad_list.append(metric(layer).flatten()) | ||||||
|  |     flattened_grad = np.concatenate(grad_list) | ||||||
|  |  | ||||||
|  |     return flattened_grad | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_grad_conflict(net, inputs, targets, loss_fn): | ||||||
|  |     N = inputs.shape[0] | ||||||
|  |     batch_grad = [] | ||||||
|  |     for i in range(N): | ||||||
|  |         net.zero_grad() | ||||||
|  |         outputs = net.forward(inputs[[i]]) | ||||||
|  |         loss = loss_fn(outputs, targets[[i]]) | ||||||
|  |         loss.backward() | ||||||
|  |         flattened_grad = get_flattened_metric(net, lambda | ||||||
|  |             l: l.weight.grad.data.clone().cpu().numpy() if l.weight.grad is not None else torch.zeros_like( | ||||||
|  |             l.weight).clone().cpu().numpy()) | ||||||
|  |         batch_grad.append(flattened_grad) | ||||||
|  |     batch_grad = np.stack(batch_grad) | ||||||
|  |     direction_code = np.sign(batch_grad) | ||||||
|  |     direction_code = abs(direction_code.sum(axis=0)) | ||||||
|  |     score = np.nansum(direction_code) | ||||||
|  |     return score | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_gradsign(input, target, net, device, loss_fn): | ||||||
|  |     s = [] | ||||||
|  |     net = net.to(device) | ||||||
|  |     x, target = input, target | ||||||
|  |     # x2 = torch.clone(x) | ||||||
|  |     # x2 = x2.to(device) | ||||||
|  |     x, target = x.to(device), target.to(device) | ||||||
|  |     s.append(get_grad_conflict(net=net, inputs=x, targets=target, loss_fn=loss_fn)) | ||||||
|  |     s = np.mean(s) | ||||||
|  |     return s | ||||||
|  |  | ||||||
|  | @measure('gradsign', bn=True) | ||||||
|  | def compute_gradsign(net, inputs, targets, split_data=1, loss_fn=None): | ||||||
|  |     device = inputs.device | ||||||
|  |     # Compute gradients (but don't apply them) | ||||||
|  |     net.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         gradsign = get_gradsign(inputs, targets, net, device, loss_fn) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(e) | ||||||
|  |         gradsign= np.nan | ||||||
|  |  | ||||||
|  |     return gradsign | ||||||
							
								
								
									
										94
									
								
								zero-cost-nas/foresight/pruners/measures/ntk.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								zero-cost-nas/foresight/pruners/measures/ntk.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | |||||||
|  | # Copyright 2021 Samsung Electronics Co., Ltd. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================= | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from . import measure | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def recal_bn(network, inputs, targets, recalbn, device): | ||||||
|  |     for m in network.modules(): | ||||||
|  |         if isinstance(m, torch.nn.BatchNorm2d): | ||||||
|  |             m.running_mean.data.fill_(0) | ||||||
|  |             m.running_var.data.fill_(0) | ||||||
|  |             m.num_batches_tracked.data.zero_() | ||||||
|  |             m.momentum = None | ||||||
|  |     network.train() | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         for i, (inputs, targets) in enumerate(zip(inputs, targets)): | ||||||
|  |             if i >= recalbn: break | ||||||
|  |             inputs = inputs.cuda(device=device, non_blocking=True) | ||||||
|  |             _, _ = network(inputs) | ||||||
|  |     return network | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_ntk_n(inputs, targets, network, device, recalbn=0, train_mode=False, num_batch=1): | ||||||
|  |     device = device | ||||||
|  |     # if recalbn > 0: | ||||||
|  |     #     network = recal_bn(network, xloader, recalbn, device) | ||||||
|  |     #     if network_2 is not None: | ||||||
|  |     #         network_2 = recal_bn(network_2, xloader, recalbn, device) | ||||||
|  |     network.eval() | ||||||
|  |     networks = [] | ||||||
|  |     networks.append(network) | ||||||
|  |     ntks = [] | ||||||
|  |     # if train_mode: | ||||||
|  |     #     networks.train() | ||||||
|  |     # else: | ||||||
|  |     #     networks.eval() | ||||||
|  |     ###### | ||||||
|  |     grads = [[] for _ in range(len(networks))] | ||||||
|  |     for i in range(num_batch): | ||||||
|  |         if num_batch > 0 and i >= num_batch: break | ||||||
|  |         inputs = inputs.cuda(device=device, non_blocking=True) | ||||||
|  |         for net_idx, network in enumerate(networks): | ||||||
|  |             network.zero_grad() | ||||||
|  |             # print(inputs.size()) | ||||||
|  |             inputs_ = inputs.clone().cuda(device=device, non_blocking=True) | ||||||
|  |             logit = network(inputs_) | ||||||
|  |             if isinstance(logit, tuple): | ||||||
|  |                 logit = logit[1]  # 201 networks: return features and logits | ||||||
|  |             for _idx in range(len(inputs_)): | ||||||
|  |                 logit[_idx:_idx + 1].backward(torch.ones_like(logit[_idx:_idx + 1]), retain_graph=True) | ||||||
|  |                 grad = [] | ||||||
|  |                 for name, W in network.named_parameters(): | ||||||
|  |                     if 'weight' in name and W.grad is not None: | ||||||
|  |                         grad.append(W.grad.view(-1).detach()) | ||||||
|  |                 grads[net_idx].append(torch.cat(grad, -1)) | ||||||
|  |                 network.zero_grad() | ||||||
|  |                 torch.cuda.empty_cache() | ||||||
|  |     ###### | ||||||
|  |     grads = [torch.stack(_grads, 0) for _grads in grads] | ||||||
|  |     ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads] | ||||||
|  |     for ntk in ntks: | ||||||
|  |         eigenvalues, _ = torch.linalg.eigh(ntk)  # ascending | ||||||
|  |         conds = np.nan_to_num((eigenvalues[0] / eigenvalues[-1]).item(), copy=True, nan=100000.0) | ||||||
|  |     return conds | ||||||
|  |  | ||||||
|  | @measure('ntk', bn=True) | ||||||
|  | def compute_ntk(net, inputs, targets, split_data=1, loss_fn=None): | ||||||
|  |     device = inputs.device | ||||||
|  |     # Compute gradients (but don't apply them) | ||||||
|  |     net.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         conds = get_ntk_n(inputs, targets, net, device) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(e) | ||||||
|  |         conds= np.nan | ||||||
|  |  | ||||||
|  |     return conds | ||||||
							
								
								
									
										110
									
								
								zero-cost-nas/foresight/pruners/measures/zen.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								zero-cost-nas/foresight/pruners/measures/zen.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,110 @@ | |||||||
|  | # Copyright 2021 Samsung Electronics Co., Ltd. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================= | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from . import measure | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def network_weight_gaussian_init(net: nn.Module): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         for n, m in net.named_modules(): | ||||||
|  |             if isinstance(m, nn.Conv2d): | ||||||
|  |                 nn.init.normal_(m.weight) | ||||||
|  |                 if hasattr(m, 'bias') and m.bias is not None: | ||||||
|  |                     nn.init.zeros_(m.bias) | ||||||
|  |             elif isinstance(m, nn.BatchNorm2d): | ||||||
|  |                 try: | ||||||
|  |                     nn.init.ones_(m.weight) | ||||||
|  |                     nn.init.zeros_(m.bias) | ||||||
|  |                 except: | ||||||
|  |                     pass | ||||||
|  |             elif isinstance(m, nn.Linear): | ||||||
|  |                 nn.init.normal_(m.weight) | ||||||
|  |                 if hasattr(m, 'bias') and m.bias is not None: | ||||||
|  |                     nn.init.zeros_(m.bias) | ||||||
|  |             else: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |     return net | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_zen(gpu, model, mixup_gamma=1e-2, resolution=32, batch_size=64, repeat=32, | ||||||
|  |                       fp16=False): | ||||||
|  |     info = {} | ||||||
|  |     nas_score_list = [] | ||||||
|  |     if gpu is not None: | ||||||
|  |         device = torch.device(gpu) | ||||||
|  |     else: | ||||||
|  |         device = torch.device('cpu') | ||||||
|  |  | ||||||
|  |     if fp16: | ||||||
|  |         dtype = torch.half | ||||||
|  |     else: | ||||||
|  |         dtype = torch.float32 | ||||||
|  |  | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         for repeat_count in range(repeat): | ||||||
|  |             network_weight_gaussian_init(model) | ||||||
|  |             input = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) | ||||||
|  |             input2 = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) | ||||||
|  |             mixup_input = input + mixup_gamma * input2 | ||||||
|  |             output = model.forward_pre_GAP(input) | ||||||
|  |             mixup_output = model.forward_pre_GAP(mixup_input) | ||||||
|  |  | ||||||
|  |             nas_score = torch.sum(torch.abs(output - mixup_output), dim=[1, 2, 3]) | ||||||
|  |             nas_score = torch.mean(nas_score) | ||||||
|  |  | ||||||
|  |             # compute BN scaling | ||||||
|  |             log_bn_scaling_factor = 0.0 | ||||||
|  |             for m in model.modules(): | ||||||
|  |                 if isinstance(m, nn.BatchNorm2d): | ||||||
|  |                     try: | ||||||
|  |                         bn_scaling_factor = torch.sqrt(torch.mean(m.running_var)) | ||||||
|  |                         log_bn_scaling_factor += torch.log(bn_scaling_factor) | ||||||
|  |                     except: | ||||||
|  |                         pass | ||||||
|  |                 pass | ||||||
|  |             pass | ||||||
|  |             nas_score = torch.log(nas_score) + log_bn_scaling_factor | ||||||
|  |             nas_score_list.append(float(nas_score)) | ||||||
|  |  | ||||||
|  |     std_nas_score = np.std(nas_score_list) | ||||||
|  |     avg_precision = 1.96 * std_nas_score / np.sqrt(len(nas_score_list)) | ||||||
|  |     avg_nas_score = np.mean(nas_score_list) | ||||||
|  |  | ||||||
|  |     info = float(avg_nas_score) | ||||||
|  |     return info | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @measure('zen', bn=True) | ||||||
|  | def compute_zen(net, inputs, targets, split_data=1, loss_fn=None): | ||||||
|  |     device = inputs.device | ||||||
|  |     # Compute gradients (but don't apply them) | ||||||
|  |     net.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         zen = get_zen(device,net) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(e) | ||||||
|  |         zen= np.nan | ||||||
|  |  | ||||||
|  |     return zen | ||||||
| @@ -108,7 +108,7 @@ def find_measures(net_orig,                  # neural network | |||||||
|  |  | ||||||
|     measures = {} |     measures = {} | ||||||
|     for k,v in measures_arr.items(): |     for k,v in measures_arr.items(): | ||||||
|         if k in ['jacob_cov', 'meco', 'zico']: |         if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico', 'ntk', 'gradsign', 'zen']: | ||||||
|             measures[k] = v |             measures[k] = v | ||||||
|         else: |         else: | ||||||
|             measures[k] = sum_arr(v) |             measures[k] = sum_arr(v) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user