MeCo/nasbench201/linear_region.py
HamsterMimi 2410fe9f5e update
2023-05-04 13:42:06 +08:00

271 lines
11 KiB
Python

import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as dset
from pdb import set_trace as bp
from operator import mul
from functools import reduce
import copy
Dataset2Class = {'cifar10': 10,
'cifar100': 100,
'imagenet-1k-s': 1000,
'imagenet-1k': 1000,
}
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class RandChannel(object):
# randomly pick channels from input
def __init__(self, num_channel):
self.num_channel = num_channel
def __repr__(self):
return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
channel = img.size(0)
channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False))
return torch.index_select(img, 0, torch.Tensor(channel_choice).long())
def get_datasets(name, root, input_size, cutout=-1):
assert len(input_size) in [3, 4]
if len(input_size) == 4:
input_size = input_size[1:]
assert input_size[1] == input_size[2]
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('imagenet-1k'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('ImageNet16'):
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
else:
raise TypeError("Unknow dataset : {:}".format(name))
#ßprint(input_size)
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name.startswith('ImageNet16'):
lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name.startswith('imagenet-1k'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if name == 'imagenet-1k':
xlists = []
xlists.append(transforms.Resize((32, 32), interpolation=2))
xlists.append(transforms.RandomCrop(input_size[1], padding=0))
elif name == 'imagenet-1k-s':
xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))]
xlists = []
else: raise ValueError('invalid name : {:}'.format(name))
xlists.append(transforms.ToTensor())
xlists.append(normalize)
xlists.append(RandChannel(input_size[0]))
train_transform = transforms.Compose(xlists)
test_transform = transforms.Compose([transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), normalize])
else:
raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name.startswith('imagenet-1k'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
else: raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name]
return train_data, test_data, class_num
class LinearRegionCount(object):
"""Computes and stores the average and current value"""
def __init__(self, n_samples):
self.ActPattern = {}
self.n_LR = -1
self.n_samples = n_samples
self.ptr = 0
self.activations = None
@torch.no_grad()
def update2D(self, activations):
n_batch = activations.size()[0]
n_neuron = activations.size()[1]
self.n_neuron = n_neuron
if self.activations is None:
self.activations = torch.zeros(self.n_samples, n_neuron).cuda()
self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations) # after ReLU
self.ptr += n_batch
@torch.no_grad()
def calc_LR(self):
res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) # each element in res: A * (1 - B)
res += res.T # make symmetric, each element in res: A * (1 - B) + (1 - A) * B, a non-zero element indicate a pair of two different linear regions
res = 1 - torch.sign(res) # a non-zero element now indicate two linear regions are identical
res = res.sum(1) # for each sample's linear region: how many identical regions from other samples
res = 1. / res.float() # contribution of each redudant (repeated) linear region
self.n_LR = res.sum().item() # sum of unique regions (by aggregating contribution of all regions)
del self.activations, res
self.activations = None
torch.cuda.empty_cache()
@torch.no_grad()
def update1D(self, activationList):
code_string = ''
for key, value in activationList.items():
n_neuron = value.size()[0]
for i in range(n_neuron):
if value[i] > 0:
code_string += '1'
else:
code_string += '0'
if code_string not in self.ActPattern:
self.ActPattern[code_string] = 1
def getLinearReginCount(self):
if self.n_LR == -1:
self.calc_LR()
return self.n_LR
class Linear_Region_Collector:
def __init__(self, models=[], input_size=(64, 3, 32, 32), sample_batch=100, dataset='cifar100', data_path=None, seed=0):
self.models = []
self.input_size = input_size # BCHW
self.sample_batch = sample_batch
self.input_numel = reduce(mul, self.input_size, 1)
self.interFeature = []
self.dataset = dataset
self.data_path = data_path
self.seed = seed
self.reinit(models, input_size, sample_batch, seed)
def reinit(self, ori_models=None, input_size=None, sample_batch=None, seed=None, weights=None):
models = []
for network in ori_models:
network = network.cuda()
net = copy.deepcopy(network)
net.proj_weights = weights
num_edge, num_op = net.num_edge, net.num_op
for i in range(num_edge):
net.candidate_flags[i] = False
net.eval()
models.append(net)
if models is not None:
assert isinstance(models, list)
del self.models
self.models = models
for model in self.models:
self.register_hook(model)
device = torch.cuda.current_device()
model = model.cuda(device=device)
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(models))]
if input_size is not None or sample_batch is not None:
if input_size is not None:
self.input_size = input_size # BCHW
self.input_numel = reduce(mul, self.input_size, 1)
if sample_batch is not None:
self.sample_batch = sample_batch
if self.data_path is not None:
self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1)
self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True)
self.loader = iter(self.train_loader)
if seed is not None and seed != self.seed:
self.seed = seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
del self.interFeature
self.interFeature = []
torch.cuda.empty_cache()
def clear(self):
self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))]
del self.interFeature
self.interFeature = []
torch.cuda.empty_cache()
def register_hook(self, model):
for m in model.modules():
if isinstance(m, nn.ReLU):
m.register_forward_hook(hook=self.hook_in_forward)
def hook_in_forward(self, module, input, output):
if isinstance(input, tuple) and len(input[0].size()) == 4:
self.interFeature.append(output.detach()) # for ReLU
def forward_batch_sample(self):
for _ in range(self.sample_batch):
try:
inputs, targets = self.loader.next()
except Exception:
del self.loader
self.loader = iter(self.train_loader)
inputs, targets = self.loader.next()
for model, LRCount in zip(self.models, self.LRCounts):
self.forward(model, LRCount, inputs)
output = [LRCount.getLinearReginCount() for LRCount in self.LRCounts]
return output
def forward(self, model, LRCount, input_data):
self.interFeature = []
with torch.no_grad():
model.forward(input_data.cuda())
if len(self.interFeature) == 0: return
feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1)
LRCount.update2D(feature_data)