update
This commit is contained in:
		| @@ -96,20 +96,6 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_ | |||||||
|  |  | ||||||
|             model.candidate_flags[cell_type][selected_eid] = False |             model.candidate_flags[cell_type][selected_eid] = False | ||||||
|             # print(model.get_projected_weights()) |             # print(model.get_projected_weights()) | ||||||
|             if proj_crit == 'comb': |  | ||||||
|                 synflow = predictive.find_measures(model, |  | ||||||
|                                                    proj_queue, |  | ||||||
|                                                    ('random', 1, n_classes), |  | ||||||
|                                                    torch.device("cuda"), |  | ||||||
|                                                    measure_names=['synflow']) |  | ||||||
|                 var = predictive.find_measures(model, |  | ||||||
|                                                proj_queue, |  | ||||||
|                                                ('random', 1, n_classes), |  | ||||||
|                                                torch.device("cuda"), |  | ||||||
|                                                measure_names=['var']) |  | ||||||
|                 # print(synflow, var) |  | ||||||
|                 comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1) |  | ||||||
|                 measures = {'comb': comb} |  | ||||||
|             else: |             else: | ||||||
|                 measures = predictive.find_measures(model, |                 measures = predictive.find_measures(model, | ||||||
|                                                     proj_queue, |                                                     proj_queue, | ||||||
|   | |||||||
| @@ -55,9 +55,6 @@ def load_all(): | |||||||
|     from . import jacob_cov |     from . import jacob_cov | ||||||
|     from . import plain |     from . import plain | ||||||
|     from . import synflow |     from . import synflow | ||||||
|     from . import var |  | ||||||
|     from . import cor |  | ||||||
|     from . import norm |  | ||||||
|     from . import meco |     from . import meco | ||||||
|     from . import zico |     from . import zico | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,53 +0,0 @@ | |||||||
| # Copyright 2021 Samsung Electronics Co., Ltd. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| import time |  | ||||||
|  |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
|  |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| # ============================================================================= |  | ||||||
|  |  | ||||||
| import numpy as np |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from . import measure |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_score(net, x, target, device, split_data): |  | ||||||
|     result_list = [] |  | ||||||
|     def forward_hook(module, data_input, data_output): |  | ||||||
|         corr = np.mean(np.corrcoef(data_input[0].detach().cpu().numpy())) |  | ||||||
|         result_list.append(corr) |  | ||||||
|     net.classifier.register_forward_hook(forward_hook) |  | ||||||
|  |  | ||||||
|     N = x.shape[0] |  | ||||||
|     for sp in range(split_data): |  | ||||||
|         st = sp * N // split_data |  | ||||||
|         en = (sp + 1) * N // split_data |  | ||||||
|         y = net(x[st:en]) |  | ||||||
|     cor = result_list[0].item() |  | ||||||
|     result_list.clear() |  | ||||||
|     return cor |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @measure('cor', bn=True) |  | ||||||
| def compute_norm(net, inputs, targets, split_data=1, loss_fn=None): |  | ||||||
|     device = inputs.device |  | ||||||
|     # Compute gradients (but don't apply them) |  | ||||||
|     net.zero_grad() |  | ||||||
|  |  | ||||||
|     try: |  | ||||||
|         cor= get_score(net, inputs, targets, device, split_data=split_data) |  | ||||||
|     except Exception as e: |  | ||||||
|         print(e) |  | ||||||
|         cor= np.nan |  | ||||||
|  |  | ||||||
|     return cor |  | ||||||
| @@ -1,55 +0,0 @@ | |||||||
| # Copyright 2021 Samsung Electronics Co., Ltd. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| import time |  | ||||||
|  |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
|  |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| # ============================================================================= |  | ||||||
|  |  | ||||||
| import numpy as np |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from . import measure |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_score(net, x, target, device, split_data): |  | ||||||
|     result_list = [] |  | ||||||
|     def forward_hook(module, data_input, data_output): |  | ||||||
|         norm = torch.norm(data_input[0]) |  | ||||||
|         result_list.append(norm) |  | ||||||
|     net.classifier.register_forward_hook(forward_hook) |  | ||||||
|  |  | ||||||
|     N = x.shape[0] |  | ||||||
|     for sp in range(split_data): |  | ||||||
|         st = sp * N // split_data |  | ||||||
|         en = (sp + 1) * N // split_data |  | ||||||
|         y = net(x[st:en]) |  | ||||||
|     n = result_list[0].item() |  | ||||||
|     result_list.clear() |  | ||||||
|     return n |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @measure('norm', bn=True) |  | ||||||
| def compute_norm(net, inputs, targets, split_data=1, loss_fn=None): |  | ||||||
|     device = inputs.device |  | ||||||
|     # Compute gradients (but don't apply them) |  | ||||||
|     net.zero_grad() |  | ||||||
|  |  | ||||||
|     # print('var:', feature.shape) |  | ||||||
|     try: |  | ||||||
|         norm, t = get_score(net, inputs, targets, device, split_data=split_data) |  | ||||||
|     except Exception as e: |  | ||||||
|         print(e) |  | ||||||
|         norm, t = np.nan, None |  | ||||||
|     # print(jc) |  | ||||||
|     # print(f'norm time: {t} s') |  | ||||||
|     return norm, t |  | ||||||
| @@ -1,16 +0,0 @@ | |||||||
| import time |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from . import measure |  | ||||||
| from ..p_utils import get_layer_metric_array |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @measure('param_count', copy_net=False, mode='param') |  | ||||||
| def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1): |  | ||||||
|     s = time.time() |  | ||||||
|     count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode) |  | ||||||
|     e = time.time() |  | ||||||
|     t = e - s |  | ||||||
|     # print(f'param_count time: {t} s') |  | ||||||
|     return count, t |  | ||||||
| @@ -1,55 +0,0 @@ | |||||||
| # Copyright 2021 Samsung Electronics Co., Ltd. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| import time |  | ||||||
|  |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
|  |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| # ============================================================================= |  | ||||||
|  |  | ||||||
| import numpy as np |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from . import measure |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_score(net, x, target, device, split_data): |  | ||||||
|     result_list = [] |  | ||||||
|     def forward_hook(module, data_input, data_output): |  | ||||||
|         var = torch.var(data_input[0]) |  | ||||||
|         result_list.append(var) |  | ||||||
|     net.classifier.register_forward_hook(forward_hook) |  | ||||||
|  |  | ||||||
|     N = x.shape[0] |  | ||||||
|     for sp in range(split_data): |  | ||||||
|         st = sp * N // split_data |  | ||||||
|         en = (sp + 1) * N // split_data |  | ||||||
|         y = net(x[st:en]) |  | ||||||
|     v = result_list[0].item() |  | ||||||
|     result_list.clear() |  | ||||||
|     return v |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @measure('var', bn=True) |  | ||||||
| def compute_var(net, inputs, targets, split_data=1, loss_fn=None): |  | ||||||
|     device = inputs.device |  | ||||||
|     # Compute gradients (but don't apply them) |  | ||||||
|     net.zero_grad() |  | ||||||
|  |  | ||||||
|     # print('var:', feature.shape) |  | ||||||
|     try: |  | ||||||
|         var= get_score(net, inputs, targets, device, split_data=split_data) |  | ||||||
|     except Exception as e: |  | ||||||
|         print(e) |  | ||||||
|         var= np.nan |  | ||||||
|     # print(jc) |  | ||||||
|     # print(f'var time: {t} s') |  | ||||||
|     return var |  | ||||||
| @@ -108,7 +108,7 @@ def find_measures(net_orig,                  # neural network | |||||||
|  |  | ||||||
|     measures = {} |     measures = {} | ||||||
|     for k,v in measures_arr.items(): |     for k,v in measures_arr.items(): | ||||||
|         if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico']: |         if k in ['jacob_cov', 'meco', 'zico']: | ||||||
|             measures[k] = v |             measures[k] = v | ||||||
|         else: |         else: | ||||||
|             measures[k] = sum_arr(v) |             measures[k] = sum_arr(v) | ||||||
|   | |||||||
| @@ -223,20 +223,6 @@ def main(): | |||||||
|                 else: |                 else: | ||||||
|                     #score =  score_loop(network, None, train_queue, args.gpu, None, args.proj_crit) |                     #score =  score_loop(network, None, train_queue, args.gpu, None, args.proj_crit) | ||||||
|                     network.requires_feature = False |                     network.requires_feature = False | ||||||
|  |  | ||||||
|                     if args.proj_crit == 'comb': |  | ||||||
|                         synflow = predictive.find_measures(network, |  | ||||||
|                                                            train_queue, |  | ||||||
|                                                            ('random', 1, n_classes), |  | ||||||
|                                                            torch.device("cuda"), |  | ||||||
|                                                            measure_names=['synflow']) |  | ||||||
|                         var = predictive.find_measures(network, |  | ||||||
|                                                        train_queue, |  | ||||||
|                                                        ('random', 1, n_classes), |  | ||||||
|                                                        torch.device("cuda"), |  | ||||||
|                                                        measure_names=['var']) |  | ||||||
|                         comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1) |  | ||||||
|                         measures = {'comb': comb} |  | ||||||
|                     else: |                     else: | ||||||
|                         measures = predictive.find_measures(network, |                         measures = predictive.find_measures(network, | ||||||
|                                                             train_queue, |                                                             train_queue, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user