first commit
This commit is contained in:
		
							
								
								
									
										33
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | |||||||
|  | # Sample-Wise Activation Patterns for Ultra-Fast NAS <br/> (ICLR 2024 Spotlight) | ||||||
|  | Training-free metrics (a.k.a. zero-cost proxies) are widely used to avoid resource-intensive neural network training, especially in Neural Architecture Search (NAS). Recent studies show that existing training-free metrics have several limitations, such as limited correlation and poor generalisation across different search spaces and tasks. Hence, we propose Sample-Wise Activation Patterns and its derivative, SWAP-Score, a novel high-performance training-free metric. It measures the expressivity of networks over a batch of input samples. The SWAP-Score is strongly correlated with ground-truth performance across various search spaces and tasks, outperforming 15 existing training-free metrics on NAS-Bench-101/201/301 and TransNAS-Bench-101. | ||||||
|  |  | ||||||
|  | # Usage | ||||||
|  |  | ||||||
|  | The following instruction demonstrates the usage of evaluating network's performance through SWAP-Score. | ||||||
|  |  | ||||||
|  | **/src/metrics/swap.py** contains the core components of SWAP-Score.  | ||||||
|  |  | ||||||
|  | **/datasets/DARTS_archs_CIFAR10.csv** contains 1000 architectures (randomly sampled from DARTS space) along with their CIFAR-10 validation accuracies (trained for 200 epochs). | ||||||
|  |  | ||||||
|  | * Install necessary dependencies (a new virtual environment is suggested). | ||||||
|  | ``` | ||||||
|  | cd SWAP | ||||||
|  | pip install -r requirements.txt | ||||||
|  | ``` | ||||||
|  | * Calculate the correlation between SWAP-Score and CIFAR-10 validation accuracies of 1000 DARTS architectures. | ||||||
|  | ``` | ||||||
|  | python correlation.py | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|  | If you use or build on our code, please consider citing our paper: | ||||||
|  | ``` | ||||||
|  | @inproceedings{ | ||||||
|  | peng2024swapnas, | ||||||
|  | title={{SWAP}-{NAS}: Sample-Wise Activation Patterns for Ultra-fast {NAS}}, | ||||||
|  | author={Yameng Peng and Andy Song and Haytham M. Fayek and Vic Ciesielski and Xiaojun Chang}, | ||||||
|  | booktitle={The Twelfth International Conference on Learning Representations}, | ||||||
|  | year={2024}, | ||||||
|  | url={https://openreview.net/forum?id=tveiUXU2aa} | ||||||
|  | } | ||||||
|  | ``` | ||||||
							
								
								
									
										66
									
								
								correlation.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								correlation.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | |||||||
|  | import os | ||||||
|  | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | ||||||
|  | import argparse | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import numpy as np | ||||||
|  | import pandas as pd | ||||||
|  | from scipy import stats | ||||||
|  | from src.utils.utilities import * | ||||||
|  | from src.metrics.swap import SWAP | ||||||
|  | from src.datasets.utilities import get_datasets | ||||||
|  | from src.search_space.networks import * | ||||||
|  |  | ||||||
|  | # Settings for console outputs | ||||||
|  | import warnings | ||||||
|  | warnings.simplefilter(action='ignore', category=FutureWarning) | ||||||
|  | warnings.simplefilter(action='ignore', category=UserWarning) | ||||||
|  |  | ||||||
|  | parser = argparse.ArgumentParser() | ||||||
|  |  | ||||||
|  | # general setting | ||||||
|  | parser.add_argument('--data_path', default="datasets", type=str, nargs='?', help='path to the image dataset (datasets or datasets/ILSVRC/Data/CLS-LOC)') | ||||||
|  | parser.add_argument('--seed', default=0, type=int, help='random seed') | ||||||
|  | parser.add_argument('--device', default="mps", type=str, nargs='?', help='setup device (cpu, mps or cuda)') | ||||||
|  | parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') | ||||||
|  | parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') | ||||||
|  |  | ||||||
|  | args = parser.parse_args() | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |      | ||||||
|  |     device = torch.device(args.device) | ||||||
|  |  | ||||||
|  |     arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',') | ||||||
|  |      | ||||||
|  |     train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1) | ||||||
|  |     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True) | ||||||
|  |     loader = iter(train_loader) | ||||||
|  |     inputs, _ = next(loader)   | ||||||
|  |  | ||||||
|  |     results = [] | ||||||
|  |      | ||||||
|  |     for index, i in arch_info.iterrows(): | ||||||
|  |         print(f'Evaluating network: {index}') | ||||||
|  |  | ||||||
|  |         network = Network(3, 10, 1, eval(i.genotype)) | ||||||
|  |         network = network.to(device) | ||||||
|  |  | ||||||
|  |         swap = SWAP(model=network, inputs=inputs, device=device, seed=args.seed) | ||||||
|  |  | ||||||
|  |         swap_score = [] | ||||||
|  |  | ||||||
|  |         for _ in range(args.repeats): | ||||||
|  |             network = network.apply(network_weight_gaussian_init) | ||||||
|  |             swap.reinit() | ||||||
|  |             swap_score.append(swap.forward()) | ||||||
|  |             swap.clear() | ||||||
|  |  | ||||||
|  |         results.append([np.mean(swap_score), i.valid_acc]) | ||||||
|  |  | ||||||
|  |     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc']) | ||||||
|  |     print()     | ||||||
|  |     print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}') | ||||||
|  |      | ||||||
|  |  | ||||||
|  |  | ||||||
							
								
								
									
										1000
									
								
								datasets/DARTS_archs_CIFAR10.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1000
									
								
								datasets/DARTS_archs_CIFAR10.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										5
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | numpy>=1.24.2 | ||||||
|  | pandas>=1.5.3 | ||||||
|  | scipy>=1.10.0 | ||||||
|  | torch>=2.0.1 | ||||||
|  | torchvision>=0.15.2 | ||||||
							
								
								
									
										0
									
								
								src/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										109
									
								
								src/datasets/DownsampledImageNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								src/datasets/DownsampledImageNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | |||||||
|  | import os, sys, hashlib | ||||||
|  | import numpy as np | ||||||
|  | from PIL import Image | ||||||
|  | import torch.utils.data as data | ||||||
|  | if sys.version_info[0] == 2: | ||||||
|  |     import cPickle as pickle | ||||||
|  | else: | ||||||
|  |     import pickle | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def calculate_md5(fpath, chunk_size=1024 * 1024): | ||||||
|  |     md5 = hashlib.md5() | ||||||
|  |     with open(fpath, 'rb') as f: | ||||||
|  |         for chunk in iter(lambda: f.read(chunk_size), b''): | ||||||
|  |             md5.update(chunk) | ||||||
|  |     return md5.hexdigest() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_md5(fpath, md5, **kwargs): | ||||||
|  |     return md5 == calculate_md5(fpath, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_integrity(fpath, md5=None): | ||||||
|  |     if not os.path.isfile(fpath): return False | ||||||
|  |     if md5 is None: return True | ||||||
|  |     else          : return check_md5(fpath, md5) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ImageNet16(data.Dataset): | ||||||
|  |     # http://image-net.org/download-images | ||||||
|  |     # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets | ||||||
|  |     # https://arxiv.org/pdf/1707.08819.pdf | ||||||
|  |  | ||||||
|  |     train_list = [ | ||||||
|  |                 ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], | ||||||
|  |                 ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], | ||||||
|  |                 ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], | ||||||
|  |                 ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], | ||||||
|  |                 ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], | ||||||
|  |                 ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], | ||||||
|  |                 ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], | ||||||
|  |                 ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], | ||||||
|  |                 ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], | ||||||
|  |                 ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], | ||||||
|  |         ] | ||||||
|  |     valid_list = [ | ||||||
|  |                 ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |     def __init__(self, root, train, transform, use_num_of_class_only=None): | ||||||
|  |         self.root      = root | ||||||
|  |         self.transform = transform | ||||||
|  |         self.train     = train  # training set or valid set | ||||||
|  |         if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') | ||||||
|  |  | ||||||
|  |         if self.train: downloaded_list = self.train_list | ||||||
|  |         else         : downloaded_list = self.valid_list | ||||||
|  |         self.data    = [] | ||||||
|  |         self.targets = [] | ||||||
|  |  | ||||||
|  |         # now load the picked numpy arrays | ||||||
|  |         for i, (file_name, checksum) in enumerate(downloaded_list): | ||||||
|  |             file_path = os.path.join(self.root, file_name) | ||||||
|  |             #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) | ||||||
|  |             with open(file_path, 'rb') as f: | ||||||
|  |                 if sys.version_info[0] == 2: | ||||||
|  |                     entry = pickle.load(f) | ||||||
|  |                 else: | ||||||
|  |                     entry = pickle.load(f, encoding='latin1') | ||||||
|  |                 self.data.append(entry['data']) | ||||||
|  |                 self.targets.extend(entry['labels']) | ||||||
|  |         self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) | ||||||
|  |         self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||||
|  |         if use_num_of_class_only is not None: | ||||||
|  |             assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) | ||||||
|  |             new_data, new_targets = [], [] | ||||||
|  |             for I, L in zip(self.data, self.targets): | ||||||
|  |                 if 1 <= L <= use_num_of_class_only: | ||||||
|  |                     new_data.append( I ) | ||||||
|  |                     new_targets.append( L ) | ||||||
|  |             self.data    = new_data | ||||||
|  |             self.targets = new_targets | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         img, target = self.data[index], self.targets[index] - 1 | ||||||
|  |  | ||||||
|  |         img = Image.fromarray(img) | ||||||
|  |  | ||||||
|  |         if self.transform is not None: | ||||||
|  |             img = self.transform(img) | ||||||
|  |  | ||||||
|  |         return img, target | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.data) | ||||||
|  |  | ||||||
|  |     def _check_integrity(self): | ||||||
|  |         root = self.root | ||||||
|  |         for fentry in (self.train_list + self.valid_list): | ||||||
|  |             filename, md5 = fentry[0], fentry[1] | ||||||
|  |             fpath = os.path.join(root, filename) | ||||||
|  |             if not check_integrity(fpath, md5): | ||||||
|  |                 return False | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     pass | ||||||
							
								
								
									
										0
									
								
								src/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										115
									
								
								src/datasets/utilities.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								src/datasets/utilities.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,115 @@ | |||||||
|  | import os.path as osp | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import torchvision.transforms as transforms | ||||||
|  | import torchvision.datasets as dset | ||||||
|  | from .DownsampledImageNet import ImageNet16 | ||||||
|  | from sklearn.model_selection import StratifiedKFold | ||||||
|  |  | ||||||
|  | Dataset2Class = {'cifar10': 10, | ||||||
|  |                  'cifar100': 100, | ||||||
|  |                  'imagenet-1k-s': 1000, | ||||||
|  |                  'imagenet-1k': 1000, | ||||||
|  |                  'ImageNet16' : 1000, | ||||||
|  |                  'ImageNet16-120': 120, | ||||||
|  |                  'ImageNet16-150': 150, | ||||||
|  |                  'ImageNet16-200': 200} | ||||||
|  |  | ||||||
|  | class RandChannel(object): | ||||||
|  |     # randomly pick channels from input | ||||||
|  |     def __init__(self, num_channel): | ||||||
|  |         self.num_channel = num_channel | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||||
|  |  | ||||||
|  |     def __call__(self, img): | ||||||
|  |         channel = img.size(0) | ||||||
|  |         channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False)) | ||||||
|  |         return torch.index_select(img, 0, torch.Tensor(channel_choice).long()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_datasets(name, root, input_size, cutout=-1): | ||||||
|  |     assert len(input_size) in [3, 4] | ||||||
|  |     if len(input_size) == 4: | ||||||
|  |         input_size = input_size[1:] | ||||||
|  |     assert input_size[1] == input_size[2] | ||||||
|  |  | ||||||
|  |     if name == 'cifar10': | ||||||
|  |         mean = [0.49139968, 0.48215827, 0.44653124] | ||||||
|  |         std  = [0.24703233, 0.24348505, 0.26158768] | ||||||
|  |     elif name == 'cifar100': | ||||||
|  |         mean = [0.5071, 0.4865, 0.4409] | ||||||
|  |         std  = [0.2673, 0.2564, 0.2762] | ||||||
|  |     elif name.startswith('imagenet-1k'): | ||||||
|  |         mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||||
|  |     elif name.startswith('ImageNet16'): | ||||||
|  |         mean = [0.481098, 0.45749, 0.407882] | ||||||
|  |         std  = [0.247922, 0.240235, 0.255255] | ||||||
|  |     else: | ||||||
|  |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|  |     # Data Argumentation | ||||||
|  |     if name == 'cifar10' or name == 'cifar100': | ||||||
|  |         lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] | ||||||
|  |         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||||
|  |         train_transform = transforms.Compose(lists) | ||||||
|  |         test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|  |     elif name.startswith('ImageNet16'): | ||||||
|  |         lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] | ||||||
|  |         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||||
|  |         train_transform = transforms.Compose(lists) | ||||||
|  |         test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||||
|  |     elif name.startswith('imagenet-1k'): | ||||||
|  |         normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||||
|  |         if name == 'imagenet-1k': | ||||||
|  |             xlists    = [] | ||||||
|  |             xlists.append(transforms.Resize((input_size[1], input_size[1]), interpolation=2)) | ||||||
|  |             xlists.append(transforms.RandomCrop(input_size[1], padding=0)) | ||||||
|  |         elif name == 'imagenet-1k-s': | ||||||
|  |             xlists = [transforms.RandomResizedCrop(input_size[1], scale=(0.2, 1.0))] | ||||||
|  |             xlists = [] | ||||||
|  |         else: raise ValueError('invalid name : {:}'.format(name)) | ||||||
|  |         xlists.append(transforms.ToTensor()) | ||||||
|  |         xlists.append(normalize) | ||||||
|  |         xlists.append(RandChannel(input_size[0])) | ||||||
|  |         train_transform = transforms.Compose(xlists) | ||||||
|  |         test_transform = transforms.Compose([transforms.Resize(input_size[1]), transforms.CenterCrop(input_size[1]), transforms.ToTensor(), normalize]) | ||||||
|  |     else: | ||||||
|  |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|  |     if name == 'cifar10': | ||||||
|  |         train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True) | ||||||
|  |         test_data  = dset.CIFAR10 (root, train=False, transform=test_transform , download=True) | ||||||
|  |         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||||
|  |     elif name == 'cifar100': | ||||||
|  |         train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) | ||||||
|  |         test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) | ||||||
|  |         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||||
|  |     elif name.startswith('imagenet-1k'): | ||||||
|  |         train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||||
|  |         test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) | ||||||
|  |     elif name == 'ImageNet16': | ||||||
|  |         root = osp.join(root, 'ImageNet16') | ||||||
|  |         train_data = ImageNet16(root, True , train_transform) | ||||||
|  |         test_data  = ImageNet16(root, False, test_transform) | ||||||
|  |         assert len(train_data) == 1281167 and len(test_data) == 50000 | ||||||
|  |     elif name == 'ImageNet16-120': | ||||||
|  |         root = osp.join(root, 'ImageNet16') | ||||||
|  |         train_data = ImageNet16(root, True , train_transform, 120) | ||||||
|  |         test_data  = ImageNet16(root, False, test_transform , 120) | ||||||
|  |         assert len(train_data) == 151700 and len(test_data) == 6000 | ||||||
|  |     elif name == 'ImageNet16-150': | ||||||
|  |         root = osp.join(root, 'ImageNet16') | ||||||
|  |         train_data = ImageNet16(root, True , train_transform, 150) | ||||||
|  |         test_data  = ImageNet16(root, False, test_transform , 150) | ||||||
|  |         assert len(train_data) == 190272 and len(test_data) == 7500 | ||||||
|  |     elif name == 'ImageNet16-200': | ||||||
|  |         root = osp.join(root, 'ImageNet16') | ||||||
|  |         train_data = ImageNet16(root, True , train_transform, 200) | ||||||
|  |         test_data  = ImageNet16(root, False, test_transform , 200) | ||||||
|  |         assert len(train_data) == 254775 and len(test_data) == 10000 | ||||||
|  |     else: raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|  |     class_num = Dataset2Class[name] | ||||||
|  |     return train_data, test_data, class_num | ||||||
							
								
								
									
										0
									
								
								src/metrics/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/metrics/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										99
									
								
								src/metrics/swap.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								src/metrics/swap.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | |||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from src.utils.utilities import count_parameters | ||||||
|  |  | ||||||
|  | def cal_regular_factor(model, mu, sigma): | ||||||
|  |  | ||||||
|  |     model_params = torch.as_tensor(count_parameters(model)) | ||||||
|  |     regular_factor =  torch.exp(-(torch.pow((model_params-mu),2)/sigma)) | ||||||
|  |     | ||||||
|  |     return regular_factor | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SampleWiseActivationPatterns(object): | ||||||
|  |     def __init__(self, device): | ||||||
|  |         self.swap = -1  | ||||||
|  |         self.activations = None | ||||||
|  |         self.device = device | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def collect_activations(self, activations): | ||||||
|  |         n_sample = activations.size()[0] | ||||||
|  |         n_neuron = activations.size()[1] | ||||||
|  |  | ||||||
|  |         if self.activations is None: | ||||||
|  |             self.activations = torch.zeros(n_sample, n_neuron).to(self.device)   | ||||||
|  |  | ||||||
|  |         self.activations = torch.sign(activations) | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def calSWAP(self, regular_factor): | ||||||
|  |          | ||||||
|  |         self.activations = self.activations.T # transpose the activation matrix: (samples, neurons) to (neurons, samples) | ||||||
|  |         self.swap = torch.unique(self.activations, dim=0).size(0) | ||||||
|  |          | ||||||
|  |         del self.activations | ||||||
|  |         self.activations = None | ||||||
|  |         torch.cuda.empty_cache() | ||||||
|  |  | ||||||
|  |         return self.swap * regular_factor | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SWAP: | ||||||
|  |     def __init__(self, model=None, inputs = None, device='cuda', seed=0, regular=False, mu=None, sigma=None): | ||||||
|  |         self.model = model | ||||||
|  |         self.interFeature = [] | ||||||
|  |         self.seed = seed | ||||||
|  |         self.regular_factor = 1 | ||||||
|  |         self.inputs = inputs | ||||||
|  |         self.device = device | ||||||
|  |  | ||||||
|  |         if regular and mu is not None and sigma is not None: | ||||||
|  |             self.regular_factor = cal_regular_factor(self.model, mu, sigma).item() | ||||||
|  |  | ||||||
|  |         self.reinit(self.model, self.seed) | ||||||
|  |  | ||||||
|  |     def reinit(self, model=None, seed=None): | ||||||
|  |         if model is not None: | ||||||
|  |             self.model = model | ||||||
|  |             self.register_hook(self.model) | ||||||
|  |             self.swap = SampleWiseActivationPatterns(self.device) | ||||||
|  |  | ||||||
|  |         if seed is not None and seed != self.seed: | ||||||
|  |             self.seed = seed | ||||||
|  |             torch.manual_seed(seed) | ||||||
|  |             torch.cuda.manual_seed(seed) | ||||||
|  |         del self.interFeature | ||||||
|  |         self.interFeature = [] | ||||||
|  |         torch.cuda.empty_cache() | ||||||
|  |  | ||||||
|  |     def clear(self): | ||||||
|  |         self.swap = SampleWiseActivationPatterns(self.device) | ||||||
|  |         del self.interFeature | ||||||
|  |         self.interFeature = [] | ||||||
|  |         torch.cuda.empty_cache() | ||||||
|  |  | ||||||
|  |     def register_hook(self, model): | ||||||
|  |         for n, m in model.named_modules(): | ||||||
|  |             if isinstance(m, nn.ReLU): | ||||||
|  |                 m.register_forward_hook(hook=self.hook_in_forward) | ||||||
|  |  | ||||||
|  |     def hook_in_forward(self, module, input, output): | ||||||
|  |         if isinstance(input, tuple) and len(input[0].size()) == 4: | ||||||
|  |             self.interFeature.append(output.detach())  | ||||||
|  |  | ||||||
|  |     def forward(self): | ||||||
|  |         self.interFeature = [] | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             self.model.forward(self.inputs.to(self.device)) | ||||||
|  |             if len(self.interFeature) == 0: return | ||||||
|  |             activtions = torch.cat([f.view(self.inputs.size(0), -1) for f in self.interFeature], 1)          | ||||||
|  |             self.swap.collect_activations(activtions) | ||||||
|  |              | ||||||
|  |             return self.swap.calSWAP(self.regular_factor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
							
								
								
									
										0
									
								
								src/search_space/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/search_space/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										105
									
								
								src/search_space/networks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								src/search_space/networks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | |||||||
|  | from .operations import * | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from collections import namedtuple | ||||||
|  |  | ||||||
|  | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') | ||||||
|  |  | ||||||
|  | def drop_path(x, drop_prob): | ||||||
|  |   if drop_prob > 0.: | ||||||
|  |     x = nn.functional.dropout(x, p=drop_prob) | ||||||
|  |  | ||||||
|  |   return x | ||||||
|  |  | ||||||
|  | class Cell(nn.Module): | ||||||
|  |  | ||||||
|  |   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): | ||||||
|  |     super(Cell, self).__init__() | ||||||
|  |  | ||||||
|  |     if reduction_prev: | ||||||
|  |       self.preprocess0 = FactorizedReduce(C_prev_prev, C) | ||||||
|  |     else: | ||||||
|  |       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, True) | ||||||
|  |     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, True) | ||||||
|  |      | ||||||
|  |     if reduction: | ||||||
|  |         op_names, indices = zip(*genotype.reduce) | ||||||
|  |         concat = genotype.reduce_concat # 2,3,4,5 | ||||||
|  |     else: | ||||||
|  |         op_names, indices = zip(*genotype.normal) | ||||||
|  |         concat = genotype.normal_concat # 2,3,4,5 | ||||||
|  |     self._compile(C, op_names, indices, concat, reduction) | ||||||
|  |  | ||||||
|  |   def _compile(self, C, op_names, indices, concat, reduction): | ||||||
|  |     assert len(op_names) == len(indices) | ||||||
|  |     self._steps = len(op_names) // 2 # 4 | ||||||
|  |     self._concat = concat # 2,3,4,5 | ||||||
|  |     self.multiplier = len(concat) # 4 | ||||||
|  |     self._ops = nn.ModuleList() | ||||||
|  |  | ||||||
|  |     for name, index in zip(op_names, indices): | ||||||
|  |         stride = 2 if reduction and index < 2 else 1 | ||||||
|  |         op = OPS[name](C, C, stride, True) | ||||||
|  |         self._ops += [op] | ||||||
|  |     self._indices = indices | ||||||
|  |  | ||||||
|  |   def forward(self, s0, s1, drop_prob): | ||||||
|  |     s0 = self.preprocess0(s0) | ||||||
|  |     s1 = self.preprocess1(s1) | ||||||
|  |  | ||||||
|  |     states = [s0, s1] | ||||||
|  |     for i in range(self._steps): | ||||||
|  |       h1 = states[self._indices[2*i]] | ||||||
|  |       h2 = states[self._indices[2*i+1]] | ||||||
|  |       op1 = self._ops[2*i] | ||||||
|  |       op2 = self._ops[2*i+1] | ||||||
|  |       h1 = op1(h1) | ||||||
|  |       h2 = op2(h2) | ||||||
|  |       if self.training and drop_prob > 0.: | ||||||
|  |         if not isinstance(op1, Identity): | ||||||
|  |           h1 = drop_path(h1, drop_prob) | ||||||
|  |         if not isinstance(op2, Identity): | ||||||
|  |           h2 = drop_path(h2, drop_prob) | ||||||
|  |       s = h1 + h2 | ||||||
|  |       states += [s] | ||||||
|  |     return torch.cat([states[i] for i in self._concat], dim=1) | ||||||
|  |  | ||||||
|  | class Network(nn.Module): | ||||||
|  |  | ||||||
|  |     def __init__(self, C, num_classes, layers, genotype): | ||||||
|  |         self.drop_path_prob = 0. | ||||||
|  |         super(Network, self).__init__() | ||||||
|  |          | ||||||
|  |         self._layers = layers | ||||||
|  |  | ||||||
|  |         C_prev_prev, C_prev, C_curr = C, C, C | ||||||
|  |          | ||||||
|  |         self.cells = nn.ModuleList() | ||||||
|  |         reduction_prev = False | ||||||
|  |  | ||||||
|  |         for i in range(layers): | ||||||
|  |             if i in [layers // 3, 2 * layers // 3]: | ||||||
|  |                 C_curr *= 2 | ||||||
|  |                 reduction = True | ||||||
|  |             else: | ||||||
|  |                 reduction = False | ||||||
|  |             cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||||
|  |             reduction_prev = reduction | ||||||
|  |             self.cells += [cell] | ||||||
|  |             C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr | ||||||
|  |  | ||||||
|  |         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||||
|  |         self.classifier = nn.Linear(C_prev, num_classes) | ||||||
|  |  | ||||||
|  |     def forward(self, input): | ||||||
|  |         s0 = s1 = input | ||||||
|  |          | ||||||
|  |         for i, cell in enumerate(self.cells): | ||||||
|  |             s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||||
|  |  | ||||||
|  |         out = self.global_pooling(s1) | ||||||
|  |         out = out.view(out.size(0), -1) | ||||||
|  |         logits = self.classifier(out) | ||||||
|  |         return out | ||||||
|  |  | ||||||
							
								
								
									
										147
									
								
								src/search_space/operations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								src/search_space/operations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,147 @@ | |||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  |  | ||||||
|  | OPS = { | ||||||
|  |     'none': lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride), | ||||||
|  |     'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine), | ||||||
|  |     'max_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max', affine), | ||||||
|  |     'skip_connect': lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), | ||||||
|  |     'sep_conv_3x3': lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine), | ||||||
|  |     'sep_conv_5x5': lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine), | ||||||
|  |     'dil_conv_3x3': lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine), | ||||||
|  |     'dil_conv_5x5': lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine), | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ReLUConvBN(nn.Module): | ||||||
|  |  | ||||||
|  |     def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||||
|  |         super(ReLUConvBN, self).__init__() | ||||||
|  |         self.op = nn.Sequential( | ||||||
|  |             nn.ReLU(inplace=False), | ||||||
|  |             nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), | ||||||
|  |             nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         return self.op(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DilConv(nn.Module): | ||||||
|  |  | ||||||
|  |     def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, track_running_stats=True): | ||||||
|  |         super(DilConv, self).__init__() | ||||||
|  |         self.op = nn.Sequential( | ||||||
|  |             nn.ReLU(inplace=False), | ||||||
|  |             nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, | ||||||
|  |                       groups=C_in, bias=False), | ||||||
|  |             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||||
|  |             nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         return self.op(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SepConv(nn.Module): | ||||||
|  |  | ||||||
|  |     def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, track_running_stats=True): | ||||||
|  |         super(SepConv, self).__init__() | ||||||
|  |         self.op = nn.Sequential( | ||||||
|  |             nn.ReLU(inplace=False), | ||||||
|  |             nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), | ||||||
|  |             nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||||
|  |             nn.BatchNorm2d(C_in, affine=affine, track_running_stats=track_running_stats), | ||||||
|  |              | ||||||
|  |             nn.ReLU(inplace=False), | ||||||
|  |             nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), | ||||||
|  |             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||||
|  |             nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         return self.op(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Identity(nn.Module): | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         super(Identity, self).__init__() | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FactorizedReduce(nn.Module): | ||||||
|  |     def __init__(self, C_in, C_out, stride=2, affine=True, track_running_stats=True): | ||||||
|  |         super(FactorizedReduce, self).__init__() | ||||||
|  |         self.stride = stride | ||||||
|  |         self.C_in   = C_in | ||||||
|  |         self.C_out  = C_out | ||||||
|  |         self.relu   = nn.ReLU(inplace=False) | ||||||
|  |         if stride == 2: | ||||||
|  |             C_outs = [C_out // 2, C_out - C_out // 2] | ||||||
|  |             self.convs = nn.ModuleList() | ||||||
|  |             for i in range(2): | ||||||
|  |                 self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False)) | ||||||
|  |             self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||||
|  |         elif stride == 1: | ||||||
|  |             self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False) | ||||||
|  |         else: | ||||||
|  |             raise ValueError('Invalid stride : {:}'.format(stride)) | ||||||
|  |         self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         if self.stride == 2: | ||||||
|  |             x = self.relu(x) | ||||||
|  |             y = self.pad(x) | ||||||
|  |             out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1) | ||||||
|  |         else: | ||||||
|  |             out = self.conv(x) | ||||||
|  |         out = self.bn(out) | ||||||
|  |         return out | ||||||
|  |  | ||||||
|  |     def extra_repr(self): | ||||||
|  |         return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Zero(nn.Module): | ||||||
|  |  | ||||||
|  |     def __init__(self, C_in, C_out, stride): | ||||||
|  |         super(Zero, self).__init__() | ||||||
|  |         self.C_in   = C_in | ||||||
|  |         self.C_out  = C_out | ||||||
|  |         self.stride = stride | ||||||
|  |         self.is_zero = True | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         if self.C_in == self.C_out: | ||||||
|  |             if self.stride == 1: return x.mul(0.) | ||||||
|  |             else               : return x[:,:,::self.stride,::self.stride].mul(0.) | ||||||
|  |         else: | ||||||
|  |             shape = list(x.shape) | ||||||
|  |             shape[1] = self.C_out | ||||||
|  |             zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) | ||||||
|  |             return zeros | ||||||
|  |  | ||||||
|  |     def extra_repr(self): | ||||||
|  |         return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class POOLING(nn.Module): | ||||||
|  |  | ||||||
|  |     def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True): | ||||||
|  |         super(POOLING, self).__init__() | ||||||
|  |         if C_in == C_out: | ||||||
|  |             self.preprocess = None | ||||||
|  |         else: | ||||||
|  |             self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats) | ||||||
|  |         if mode == 'avg'  : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||||
|  |         elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||||
|  |         else              : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) | ||||||
|  |  | ||||||
|  |     def forward(self, inputs): | ||||||
|  |         if self.preprocess: x = self.preprocess(inputs) | ||||||
|  |         else              : x = inputs | ||||||
|  |         return self.op(x) | ||||||
|  |  | ||||||
							
								
								
									
										0
									
								
								src/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										38
									
								
								src/utils/utilities.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								src/utils/utilities.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | |||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  |  | ||||||
|  | class Model(object): | ||||||
|  |     def __init__(self): | ||||||
|  |         self.arch = None | ||||||
|  |         self.geno = None | ||||||
|  |         self.score = None | ||||||
|  |  | ||||||
|  | def count_parameters(model): | ||||||
|  |   return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e3 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def network_weight_gaussian_init(net: nn.Module): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         for m in net.modules(): | ||||||
|  |             if isinstance(m, nn.Conv2d): | ||||||
|  |                 nn.init.normal_(m.weight) | ||||||
|  |                 if hasattr(m, 'bias') and m.bias is not None: | ||||||
|  |                     nn.init.zeros_(m.bias) | ||||||
|  |             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||||||
|  |                 nn.init.ones_(m.weight) | ||||||
|  |                 nn.init.zeros_(m.bias) | ||||||
|  |             elif isinstance(m, nn.Linear): | ||||||
|  |                 nn.init.normal_(m.weight) | ||||||
|  |                 if hasattr(m, 'bias') and m.bias is not None: | ||||||
|  |                     nn.init.zeros_(m.bias) | ||||||
|  |             else: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |     return net | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
		Reference in New Issue
	
	Block a user