upload
This commit is contained in:
116
zero-cost-nas/foresight/pruners/predictive.py
Normal file
116
zero-cost-nas/foresight/pruners/predictive.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .p_utils import *
|
||||
from . import measures
|
||||
|
||||
import types
|
||||
import copy
|
||||
|
||||
|
||||
def no_op(self,x):
|
||||
return x
|
||||
|
||||
def copynet(self, bn):
|
||||
net = copy.deepcopy(self)
|
||||
if bn==False:
|
||||
for l in net.modules():
|
||||
if isinstance(l,nn.BatchNorm2d) or isinstance(l,nn.BatchNorm1d) :
|
||||
l.forward = types.MethodType(no_op, l)
|
||||
return net
|
||||
|
||||
def find_measures_arrays(net_orig, trainloader, dataload_info, device, measure_names=None, loss_fn=F.cross_entropy):
|
||||
if measure_names is None:
|
||||
measure_names = measures.available_measures
|
||||
|
||||
dataload, num_imgs_or_batches, num_classes = dataload_info
|
||||
|
||||
if not hasattr(net_orig,'get_prunable_copy'):
|
||||
net_orig.get_prunable_copy = types.MethodType(copynet, net_orig)
|
||||
|
||||
#move to cpu to free up mem
|
||||
torch.cuda.empty_cache()
|
||||
net_orig = net_orig.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
#given 1 minibatch of data
|
||||
if dataload == 'random':
|
||||
inputs, targets = get_some_data(trainloader, num_batches=num_imgs_or_batches, device=device)
|
||||
elif dataload == 'grasp':
|
||||
inputs, targets = get_some_data_grasp(trainloader, num_classes, samples_per_class=num_imgs_or_batches, device=device)
|
||||
else:
|
||||
raise NotImplementedError(f'dataload {dataload} is not supported')
|
||||
|
||||
done, ds = False, 1
|
||||
measure_values = {}
|
||||
|
||||
while not done:
|
||||
try:
|
||||
for measure_name in measure_names:
|
||||
if measure_name not in measure_values:
|
||||
val = measures.calc_measure(measure_name, net_orig, device, inputs, targets, loss_fn=loss_fn, split_data=ds)
|
||||
measure_values[measure_name] = val
|
||||
|
||||
done = True
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
done=False
|
||||
if ds == inputs.shape[0]//2:
|
||||
raise ValueError(f'Can\'t split data anymore, but still unable to run. Something is wrong')
|
||||
ds += 1
|
||||
while inputs.shape[0] % ds != 0:
|
||||
ds += 1
|
||||
torch.cuda.empty_cache()
|
||||
print(f'Caught CUDA OOM, retrying with data split into {ds} parts')
|
||||
else:
|
||||
raise e
|
||||
|
||||
net_orig = net_orig.to(device).train()
|
||||
return measure_values
|
||||
|
||||
def find_measures(net_orig, # neural network
|
||||
dataloader, # a data loader (typically for training data)
|
||||
dataload_info, # a tuple with (dataload_type = {random, grasp}, number_of_batches_for_random_or_images_per_class_for_grasp, number of classes)
|
||||
device, # GPU/CPU device used
|
||||
loss_fn=F.cross_entropy, # loss function to use within the zero-cost metrics
|
||||
measure_names=None, # an array of measure names to compute, if left blank, all measures are computed by default
|
||||
measures_arr=None): # [not used] if the measures are already computed but need to be summarized, pass them here
|
||||
|
||||
#Given a neural net
|
||||
#and some information about the input data (dataloader)
|
||||
#and loss function (loss_fn)
|
||||
#this function returns an array of zero-cost proxy metrics.
|
||||
|
||||
def sum_arr(arr):
|
||||
sum = 0.
|
||||
for i in range(len(arr)):
|
||||
sum += torch.sum(arr[i])
|
||||
return sum.item()
|
||||
|
||||
if measures_arr is None:
|
||||
measures_arr = find_measures_arrays(net_orig, dataloader, dataload_info, device, loss_fn=loss_fn, measure_names=measure_names)
|
||||
|
||||
measures = {}
|
||||
for k,v in measures_arr.items():
|
||||
if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico']:
|
||||
measures[k] = v
|
||||
else:
|
||||
measures[k] = sum_arr(v)
|
||||
|
||||
return measures
|
Reference in New Issue
Block a user