MeCo/nasbench201/linear_region.py

271 lines
11 KiB
Python
Raw Normal View History

2023-05-04 07:42:06 +02:00
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as dset
from pdb import set_trace as bp
from operator import mul
from functools import reduce
import copy
Dataset2Class = {'cifar10': 10,
'cifar100': 100,
'imagenet-1k-s': 1000,
'imagenet-1k': 1000,
}
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class RandChannel(object):
# randomly pick channels from input
def __init__(self, num_channel):
self.num_channel = num_channel
def __repr__(self):
return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
channel = img.size(0)
channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False))
return torch.index_select(img, 0, torch.Tensor(channel_choice).long())
def get_datasets(name, root, input_size, cutout=-1):
assert len(input_size) in [3, 4]
if len(input_size) == 4:
input_size = input_size[1:]
assert input_size[1] == input_size[2]
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('imagenet-1k'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('ImageNet16'):
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
else:
raise TypeError("Unknow dataset : {:}".format(name))
#ßprint(input_size)
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name.startswith('ImageNet16'):
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name.startswith('imagenet-1k'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if name == 'imagenet-1k':
xlists = []
xlists.append(transforms.Resize((32, 32), interpolation=2))
xlists.append(transforms.RandomCrop(input_size[1], padding=0))
elif name == 'imagenet-1k-s':
xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))]
xlists = []
else: raise ValueError('invalid name : {:}'.format(name))
xlists.append(transforms.ToTensor())
xlists.append(normalize)
xlists.append(RandChannel(input_size[0]))
train_transform = transforms.Compose(xlists)
test_transform = transforms.Compose([transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), normalize])
else:
raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name.startswith('imagenet-1k'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
else: raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name]
return train_data, test_data, class_num
class LinearRegionCount(object):
"""Computes and stores the average and current value"""
def __init__(self, n_samples):
self.ActPattern = {}
self.n_LR = -1
self.n_samples = n_samples
self.ptr = 0
self.activations = None
@torch.no_grad()
def update2D(self, activations):
n_batch = activations.size()[0]
n_neuron = activations.size()[1]
self.n_neuron = n_neuron
if self.activations is None:
self.activations = torch.zeros(self.n_samples, n_neuron).cuda()
self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU
self.ptr += n_batch
@torch.no_grad()
def calc_LR(self):
res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) # each element in res: A * (1 - B)
res += res.T # make symmetric, each element in res: A * (1 - B) + (1 - A) * B, a non-zero element indicate a pair of two different linear regions
res = 1 - torch.sign(res) # a non-zero element now indicate two linear regions are identical
res = res.sum(1) # for each sample's linear region: how many identical regions from other samples
res = 1. / res.float() # contribution of each redudant (repeated) linear region
self.n_LR = res.sum().item() # sum of unique regions (by aggregating contribution of all regions)
del self.activations, res
self.activations = None
torch.cuda.empty_cache()
@torch.no_grad()
def update1D(self, activationList):
code_string = ''
for key, value in activationList.items():
n_neuron = value.size()[0]
for i in range(n_neuron):
if value[i] > 0:
code_string += '1'
else:
code_string += '0'
if code_string not in self.ActPattern:
self.ActPattern[code_string] = 1
def getLinearReginCount(self):
if self.n_LR == -1:
self.calc_LR()
return self.n_LR
class Linear_Region_Collector:
def __init__(self, models=[], input_size=(64, 3, 32, 32), sample_batch=100, dataset='cifar100', data_path=None, seed=0):
self.models = []
self.input_size = input_size # BCHW
self.sample_batch = sample_batch
self.input_numel = reduce(mul, self.input_size, 1)
self.interFeature = []
self.dataset = dataset
self.data_path = data_path
self.seed = seed
self.reinit(models, input_size, sample_batch, seed)
def reinit(self, ori_models=None, input_size=None, sample_batch=None, seed=None, weights=None):
models = []
for network in ori_models:
network = network.cuda()
net = copy.deepcopy(network)
net.proj_weights = weights
num_edge, num_op = net.num_edge, net.num_op
for i in range(num_edge):
net.candidate_flags[i] = False
net.eval()
models.append(net)
if models is not None:
assert isinstance(models, list)
del self.models
self.models = models
for model in self.models:
self.register_hook(model)
device = torch.cuda.current_device()
model = model.cuda(device=device)
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(models))]
if input_size is not None or sample_batch is not None:
if input_size is not None:
self.input_size = input_size # BCHW
self.input_numel = reduce(mul, self.input_size, 1)
if sample_batch is not None:
self.sample_batch = sample_batch
if self.data_path is not None:
self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1)
self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True)
self.loader = iter(self.train_loader)
if seed is not None and seed != self.seed:
self.seed = seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
del self.interFeature
self.interFeature = []
torch.cuda.empty_cache()
def clear(self):
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))]
del self.interFeature
self.interFeature = []
torch.cuda.empty_cache()
def register_hook(self, model):
for m in model.modules():
if isinstance(m, nn.ReLU):
m.register_forward_hook(hook=self.hook_in_forward)
def hook_in_forward(self, module, input, output):
if isinstance(input, tuple) and len(input[0].size()) == 4:
self.interFeature.append(output.detach()) # for ReLU
def forward_batch_sample(self):
for _ in range(self.sample_batch):
try:
inputs, targets = self.loader.next()
except Exception:
del self.loader
self.loader = iter(self.train_loader)
inputs, targets = self.loader.next()
for model, LRCount in zip(self.models, self.LRCounts):
self.forward(model, LRCount, inputs)
output = [LRCount.getLinearReginCount() for LRCount in self.LRCounts]
return output
def forward(self, model, LRCount, input_data):
self.interFeature = []
with torch.no_grad():
model.forward(input_data.cuda())
if len(self.interFeature) == 0: return
feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1)
LRCount.update2D(feature_data)