271 lines
11 KiB
Python
271 lines
11 KiB
Python
import os.path as osp
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.data
|
|
import torchvision.transforms as transforms
|
|
import torchvision.datasets as dset
|
|
from pdb import set_trace as bp
|
|
from operator import mul
|
|
from functools import reduce
|
|
import copy
|
|
Dataset2Class = {'cifar10': 10,
|
|
'cifar100': 100,
|
|
'imagenet-1k-s': 1000,
|
|
'imagenet-1k': 1000,
|
|
}
|
|
|
|
|
|
class CUTOUT(object):
|
|
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __repr__(self):
|
|
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
|
|
def __call__(self, img):
|
|
h, w = img.size(1), img.size(2)
|
|
mask = np.ones((h, w), np.float32)
|
|
y = np.random.randint(h)
|
|
x = np.random.randint(w)
|
|
|
|
y1 = np.clip(y - self.length // 2, 0, h)
|
|
y2 = np.clip(y + self.length // 2, 0, h)
|
|
x1 = np.clip(x - self.length // 2, 0, w)
|
|
x2 = np.clip(x + self.length // 2, 0, w)
|
|
|
|
mask[y1: y2, x1: x2] = 0.
|
|
mask = torch.from_numpy(mask)
|
|
mask = mask.expand_as(img)
|
|
img *= mask
|
|
return img
|
|
|
|
|
|
imagenet_pca = {
|
|
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
|
|
'eigvec': np.asarray([
|
|
[-0.5675, 0.7192, 0.4009],
|
|
[-0.5808, -0.0045, -0.8140],
|
|
[-0.5836, -0.6948, 0.4203],
|
|
])
|
|
}
|
|
|
|
|
|
class RandChannel(object):
|
|
# randomly pick channels from input
|
|
def __init__(self, num_channel):
|
|
self.num_channel = num_channel
|
|
|
|
def __repr__(self):
|
|
return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
|
|
def __call__(self, img):
|
|
channel = img.size(0)
|
|
channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False))
|
|
return torch.index_select(img, 0, torch.Tensor(channel_choice).long())
|
|
|
|
|
|
def get_datasets(name, root, input_size, cutout=-1):
|
|
assert len(input_size) in [3, 4]
|
|
if len(input_size) == 4:
|
|
input_size = input_size[1:]
|
|
assert input_size[1] == input_size[2]
|
|
|
|
if name == 'cifar10':
|
|
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
|
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
|
elif name == 'cifar100':
|
|
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
|
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
|
elif name.startswith('imagenet-1k'):
|
|
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
|
elif name.startswith('ImageNet16'):
|
|
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
|
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
|
|
else:
|
|
raise TypeError("Unknow dataset : {:}".format(name))
|
|
#ßprint(input_size)
|
|
# Data Argumentation
|
|
if name == 'cifar10' or name == 'cifar100':
|
|
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
|
if cutout > 0 : lists += [CUTOUT(cutout)]
|
|
train_transform = transforms.Compose(lists)
|
|
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
|
elif name.startswith('ImageNet16'):
|
|
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
|
if cutout > 0 : lists += [CUTOUT(cutout)]
|
|
train_transform = transforms.Compose(lists)
|
|
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
|
elif name.startswith('imagenet-1k'):
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
if name == 'imagenet-1k':
|
|
xlists = []
|
|
xlists.append(transforms.Resize((32, 32), interpolation=2))
|
|
xlists.append(transforms.RandomCrop(input_size[1], padding=0))
|
|
elif name == 'imagenet-1k-s':
|
|
xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))]
|
|
xlists = []
|
|
else: raise ValueError('invalid name : {:}'.format(name))
|
|
xlists.append(transforms.ToTensor())
|
|
xlists.append(normalize)
|
|
xlists.append(RandChannel(input_size[0]))
|
|
train_transform = transforms.Compose(xlists)
|
|
test_transform = transforms.Compose([transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), normalize])
|
|
else:
|
|
raise TypeError("Unknow dataset : {:}".format(name))
|
|
|
|
if name == 'cifar10':
|
|
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
|
|
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
|
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
|
elif name == 'cifar100':
|
|
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
|
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
|
elif name.startswith('imagenet-1k'):
|
|
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
|
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
|
else: raise TypeError("Unknow dataset : {:}".format(name))
|
|
|
|
class_num = Dataset2Class[name]
|
|
return train_data, test_data, class_num
|
|
|
|
|
|
class LinearRegionCount(object):
|
|
"""Computes and stores the average and current value"""
|
|
def __init__(self, n_samples):
|
|
self.ActPattern = {}
|
|
self.n_LR = -1
|
|
self.n_samples = n_samples
|
|
self.ptr = 0
|
|
self.activations = None
|
|
|
|
@torch.no_grad()
|
|
def update2D(self, activations):
|
|
n_batch = activations.size()[0]
|
|
n_neuron = activations.size()[1]
|
|
self.n_neuron = n_neuron
|
|
if self.activations is None:
|
|
self.activations = torch.zeros(self.n_samples, n_neuron).cuda()
|
|
self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU
|
|
self.ptr += n_batch
|
|
|
|
@torch.no_grad()
|
|
def calc_LR(self):
|
|
res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) # each element in res: A * (1 - B)
|
|
res += res.T # make symmetric, each element in res: A * (1 - B) + (1 - A) * B, a non-zero element indicate a pair of two different linear regions
|
|
res = 1 - torch.sign(res) # a non-zero element now indicate two linear regions are identical
|
|
res = res.sum(1) # for each sample's linear region: how many identical regions from other samples
|
|
res = 1. / res.float() # contribution of each redudant (repeated) linear region
|
|
self.n_LR = res.sum().item() # sum of unique regions (by aggregating contribution of all regions)
|
|
del self.activations, res
|
|
self.activations = None
|
|
torch.cuda.empty_cache()
|
|
|
|
@torch.no_grad()
|
|
def update1D(self, activationList):
|
|
code_string = ''
|
|
for key, value in activationList.items():
|
|
n_neuron = value.size()[0]
|
|
for i in range(n_neuron):
|
|
if value[i] > 0:
|
|
code_string += '1'
|
|
else:
|
|
code_string += '0'
|
|
if code_string not in self.ActPattern:
|
|
self.ActPattern[code_string] = 1
|
|
|
|
def getLinearReginCount(self):
|
|
if self.n_LR == -1:
|
|
self.calc_LR()
|
|
return self.n_LR
|
|
|
|
|
|
class Linear_Region_Collector:
|
|
def __init__(self, models=[], input_size=(64, 3, 32, 32), sample_batch=100, dataset='cifar100', data_path=None, seed=0):
|
|
self.models = []
|
|
self.input_size = input_size # BCHW
|
|
self.sample_batch = sample_batch
|
|
self.input_numel = reduce(mul, self.input_size, 1)
|
|
self.interFeature = []
|
|
self.dataset = dataset
|
|
self.data_path = data_path
|
|
self.seed = seed
|
|
self.reinit(models, input_size, sample_batch, seed)
|
|
|
|
def reinit(self, ori_models=None, input_size=None, sample_batch=None, seed=None, weights=None):
|
|
models = []
|
|
for network in ori_models:
|
|
network = network.cuda()
|
|
net = copy.deepcopy(network)
|
|
net.proj_weights = weights
|
|
num_edge, num_op = net.num_edge, net.num_op
|
|
for i in range(num_edge):
|
|
net.candidate_flags[i] = False
|
|
net.eval()
|
|
models.append(net)
|
|
|
|
if models is not None:
|
|
assert isinstance(models, list)
|
|
del self.models
|
|
self.models = models
|
|
for model in self.models:
|
|
self.register_hook(model)
|
|
device = torch.cuda.current_device()
|
|
model = model.cuda(device=device)
|
|
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(models))]
|
|
if input_size is not None or sample_batch is not None:
|
|
if input_size is not None:
|
|
self.input_size = input_size # BCHW
|
|
self.input_numel = reduce(mul, self.input_size, 1)
|
|
if sample_batch is not None:
|
|
self.sample_batch = sample_batch
|
|
if self.data_path is not None:
|
|
self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1)
|
|
self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True)
|
|
self.loader = iter(self.train_loader)
|
|
if seed is not None and seed != self.seed:
|
|
self.seed = seed
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
del self.interFeature
|
|
self.interFeature = []
|
|
torch.cuda.empty_cache()
|
|
|
|
def clear(self):
|
|
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))]
|
|
del self.interFeature
|
|
self.interFeature = []
|
|
torch.cuda.empty_cache()
|
|
|
|
def register_hook(self, model):
|
|
for m in model.modules():
|
|
if isinstance(m, nn.ReLU):
|
|
m.register_forward_hook(hook=self.hook_in_forward)
|
|
|
|
def hook_in_forward(self, module, input, output):
|
|
if isinstance(input, tuple) and len(input[0].size()) == 4:
|
|
self.interFeature.append(output.detach()) # for ReLU
|
|
|
|
def forward_batch_sample(self):
|
|
for _ in range(self.sample_batch):
|
|
try:
|
|
inputs, targets = self.loader.next()
|
|
except Exception:
|
|
del self.loader
|
|
self.loader = iter(self.train_loader)
|
|
inputs, targets = self.loader.next()
|
|
for model, LRCount in zip(self.models, self.LRCounts):
|
|
self.forward(model, LRCount, inputs)
|
|
output = [LRCount.getLinearReginCount() for LRCount in self.LRCounts]
|
|
return output
|
|
|
|
def forward(self, model, LRCount, input_data):
|
|
self.interFeature = []
|
|
with torch.no_grad():
|
|
model.forward(input_data.cuda())
|
|
if len(self.interFeature) == 0: return
|
|
feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1)
|
|
LRCount.update2D(feature_data)
|