upload
This commit is contained in:
		
							
								
								
									
										16
									
								
								zero-cost-nas/foresight/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								zero-cost-nas/foresight/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| from .version import * | ||||
							
								
								
									
										121
									
								
								zero-cost-nas/foresight/dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								zero-cost-nas/foresight/dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | ||||
|  | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| from torch.utils.data import DataLoader | ||||
| from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN | ||||
| from torchvision.transforms import Compose, ToTensor, Normalize | ||||
| from torchvision import transforms | ||||
|  | ||||
| from .imagenet16 import * | ||||
|  | ||||
|  | ||||
| def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'): | ||||
|  | ||||
|     if 'ImageNet16' in dataset: | ||||
|         mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|         std  = [x / 255 for x in [63.22,  61.26 , 65.09]] | ||||
|         size, pad = 16, 2 | ||||
|     elif 'cifar' in dataset: | ||||
|         mean = (0.4914, 0.4822, 0.4465) | ||||
|         std = (0.2023, 0.1994, 0.2010) | ||||
|         size, pad = 32, 4 | ||||
|     elif 'svhn' in dataset: | ||||
|         mean = (0.5, 0.5, 0.5) | ||||
|         std = (0.5, 0.5, 0.5) | ||||
|         size, pad = 32, 0 | ||||
|     elif dataset == 'ImageNet1k': | ||||
|         from .h5py_dataset import H5Dataset | ||||
|         size,pad = 224,2 | ||||
|         mean = (0.485, 0.456, 0.406) | ||||
|         std  = (0.229, 0.224, 0.225) | ||||
|         #resize = 256 | ||||
|  | ||||
|     if resize is None: | ||||
|         resize = size | ||||
|  | ||||
|     train_transform = transforms.Compose([ | ||||
|         transforms.RandomCrop(size, padding=pad), | ||||
|         transforms.Resize(resize), | ||||
|         transforms.RandomHorizontalFlip(), | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(mean,std), | ||||
|     ]) | ||||
|  | ||||
|     test_transform = transforms.Compose([ | ||||
|         transforms.Resize(resize), | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(mean,std), | ||||
|     ]) | ||||
|  | ||||
|     if dataset == 'cifar10': | ||||
|         train_dataset = CIFAR10(datadir, True, train_transform, download=True) | ||||
|         test_dataset = CIFAR10(datadir, False, test_transform, download=True) | ||||
|     elif dataset == 'cifar100': | ||||
|         train_dataset = CIFAR100(datadir, True, train_transform, download=True) | ||||
|         test_dataset = CIFAR100(datadir, False, test_transform, download=True) | ||||
|     elif dataset == 'svhn': | ||||
|         train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True) | ||||
|         test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True) | ||||
|     elif dataset == 'ImageNet16-120': | ||||
|         train_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), True , train_transform, 120) | ||||
|         test_dataset  = ImageNet16(os.path.join(datadir, 'ImageNet16'), False, test_transform , 120) | ||||
|     elif dataset == 'ImageNet1k': | ||||
|         train_dataset = H5Dataset(os.path.join(datadir, 'imagenet-train-256.h5'), transform=train_transform) | ||||
|         test_dataset  = H5Dataset(os.path.join(datadir, 'imagenet-val-256.h5'),   transform=test_transform) | ||||
|              | ||||
|     else: | ||||
|         raise ValueError('There are no more cifars or imagenets.') | ||||
|  | ||||
|     train_loader = DataLoader( | ||||
|         train_dataset, | ||||
|         train_batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=num_workers, | ||||
|         pin_memory=True) | ||||
|     test_loader = DataLoader( | ||||
|         test_dataset, | ||||
|         test_batch_size, | ||||
|         shuffle=False, | ||||
|         num_workers=num_workers, | ||||
|         pin_memory=True) | ||||
|  | ||||
|     return train_loader, test_loader | ||||
|  | ||||
|  | ||||
| def get_mnist_dataloaders(train_batch_size, val_batch_size, num_workers): | ||||
|  | ||||
|     data_transform = Compose([transforms.ToTensor()]) | ||||
|  | ||||
|     # Normalise? transforms.Normalize((0.1307,), (0.3081,)) | ||||
|  | ||||
|     train_dataset = MNIST("_dataset", True, data_transform, download=True) | ||||
|     test_dataset = MNIST("_dataset", False, data_transform, download=True) | ||||
|  | ||||
|     train_loader = DataLoader( | ||||
|         train_dataset, | ||||
|         train_batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=num_workers, | ||||
|         pin_memory=True) | ||||
|     test_loader = DataLoader( | ||||
|         test_dataset, | ||||
|         val_batch_size, | ||||
|         shuffle=False, | ||||
|         num_workers=num_workers, | ||||
|         pin_memory=True) | ||||
|  | ||||
|     return train_loader, test_loader | ||||
|  | ||||
							
								
								
									
										55
									
								
								zero-cost-nas/foresight/h5py_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								zero-cost-nas/foresight/h5py_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import h5py | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
|  | ||||
|  | ||||
| import torch | ||||
| from torch.utils.data import Dataset, DataLoader | ||||
|  | ||||
| class H5Dataset(Dataset): | ||||
|     def __init__(self, h5_path, transform=None): | ||||
|         self.h5_path = h5_path | ||||
|         self.h5_file = None | ||||
|         self.length = len(h5py.File(h5_path, 'r')) | ||||
|         self.transform = transform | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|  | ||||
|         #loading in getitem allows us to use multiple processes for data loading | ||||
|         #because hdf5 files aren't pickelable so can't transfer them across processes | ||||
|         # https://discuss.pytorch.org/t/hdf5-a-data-format-for-pytorch/40379 | ||||
|         # https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16 | ||||
|         # TODO possible look at __getstate__ and __setstate__ as a more elegant solution | ||||
|         if self.h5_file is None: | ||||
|             self.h5_file = h5py.File(self.h5_path, 'r') | ||||
|  | ||||
|         record = self.h5_file[str(index)] | ||||
|  | ||||
|         if self.transform: | ||||
|             x = Image.fromarray(record['data'][()]) | ||||
|             x = self.transform(x) | ||||
|         else: | ||||
|             x = torch.from_numpy(record['data'][()]) | ||||
|  | ||||
|         y = record['target'][()] | ||||
|         y = torch.from_numpy(np.asarray(y)) | ||||
|  | ||||
|         return (x,y) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.length | ||||
							
								
								
									
										129
									
								
								zero-cost-nas/foresight/imagenet16.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								zero-cost-nas/foresight/imagenet16.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, hashlib, torch | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
| import torch.utils.data as data | ||||
| if sys.version_info[0] == 2: | ||||
|   import cPickle as pickle | ||||
| else: | ||||
|   import pickle | ||||
|  | ||||
|  | ||||
| def calculate_md5(fpath, chunk_size=1024 * 1024): | ||||
|   md5 = hashlib.md5() | ||||
|   with open(fpath, 'rb') as f: | ||||
|     for chunk in iter(lambda: f.read(chunk_size), b''): | ||||
|       md5.update(chunk) | ||||
|   return md5.hexdigest() | ||||
|  | ||||
|  | ||||
| def check_md5(fpath, md5, **kwargs): | ||||
|   return md5 == calculate_md5(fpath, **kwargs) | ||||
|  | ||||
|  | ||||
| def check_integrity(fpath, md5=None): | ||||
|   if not os.path.isfile(fpath): return False | ||||
|   if md5 is None: return True | ||||
|   else          : return check_md5(fpath, md5) | ||||
|  | ||||
|  | ||||
| class ImageNet16(data.Dataset): | ||||
|   # http://image-net.org/download-images | ||||
|   # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets | ||||
|   # https://arxiv.org/pdf/1707.08819.pdf | ||||
|    | ||||
|   train_list = [ | ||||
|         ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], | ||||
|         ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], | ||||
|         ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], | ||||
|         ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], | ||||
|         ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], | ||||
|         ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], | ||||
|         ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], | ||||
|         ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], | ||||
|         ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], | ||||
|         ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], | ||||
|     ] | ||||
|   valid_list = [ | ||||
|         ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], | ||||
|     ] | ||||
|  | ||||
|   def __init__(self, root, train, transform, use_num_of_class_only=None): | ||||
|     self.root      = root | ||||
|     self.transform = transform | ||||
|     self.train     = train  # training set or valid set | ||||
|     if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') | ||||
|  | ||||
|     if self.train: downloaded_list = self.train_list | ||||
|     else         : downloaded_list = self.valid_list | ||||
|     self.data    = [] | ||||
|     self.targets = [] | ||||
|    | ||||
|     # now load the picked numpy arrays | ||||
|     for i, (file_name, checksum) in enumerate(downloaded_list): | ||||
|       file_path = os.path.join(self.root, file_name) | ||||
|       #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) | ||||
|       with open(file_path, 'rb') as f: | ||||
|         if sys.version_info[0] == 2: | ||||
|           entry = pickle.load(f) | ||||
|         else: | ||||
|           entry = pickle.load(f, encoding='latin1') | ||||
|         self.data.append(entry['data']) | ||||
|         self.targets.extend(entry['labels']) | ||||
|     self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) | ||||
|     self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||
|     if use_num_of_class_only is not None: | ||||
|       assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) | ||||
|       new_data, new_targets = [], [] | ||||
|       for I, L in zip(self.data, self.targets): | ||||
|         if 1 <= L <= use_num_of_class_only: | ||||
|           new_data.append( I ) | ||||
|           new_targets.append( L ) | ||||
|       self.data    = new_data | ||||
|       self.targets = new_targets | ||||
|     #    self.mean.append(entry['mean']) | ||||
|     #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) | ||||
|     #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) | ||||
|     #print ('Mean : {:}'.format(self.mean)) | ||||
|     #temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3)) | ||||
|     #std_data  = np.std(temp, axis=0) | ||||
|     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||
|     #print ('Std  : {:}'.format(std_data)) | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     img, target = self.data[index], self.targets[index] - 1 | ||||
|  | ||||
|     img = Image.fromarray(img) | ||||
|  | ||||
|     if self.transform is not None: | ||||
|       img = self.transform(img) | ||||
|  | ||||
|     return img, target | ||||
|  | ||||
|   def __len__(self): | ||||
|     return len(self.data) | ||||
|  | ||||
|   def _check_integrity(self): | ||||
|     root = self.root | ||||
|     for fentry in (self.train_list + self.valid_list): | ||||
|       filename, md5 = fentry[0], fentry[1] | ||||
|       fpath = os.path.join(root, filename) | ||||
|       if not check_integrity(fpath, md5): | ||||
|         return False | ||||
|     return True | ||||
|  | ||||
| # | ||||
| if __name__ == '__main__': | ||||
|   train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)  | ||||
|   valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)  | ||||
|  | ||||
|   print ( len(train) ) | ||||
|   print ( len(valid) ) | ||||
|   image, label = train[111] | ||||
|   trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) | ||||
|   validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) | ||||
|   print ( len(trainX) ) | ||||
|   print ( len(validX) ) | ||||
|   #import pdb; pdb.set_trace() | ||||
							
								
								
									
										19
									
								
								zero-cost-nas/foresight/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								zero-cost-nas/foresight/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| from os.path import dirname, basename, isfile, join | ||||
| import glob | ||||
| modules = glob.glob(join(dirname(__file__), "*.py")) | ||||
| __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] | ||||
							
								
								
									
										251
									
								
								zero-cost-nas/foresight/models/nasbench1.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								zero-cost-nas/foresight/models/nasbench1.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,251 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| """Builds the Pytorch computational graph. | ||||
| Tensors flowing into a single vertex are added together for all vertices | ||||
| except the output, which is concatenated instead. Tensors flowing out of input | ||||
| are always added. | ||||
| If interior edge channels don't match, drop the extra channels (channels are | ||||
| guaranteed non-decreasing). Tensors flowing out of the input as always | ||||
| projected instead. | ||||
| """ | ||||
|  | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import numpy as np | ||||
| import math | ||||
|  | ||||
| from .nasbench1_ops import * | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| class Network(nn.Module): | ||||
|     def __init__(self, spec, stem_out, num_stacks, num_mods, num_classes, bn=True): | ||||
|         super(Network, self).__init__() | ||||
|  | ||||
|         self.spec=spec | ||||
|         self.stem_out=stem_out  | ||||
|         self.num_stacks=num_stacks  | ||||
|         self.num_mods=num_mods | ||||
|         self.num_classes=num_classes | ||||
|  | ||||
|         self.layers = nn.ModuleList([]) | ||||
|  | ||||
|         in_channels = 3 | ||||
|         out_channels = stem_out | ||||
|  | ||||
|         # initial stem convolution | ||||
|         stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1, bn=bn) | ||||
|         self.layers.append(stem_conv) | ||||
|  | ||||
|         in_channels = out_channels | ||||
|         for stack_num in range(num_stacks): | ||||
|             if stack_num > 0: | ||||
|                 downsample = nn.MaxPool2d(kernel_size=2, stride=2) | ||||
|                 self.layers.append(downsample) | ||||
|  | ||||
|                 out_channels *= 2 | ||||
|  | ||||
|             for _ in range(num_mods): | ||||
|                 cell = Cell(spec, in_channels, out_channels, bn=bn) | ||||
|                 self.layers.append(cell) | ||||
|                 in_channels = out_channels | ||||
|  | ||||
|         self.classifier = nn.Linear(out_channels, num_classes) | ||||
|  | ||||
|         self._initialize_weights() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         for _, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         out = torch.mean(x, (2, 3)) | ||||
|         out = self.classifier(out) | ||||
|  | ||||
|         return out | ||||
|      | ||||
|     def get_prunable_copy(self, bn=False): | ||||
|  | ||||
|         model_new = Network(self.spec, self.stem_out, self.num_stacks, self.num_mods, self.num_classes, bn=bn) | ||||
|          | ||||
|         #TODO this is quite brittle and doesn't work with nn.Sequential when bn is different | ||||
|         # it is only required to maintain initialization -- maybe init after get_punable_copy? | ||||
|         model_new.load_state_dict(self.state_dict(), strict=False) | ||||
|         model_new.train() | ||||
|  | ||||
|         return model_new | ||||
|  | ||||
|     def _initialize_weights(self): | ||||
|         for m in self.modules(): | ||||
|             if isinstance(m, nn.Conv2d): | ||||
|                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||||
|                 m.weight.data.normal_(0, math.sqrt(2.0 / n)) | ||||
|                 if m.bias is not None: | ||||
|                     m.bias.data.zero_() | ||||
|             elif isinstance(m, nn.BatchNorm2d): | ||||
|                 m.weight.data.fill_(1) | ||||
|                 m.bias.data.zero_() | ||||
|             elif isinstance(m, nn.Linear): | ||||
|                 n = m.weight.size(1) | ||||
|                 m.weight.data.normal_(0, 0.01) | ||||
|                 m.bias.data.zero_() | ||||
|  | ||||
| class Cell(nn.Module): | ||||
|     """ | ||||
|     Builds the model using the adjacency matrix and op labels specified. Channels | ||||
|     controls the module output channel count but the interior channels are | ||||
|     determined via equally splitting the channel count whenever there is a | ||||
|     concatenation of Tensors. | ||||
|     """ | ||||
|     def __init__(self, spec, in_channels, out_channels, bn=True): | ||||
|         super(Cell, self).__init__() | ||||
|  | ||||
|         self.spec = spec | ||||
|         self.num_vertices = np.shape(self.spec.matrix)[0] | ||||
|  | ||||
|         # vertex_channels[i] = number of output channels of vertex i | ||||
|         self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.spec.matrix) | ||||
|         #self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1) | ||||
|  | ||||
|         # operation for each node | ||||
|         self.vertex_op = nn.ModuleList([None]) | ||||
|         for t in range(1, self.num_vertices-1): | ||||
|             op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t], bn=bn) | ||||
|             self.vertex_op.append(op) | ||||
|  | ||||
|         # operation for input on each vertex | ||||
|         self.input_op = nn.ModuleList([None]) | ||||
|         for t in range(1, self.num_vertices): | ||||
|             if self.spec.matrix[0, t]: | ||||
|                 self.input_op.append(Projection(in_channels, self.vertex_channels[t], bn=bn)) | ||||
|             else: | ||||
|                 self.input_op.append(None) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         tensors = [x] | ||||
|  | ||||
|         out_concat = [] | ||||
|         for t in range(1, self.num_vertices-1): | ||||
|             fan_in = [Truncate(tensors[src], self.vertex_channels[t]) for src in range(1, t) if self.spec.matrix[src, t]] | ||||
|  | ||||
|             if self.spec.matrix[0, t]: | ||||
|                 fan_in.append(self.input_op[t](x)) | ||||
|  | ||||
|             # perform operation on node | ||||
|             #vertex_input = torch.stack(fan_in, dim=0).sum(dim=0) | ||||
|             vertex_input = sum(fan_in) | ||||
|             #vertex_input = sum(fan_in) / len(fan_in) | ||||
|             vertex_output = self.vertex_op[t](vertex_input) | ||||
|  | ||||
|             tensors.append(vertex_output) | ||||
|             if self.spec.matrix[t, self.num_vertices-1]: | ||||
|                 out_concat.append(tensors[t]) | ||||
|  | ||||
|         if not out_concat: | ||||
|             assert self.spec.matrix[0, self.num_vertices-1] | ||||
|             outputs = self.input_op[self.num_vertices-1](tensors[0]) | ||||
|         else: | ||||
|             if len(out_concat) == 1: | ||||
|                 outputs = out_concat[0] | ||||
|             else: | ||||
|                 outputs = torch.cat(out_concat, 1) | ||||
|  | ||||
|             if self.spec.matrix[0, self.num_vertices-1]: | ||||
|                 outputs += self.input_op[self.num_vertices-1](tensors[0]) | ||||
|  | ||||
|             #if self.spec.matrix[0, self.num_vertices-1]: | ||||
|             #    out_concat.append(self.input_op[self.num_vertices-1](tensors[0])) | ||||
|             #outputs = sum(out_concat) / len(out_concat) | ||||
|  | ||||
|         return outputs | ||||
|  | ||||
| def Projection(in_channels, out_channels, bn=True): | ||||
|     """1x1 projection (as in ResNet) followed by batch normalization and ReLU.""" | ||||
|     return ConvBnRelu(in_channels, out_channels, 1, bn=bn) | ||||
|  | ||||
| def Truncate(inputs, channels): | ||||
|     """Slice the inputs to channels if necessary.""" | ||||
|     input_channels = inputs.size()[1] | ||||
|     if input_channels < channels: | ||||
|         raise ValueError('input channel < output channels for truncate') | ||||
|     elif input_channels == channels: | ||||
|         return inputs   # No truncation necessary | ||||
|     else: | ||||
|         # Truncation should only be necessary when channel division leads to | ||||
|         # vertices with +1 channels. The input vertex should always be projected to | ||||
|         # the minimum channel count. | ||||
|         assert input_channels - channels == 1 | ||||
|         return inputs[:, :channels, :, :] | ||||
|  | ||||
| def ComputeVertexChannels(in_channels, out_channels, matrix): | ||||
|     """Computes the number of channels at every vertex. | ||||
|     Given the input channels and output channels, this calculates the number of | ||||
|     channels at each interior vertex. Interior vertices have the same number of | ||||
|     channels as the max of the channels of the vertices it feeds into. The output | ||||
|     channels are divided amongst the vertices that are directly connected to it. | ||||
|     When the division is not even, some vertices may receive an extra channel to | ||||
|     compensate. | ||||
|     Returns: | ||||
|         list of channel counts, in order of the vertices. | ||||
|     """ | ||||
|     num_vertices = np.shape(matrix)[0] | ||||
|  | ||||
|     vertex_channels = [0] * num_vertices | ||||
|     vertex_channels[0] = in_channels | ||||
|     vertex_channels[num_vertices - 1] = out_channels | ||||
|  | ||||
|     if num_vertices == 2: | ||||
|         # Edge case where module only has input and output vertices | ||||
|         return vertex_channels | ||||
|  | ||||
|     # Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is | ||||
|     # the dst vertex. Summing over 0 gives the in-degree count of each vertex. | ||||
|     in_degree = np.sum(matrix[1:], axis=0) | ||||
|     interior_channels = out_channels // in_degree[num_vertices - 1] | ||||
|     correction = out_channels % in_degree[num_vertices - 1]  # Remainder to add | ||||
|  | ||||
|     # Set channels of vertices that flow directly to output | ||||
|     for v in range(1, num_vertices - 1): | ||||
|       if matrix[v, num_vertices - 1]: | ||||
|           vertex_channels[v] = interior_channels | ||||
|           if correction: | ||||
|               vertex_channels[v] += 1 | ||||
|               correction -= 1 | ||||
|  | ||||
|     # Set channels for all other vertices to the max of the out edges, going | ||||
|     # backwards. (num_vertices - 2) index skipped because it only connects to | ||||
|     # output. | ||||
|     for v in range(num_vertices - 3, 0, -1): | ||||
|         if not matrix[v, num_vertices - 1]: | ||||
|             for dst in range(v + 1, num_vertices - 1): | ||||
|                 if matrix[v, dst]: | ||||
|                     vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst]) | ||||
|         assert vertex_channels[v] > 0 | ||||
|  | ||||
|     # Sanity check, verify that channels never increase and final channels add up. | ||||
|     final_fan_in = 0 | ||||
|     for v in range(1, num_vertices - 1): | ||||
|         if matrix[v, num_vertices - 1]: | ||||
|             final_fan_in += vertex_channels[v] | ||||
|         for dst in range(v + 1, num_vertices - 1): | ||||
|             if matrix[v, dst]: | ||||
|                 assert vertex_channels[v] >= vertex_channels[dst] | ||||
|     assert final_fan_in == out_channels or num_vertices == 2 | ||||
|     # num_vertices == 2 means only input/output nodes, so 0 fan-in | ||||
|  | ||||
|     return vertex_channels | ||||
							
								
								
									
										83
									
								
								zero-cost-nas/foresight/models/nasbench1_ops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								zero-cost-nas/foresight/models/nasbench1_ops.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| """Base operations used by the modules in this search space.""" | ||||
|  | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| class ConvBnRelu(nn.Module): | ||||
|     def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bn=True): | ||||
|         super(ConvBnRelu, self).__init__() | ||||
|  | ||||
|         if bn: | ||||
|             self.conv_bn_relu = nn.Sequential( | ||||
|                 nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), | ||||
|                 nn.BatchNorm2d(out_channels), | ||||
|                 nn.ReLU(inplace=False) | ||||
|             ) | ||||
|         else: | ||||
|             self.conv_bn_relu = nn.Sequential( | ||||
|                 nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), | ||||
|                 nn.ReLU(inplace=False) | ||||
|             ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.conv_bn_relu(x) | ||||
|  | ||||
| class Conv3x3BnRelu(nn.Module): | ||||
|     """3x3 convolution with batch norm and ReLU activation.""" | ||||
|     def __init__(self, in_channels, out_channels, bn=True): | ||||
|         super(Conv3x3BnRelu, self).__init__() | ||||
|  | ||||
|         self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1, bn=bn) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.conv3x3(x) | ||||
|         return x | ||||
|  | ||||
| class Conv1x1BnRelu(nn.Module): | ||||
|     """1x1 convolution with batch norm and ReLU activation.""" | ||||
|     def __init__(self, in_channels, out_channels, bn=True): | ||||
|         super(Conv1x1BnRelu, self).__init__() | ||||
|  | ||||
|         self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0, bn=bn) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.conv1x1(x) | ||||
|         return x | ||||
|  | ||||
| class MaxPool3x3(nn.Module): | ||||
|     """3x3 max pool with no subsampling.""" | ||||
|     def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bn=None): | ||||
|         super(MaxPool3x3, self).__init__() | ||||
|  | ||||
|         self.maxpool = nn.MaxPool2d(kernel_size, stride, padding) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.maxpool(x) | ||||
|         return x | ||||
|  | ||||
| # Commas should not be used in op names | ||||
| OP_MAP = { | ||||
|     'conv3x3-bn-relu': Conv3x3BnRelu, | ||||
|     'conv1x1-bn-relu': Conv1x1BnRelu, | ||||
|     'maxpool3x3': MaxPool3x3 | ||||
| } | ||||
							
								
								
									
										294
									
								
								zero-cost-nas/foresight/models/nasbench1_spec.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								zero-cost-nas/foresight/models/nasbench1_spec.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,294 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| """Model specification for module connectivity individuals. | ||||
| This module handles pruning the unused parts of the computation graph but should | ||||
| avoid creating any TensorFlow models (this is done inside model_builder.py). | ||||
| """ | ||||
|  | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import copy | ||||
| import hashlib | ||||
| import itertools | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| # Graphviz is optional and only required for visualization. | ||||
| try: | ||||
|   import graphviz   # pylint: disable=g-import-not-at-top | ||||
| except ImportError: | ||||
|   pass | ||||
|  | ||||
| def _ToModelSpec(mat, ops): | ||||
|     return ModelSpec(mat, ops) | ||||
|  | ||||
| def gen_is_edge_fn(bits): | ||||
|   """Generate a boolean function for the edge connectivity. | ||||
|   Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is | ||||
|     [[0, A, B, D], | ||||
|      [0, 0, C, E], | ||||
|      [0, 0, 0, F], | ||||
|      [0, 0, 0, 0]] | ||||
|   Note that this function is agnostic to the actual matrix dimension due to | ||||
|   order in which elements are filled out (column-major, starting from least | ||||
|   significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5 | ||||
|   matrix is | ||||
|     [[0, A, B, D, 0], | ||||
|      [0, 0, C, E, 0], | ||||
|      [0, 0, 0, F, 0], | ||||
|      [0, 0, 0, 0, 0], | ||||
|      [0, 0, 0, 0, 0]] | ||||
|   Args: | ||||
|     bits: integer which will be interpreted as a bit mask. | ||||
|   Returns: | ||||
|     vectorized function that returns True when an edge is present. | ||||
|   """ | ||||
|   def is_edge(x, y): | ||||
|     """Is there an edge from x to y (0-indexed)?""" | ||||
|     if x >= y: | ||||
|       return 0 | ||||
|     # Map x, y to index into bit string | ||||
|     index = x + (y * (y - 1) // 2) | ||||
|     return (bits >> index) % 2 == 1 | ||||
|  | ||||
|   return np.vectorize(is_edge) | ||||
|  | ||||
|  | ||||
| def is_full_dag(matrix): | ||||
|   """Full DAG == all vertices on a path from vert 0 to (V-1). | ||||
|   i.e. no disconnected or "hanging" vertices. | ||||
|   It is sufficient to check for: | ||||
|     1) no rows of 0 except for row V-1 (only output vertex has no out-edges) | ||||
|     2) no cols of 0 except for col 0 (only input vertex has no in-edges) | ||||
|   Args: | ||||
|     matrix: V x V upper-triangular adjacency matrix | ||||
|   Returns: | ||||
|     True if the there are no dangling vertices. | ||||
|   """ | ||||
|   shape = np.shape(matrix) | ||||
|  | ||||
|   rows = matrix[:shape[0]-1, :] == 0 | ||||
|   rows = np.all(rows, axis=1)     # Any row with all 0 will be True | ||||
|   rows_bad = np.any(rows) | ||||
|  | ||||
|   cols = matrix[:, 1:] == 0 | ||||
|   cols = np.all(cols, axis=0)     # Any col with all 0 will be True | ||||
|   cols_bad = np.any(cols) | ||||
|  | ||||
|   return (not rows_bad) and (not cols_bad) | ||||
|  | ||||
|  | ||||
| def num_edges(matrix): | ||||
|   """Computes number of edges in adjacency matrix.""" | ||||
|   return np.sum(matrix) | ||||
|  | ||||
|  | ||||
| def hash_module(matrix, labeling): | ||||
|   """Computes a graph-invariance MD5 hash of the matrix and label pair. | ||||
|   Args: | ||||
|     matrix: np.ndarray square upper-triangular adjacency matrix. | ||||
|     labeling: list of int labels of length equal to both dimensions of | ||||
|       matrix. | ||||
|   Returns: | ||||
|     MD5 hash of the matrix and labeling. | ||||
|   """ | ||||
|   vertices = np.shape(matrix)[0] | ||||
|   in_edges = np.sum(matrix, axis=0).tolist() | ||||
|   out_edges = np.sum(matrix, axis=1).tolist() | ||||
|  | ||||
|   assert len(in_edges) == len(out_edges) == len(labeling) | ||||
|   hashes = list(zip(out_edges, in_edges, labeling)) | ||||
|   hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes] | ||||
|   # Computing this up to the diameter is probably sufficient but since the | ||||
|   # operation is fast, it is okay to repeat more times. | ||||
|   for _ in range(vertices): | ||||
|     new_hashes = [] | ||||
|     for v in range(vertices): | ||||
|       in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]] | ||||
|       out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]] | ||||
|       new_hashes.append(hashlib.md5( | ||||
|           (''.join(sorted(in_neighbors)) + '|' + | ||||
|            ''.join(sorted(out_neighbors)) + '|' + | ||||
|            hashes[v]).encode('utf-8')).hexdigest()) | ||||
|     hashes = new_hashes | ||||
|   fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest() | ||||
|  | ||||
|   return fingerprint | ||||
|  | ||||
|  | ||||
| def permute_graph(graph, label, permutation): | ||||
|   """Permutes the graph and labels based on permutation. | ||||
|   Args: | ||||
|     graph: np.ndarray adjacency matrix. | ||||
|     label: list of labels of same length as graph dimensions. | ||||
|     permutation: a permutation list of ints of same length as graph dimensions. | ||||
|   Returns: | ||||
|     np.ndarray where vertex permutation[v] is vertex v from the original graph | ||||
|   """ | ||||
|   # vertex permutation[v] in new graph is vertex v in the old graph | ||||
|   forward_perm = zip(permutation, list(range(len(permutation)))) | ||||
|   inverse_perm = [x[1] for x in sorted(forward_perm)] | ||||
|   edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1 | ||||
|   new_matrix = np.fromfunction(np.vectorize(edge_fn), | ||||
|                                (len(label), len(label)), | ||||
|                                dtype=np.int8) | ||||
|   new_label = [label[inverse_perm[i]] for i in range(len(label))] | ||||
|   return new_matrix, new_label | ||||
|  | ||||
|  | ||||
| def is_isomorphic(graph1, graph2): | ||||
|   """Exhaustively checks if 2 graphs are isomorphic.""" | ||||
|   matrix1, label1 = np.array(graph1[0]), graph1[1] | ||||
|   matrix2, label2 = np.array(graph2[0]), graph2[1] | ||||
|   assert np.shape(matrix1) == np.shape(matrix2) | ||||
|   assert len(label1) == len(label2) | ||||
|  | ||||
|   vertices = np.shape(matrix1)[0] | ||||
|   # Note: input and output in our constrained graphs always map to themselves | ||||
|   # but this script does not enforce that. | ||||
|   for perm in itertools.permutations(range(0, vertices)): | ||||
|     pmatrix1, plabel1 = permute_graph(matrix1, label1, perm) | ||||
|     if np.array_equal(pmatrix1, matrix2) and plabel1 == label2: | ||||
|       return True | ||||
|  | ||||
|   return False | ||||
|  | ||||
| class ModelSpec(object): | ||||
|   """Model specification given adjacency matrix and labeling.""" | ||||
|  | ||||
|   def __init__(self, matrix, ops, data_format='channels_last'): | ||||
|     """Initialize the module spec. | ||||
|     Args: | ||||
|       matrix: ndarray or nested list with shape [V, V] for the adjacency matrix. | ||||
|       ops: V-length list of labels for the base ops used. The first and last | ||||
|         elements are ignored because they are the input and output vertices | ||||
|         which have no operations. The elements are retained to keep consistent | ||||
|         indexing. | ||||
|       data_format: channels_last or channels_first. | ||||
|     Raises: | ||||
|       ValueError: invalid matrix or ops | ||||
|     """ | ||||
|     if not isinstance(matrix, np.ndarray): | ||||
|       matrix = np.array(matrix) | ||||
|     shape = np.shape(matrix) | ||||
|     if len(shape) != 2 or shape[0] != shape[1]: | ||||
|       raise ValueError('matrix must be square') | ||||
|     if shape[0] != len(ops): | ||||
|       raise ValueError('length of ops must match matrix dimensions') | ||||
|     if not is_upper_triangular(matrix): | ||||
|       raise ValueError('matrix must be upper triangular') | ||||
|  | ||||
|     # Both the original and pruned matrices are deep copies of the matrix and | ||||
|     # ops so any changes to those after initialization are not recognized by the | ||||
|     # spec. | ||||
|     self.original_matrix = copy.deepcopy(matrix) | ||||
|     self.original_ops = copy.deepcopy(ops) | ||||
|  | ||||
|     self.matrix = copy.deepcopy(matrix) | ||||
|     self.ops = copy.deepcopy(ops) | ||||
|     self.valid_spec = True | ||||
|     self._prune() | ||||
|  | ||||
|     self.data_format = data_format | ||||
|  | ||||
|   def _prune(self): | ||||
|     """Prune the extraneous parts of the graph. | ||||
|     General procedure: | ||||
|       1) Remove parts of graph not connected to input. | ||||
|       2) Remove parts of graph not connected to output. | ||||
|       3) Reorder the vertices so that they are consecutive after steps 1 and 2. | ||||
|     These 3 steps can be combined by deleting the rows and columns of the | ||||
|     vertices that are not reachable from both the input and output (in reverse). | ||||
|     """ | ||||
|     num_vertices = np.shape(self.original_matrix)[0] | ||||
|  | ||||
|     # DFS forward from input | ||||
|     visited_from_input = set([0]) | ||||
|     frontier = [0] | ||||
|     while frontier: | ||||
|       top = frontier.pop() | ||||
|       for v in range(top + 1, num_vertices): | ||||
|         if self.original_matrix[top, v] and v not in visited_from_input: | ||||
|           visited_from_input.add(v) | ||||
|           frontier.append(v) | ||||
|  | ||||
|     # DFS backward from output | ||||
|     visited_from_output = set([num_vertices - 1]) | ||||
|     frontier = [num_vertices - 1] | ||||
|     while frontier: | ||||
|       top = frontier.pop() | ||||
|       for v in range(0, top): | ||||
|         if self.original_matrix[v, top] and v not in visited_from_output: | ||||
|           visited_from_output.add(v) | ||||
|           frontier.append(v) | ||||
|  | ||||
|     # Any vertex that isn't connected to both input and output is extraneous to | ||||
|     # the computation graph. | ||||
|     extraneous = set(range(num_vertices)).difference( | ||||
|         visited_from_input.intersection(visited_from_output)) | ||||
|  | ||||
|     # If the non-extraneous graph is less than 2 vertices, the input is not | ||||
|     # connected to the output and the spec is invalid. | ||||
|     if len(extraneous) > num_vertices - 2: | ||||
|       self.matrix = None | ||||
|       self.ops = None | ||||
|       self.valid_spec = False | ||||
|       return | ||||
|  | ||||
|     self.matrix = np.delete(self.matrix, list(extraneous), axis=0) | ||||
|     self.matrix = np.delete(self.matrix, list(extraneous), axis=1) | ||||
|     for index in sorted(extraneous, reverse=True): | ||||
|       del self.ops[index] | ||||
|  | ||||
|   def hash_spec(self, canonical_ops): | ||||
|     """Computes the isomorphism-invariant graph hash of this spec. | ||||
|     Args: | ||||
|       canonical_ops: list of operations in the canonical ordering which they | ||||
|         were assigned (i.e. the order provided in the config['available_ops']). | ||||
|     Returns: | ||||
|       MD5 hash of this spec which can be used to query the dataset. | ||||
|     """ | ||||
|     # Invert the operations back to integer label indices used in graph gen. | ||||
|     labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2] | ||||
|     return graph_util.hash_module(self.matrix, labeling) | ||||
|  | ||||
|   def visualize(self): | ||||
|     """Creates a dot graph. Can be visualized in colab directly.""" | ||||
|     num_vertices = np.shape(self.matrix)[0] | ||||
|     g = graphviz.Digraph() | ||||
|     g.node(str(0), 'input') | ||||
|     for v in range(1, num_vertices - 1): | ||||
|       g.node(str(v), self.ops[v]) | ||||
|     g.node(str(num_vertices - 1), 'output') | ||||
|  | ||||
|     for src in range(num_vertices - 1): | ||||
|       for dst in range(src + 1, num_vertices): | ||||
|         if self.matrix[src, dst]: | ||||
|           g.edge(str(src), str(dst)) | ||||
|  | ||||
|     return g | ||||
|  | ||||
|  | ||||
| def is_upper_triangular(matrix): | ||||
|   """True if matrix is 0 on diagonal and below.""" | ||||
|   for src in range(np.shape(matrix)[0]): | ||||
|     for dst in range(0, src + 1): | ||||
|       if matrix[src, dst] != 0: | ||||
|         return False | ||||
|  | ||||
|   return True | ||||
							
								
								
									
										121
									
								
								zero-cost-nas/foresight/models/nasbench2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								zero-cost-nas/foresight/models/nasbench2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import os | ||||
| import argparse | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .nasbench2_ops import * | ||||
|  | ||||
|  | ||||
| def gen_searchcell_mask_from_arch_str(arch_str): | ||||
|     nodes = arch_str.split('+')  | ||||
|     nodes = [node[1:-1].split('|') for node in nodes] | ||||
|     nodes = [[op_and_input.split('~')  for op_and_input in node] for node in nodes] | ||||
|  | ||||
|     keep_mask = [] | ||||
|     for curr_node_idx in range(len(nodes)): | ||||
|             for prev_node_idx in range(curr_node_idx+1):  | ||||
|                 _op = [edge[0] for edge in nodes[curr_node_idx] if int(edge[1]) == prev_node_idx] | ||||
|                 assert len(_op) == 1, 'The arch string does not follow the assumption of 1 connection between two nodes.' | ||||
|                 for _op_name in OPS.keys(): | ||||
|                     keep_mask.append(_op[0] == _op_name) | ||||
|     return keep_mask | ||||
|  | ||||
|  | ||||
| def get_model_from_arch_str(arch_str, num_classes, use_bn=True, init_channels=16): | ||||
|     keep_mask = gen_searchcell_mask_from_arch_str(arch_str) | ||||
|     net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn, keep_mask=keep_mask, stem_ch=init_channels) | ||||
|     return net | ||||
|  | ||||
|  | ||||
| def get_super_model(num_classes, use_bn=True): | ||||
|     net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn) | ||||
|     return net | ||||
|  | ||||
|  | ||||
| class NAS201Model(nn.Module): | ||||
|  | ||||
|     def __init__(self, arch_str, num_classes, use_bn=True, keep_mask=None, stem_ch=16): | ||||
|         super(NAS201Model, self).__init__() | ||||
|         self.arch_str=arch_str | ||||
|         self.num_classes=num_classes | ||||
|         self.use_bn= use_bn | ||||
|  | ||||
|         self.stem = stem(out_channels=stem_ch, use_bn=use_bn) | ||||
|         self.stack_cell1 = nn.Sequential(*[SearchCell(in_channels=stem_ch, out_channels=stem_ch, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)]) | ||||
|         self.reduction1 = reduction(in_channels=stem_ch, out_channels=stem_ch*2) | ||||
|         self.stack_cell2 = nn.Sequential(*[SearchCell(in_channels=stem_ch*2, out_channels=stem_ch*2, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)]) | ||||
|         self.reduction2 = reduction(in_channels=stem_ch*2, out_channels=stem_ch*4) | ||||
|         self.stack_cell3 = nn.Sequential(*[SearchCell(in_channels=stem_ch*4, out_channels=stem_ch*4, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)]) | ||||
|         self.top = top(in_dims=stem_ch*4, num_classes=num_classes, use_bn=use_bn) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.stem(x)         | ||||
|  | ||||
|         x = self.stack_cell1(x) | ||||
|         x = self.reduction1(x) | ||||
|  | ||||
|         x = self.stack_cell2(x) | ||||
|         x = self.reduction2(x) | ||||
|  | ||||
|         x = self.stack_cell3(x) | ||||
|  | ||||
|         x = self.top(x) | ||||
|         return x | ||||
|      | ||||
|     def get_prunable_copy(self, bn=False): | ||||
|         model_new = get_model_from_arch_str(self.arch_str, self.num_classes, use_bn=bn) | ||||
|  | ||||
|         #TODO this is quite brittle and doesn't work with nn.Sequential when bn is different | ||||
|         # it is only required to maintain initialization -- maybe init after get_punable_copy? | ||||
|         model_new.load_state_dict(self.state_dict(), strict=False) | ||||
|         model_new.train() | ||||
|  | ||||
|         return model_new | ||||
|      | ||||
|  | ||||
| def get_arch_str_from_model(net): | ||||
|     search_cell = net.stack_cell1[0].options | ||||
|     keep_mask = net.stack_cell1[0].keep_mask | ||||
|     num_nodes = net.stack_cell1[0].num_nodes | ||||
|  | ||||
|     nodes = [] | ||||
|     idx = 0 | ||||
|     for curr_node in range(num_nodes -1): | ||||
|         edges = [] | ||||
|         for prev_node in range(curr_node+1): # n-1 prev nodes | ||||
|             for _op_name in OPS.keys(): | ||||
|                 if keep_mask[idx]: | ||||
|                     edges.append(f'{_op_name}~{prev_node}') | ||||
|                 idx += 1 | ||||
|         node_str = '|'.join(edges) | ||||
|         node_str = f'|{node_str}|' | ||||
|         nodes.append(node_str)  | ||||
|     arch_str = '+'.join(nodes) | ||||
|     return arch_str | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     arch_str = '|nor_conv_3x3~0|+|none~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|' | ||||
|      | ||||
|     n = get_model_from_arch_str(arch_str=arch_str, num_classes=10) | ||||
|     print(n.stack_cell1[0]) | ||||
|      | ||||
|     arch_str2 = get_arch_str_from_model(n) | ||||
|     print(arch_str) | ||||
|     print(arch_str2) | ||||
|     print(f'Are the two arch strings same? {arch_str == arch_str2}') | ||||
							
								
								
									
										164
									
								
								zero-cost-nas/foresight/models/nasbench2_ops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								zero-cost-nas/foresight/models/nasbench2_ops.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,164 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import os | ||||
| import argparse | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|     def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, affine, track_running_stats=True, use_bn=True, name='ReLUConvBN'): | ||||
|         super(ReLUConvBN, self).__init__() | ||||
|         self.name = name | ||||
|         if use_bn: | ||||
|             self.op = nn.Sequential( | ||||
|                 nn.ReLU(inplace=False), | ||||
|                 nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine), | ||||
|                 nn.BatchNorm2d(out_channels, affine=affine, track_running_stats=track_running_stats) | ||||
|                 ) | ||||
|         else: | ||||
|             self.op = nn.Sequential( | ||||
|                 nn.ReLU(inplace=False), | ||||
|                 nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine) | ||||
|                 ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|     def __init__(self, name='Identity'): | ||||
|         self.name = name | ||||
|         super(Identity, self).__init__() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return x | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|  | ||||
|   def __init__(self, stride, name='Zero'): | ||||
|     self.name = name | ||||
|     super(Zero, self).__init__() | ||||
|     self.stride = stride | ||||
|  | ||||
|   def forward(self, x): | ||||
|     if self.stride == 1: | ||||
|       return x.mul(0.) | ||||
|     return x[:,:,::self.stride,::self.stride].mul(0.) | ||||
|  | ||||
| class POOLING(nn.Module): | ||||
|     def __init__(self, kernel_size, stride, padding, name='POOLING'): | ||||
|         super(POOLING, self).__init__() | ||||
|         self.name = name | ||||
|         self.avgpool = nn.AvgPool2d(kernel_size=kernel_size, stride=1, padding=1, count_include_pad=False) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.avgpool(x) | ||||
|  | ||||
|  | ||||
| class reduction(nn.Module): | ||||
|     def __init__(self, in_channels, out_channels): | ||||
|         super(reduction, self).__init__() | ||||
|         self.residual = nn.Sequential( | ||||
|                             nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||
|                             nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False)) | ||||
|  | ||||
|         self.conv_a = ReLUConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, dilation=1, affine=True, track_running_stats=True) | ||||
|         self.conv_b = ReLUConvBN(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, dilation=1, affine=True, track_running_stats=True) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         basicblock = self.conv_a(x) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         residual = self.residual(x) | ||||
|         return residual + basicblock | ||||
|  | ||||
| class stem(nn.Module): | ||||
|     def __init__(self, out_channels, use_bn=True): | ||||
|         super(stem, self).__init__() | ||||
|         if use_bn: | ||||
|             self.net = nn.Sequential( | ||||
|                     nn.Conv2d(in_channels=3, out_channels=out_channels, kernel_size=3, padding=1, bias=False), | ||||
|                     nn.BatchNorm2d(out_channels)) | ||||
|         else: | ||||
|             self.net = nn.Sequential( | ||||
|                     nn.Conv2d(in_channels=3, out_channels=out_channels, kernel_size=3, padding=1, bias=False) | ||||
|             ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.net(x) | ||||
|  | ||||
| class top(nn.Module): | ||||
|     def __init__(self, in_dims, num_classes, use_bn=True): | ||||
|         super(top, self).__init__() | ||||
|         if use_bn: | ||||
|             self.lastact = nn.Sequential(nn.BatchNorm2d(in_dims), nn.ReLU(inplace=True)) | ||||
|         else: | ||||
|             self.lastact = nn.ReLU(inplace=True) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier  = nn.Linear(in_dims, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.lastact(x) | ||||
|         x = self.global_pooling(x) | ||||
|         x = x.view(x.size(0), -1) | ||||
|         logits = self.classifier(x) | ||||
|         return logits | ||||
|  | ||||
|  | ||||
| class SearchCell(nn.Module): | ||||
|  | ||||
|     def __init__(self, in_channels, out_channels, stride, affine, track_running_stats, use_bn=True, num_nodes=4, keep_mask=None): | ||||
|         super(SearchCell, self).__init__() | ||||
|         self.num_nodes = num_nodes | ||||
|         self.options = nn.ModuleList() | ||||
|         for curr_node in range(self.num_nodes-1): | ||||
|             for prev_node in range(curr_node+1):  | ||||
|                 for _op_name in OPS.keys(): | ||||
|                     op = OPS[_op_name](in_channels, out_channels, stride, affine, track_running_stats, use_bn) | ||||
|                     self.options.append(op) | ||||
|  | ||||
|         if keep_mask is not None: | ||||
|             self.keep_mask = keep_mask | ||||
|         else: | ||||
|             self.keep_mask = [True]*len(self.options) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         outs = [x] | ||||
|  | ||||
|         idx = 0 | ||||
|         for curr_node in range(self.num_nodes-1): | ||||
|             edges_in = [] | ||||
|             for prev_node in range(curr_node+1): # n-1 prev nodes | ||||
|                 for op_idx in range(len(OPS.keys())): | ||||
|                     if self.keep_mask[idx]: | ||||
|                         edges_in.append(self.options[idx](outs[prev_node])) | ||||
|                     idx += 1 | ||||
|             node_output = sum(edges_in) | ||||
|             outs.append(node_output) | ||||
|          | ||||
|         return outs[-1] | ||||
|  | ||||
|  | ||||
|  | ||||
| OPS = { | ||||
|     'none' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: Zero(stride, name='none'), | ||||
|     'avg_pool_3x3' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: POOLING(3, 1, 1, name='avg_pool_3x3'), | ||||
|     'nor_conv_3x3' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: ReLUConvBN(in_channels, out_channels, 3, 1, 1, 1, affine, track_running_stats, use_bn, name='nor_conv_3x3'), | ||||
|     'nor_conv_1x1' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: ReLUConvBN(in_channels, out_channels, 1, 1, 0, 1, affine, track_running_stats, use_bn, name='nor_conv_1x1'), | ||||
|     'skip_connect' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: Identity(name='skip_connect'), | ||||
| } | ||||
|  | ||||
|  | ||||
							
								
								
									
										19
									
								
								zero-cost-nas/foresight/pruners/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								zero-cost-nas/foresight/pruners/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| from os.path import dirname, basename, isfile, join | ||||
| import glob | ||||
| modules = glob.glob(join(dirname(__file__), "*.py")) | ||||
| __all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] | ||||
							
								
								
									
										66
									
								
								zero-cost-nas/foresight/pruners/measures/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								zero-cost-nas/foresight/pruners/measures/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
|  | ||||
| available_measures = [] | ||||
| _measure_impls = {} | ||||
|  | ||||
|  | ||||
| def measure(name, bn=True, copy_net=True, force_clean=True, **impl_args): | ||||
|     def make_impl(func): | ||||
|         def measure_impl(net_orig, device, *args, **kwargs): | ||||
|             if copy_net: | ||||
|                 net = net_orig.get_prunable_copy(bn=bn).to(device) | ||||
|             else: | ||||
|                 net = net_orig | ||||
|             ret = func(net, *args, **kwargs, **impl_args) | ||||
|             if copy_net and force_clean: | ||||
|                 import gc | ||||
|                 import torch | ||||
|                 del net | ||||
|                 torch.cuda.empty_cache() | ||||
|                 gc.collect() | ||||
|             return ret | ||||
|  | ||||
|         global _measure_impls | ||||
|         if name in _measure_impls: | ||||
|             raise KeyError(f'Duplicated measure! {name}') | ||||
|         available_measures.append(name) | ||||
|         _measure_impls[name] = measure_impl | ||||
|         return func | ||||
|     return make_impl | ||||
|  | ||||
|  | ||||
| def calc_measure(name, net, device, *args, **kwargs): | ||||
|     return _measure_impls[name](net, device, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| def load_all(): | ||||
|     from . import grad_norm | ||||
|     from . import snip | ||||
|     from . import grasp | ||||
|     from . import fisher | ||||
|     from . import jacob_cov | ||||
|     from . import plain | ||||
|     from . import synflow | ||||
|     from . import var | ||||
|     from . import cor | ||||
|     from . import norm | ||||
|     from . import meco | ||||
|     from . import zico | ||||
|  | ||||
|  | ||||
| # TODO: should we do that by default? | ||||
| load_all() | ||||
							
								
								
									
										53
									
								
								zero-cost-nas/foresight/pruners/measures/cor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								zero-cost-nas/foresight/pruners/measures/cor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_score(net, x, target, device, split_data): | ||||
|     result_list = [] | ||||
|     def forward_hook(module, data_input, data_output): | ||||
|         corr = np.mean(np.corrcoef(data_input[0].detach().cpu().numpy())) | ||||
|         result_list.append(corr) | ||||
|     net.classifier.register_forward_hook(forward_hook) | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         y = net(x[st:en]) | ||||
|     cor = result_list[0].item() | ||||
|     result_list.clear() | ||||
|     return cor | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('cor', bn=True) | ||||
| def compute_norm(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     try: | ||||
|         cor= get_score(net, inputs, targets, device, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         cor= np.nan | ||||
|  | ||||
|     return cor | ||||
							
								
								
									
										107
									
								
								zero-cost-nas/foresight/pruners/measures/fisher.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								zero-cost-nas/foresight/pruners/measures/fisher.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,107 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import types | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array, reshape_elements | ||||
|  | ||||
|  | ||||
| def fisher_forward_conv2d(self, x): | ||||
|     x = F.conv2d(x, self.weight, self.bias, self.stride, | ||||
|                     self.padding, self.dilation, self.groups) | ||||
|     #intercept and store the activations after passing through 'hooked' identity op | ||||
|     self.act = self.dummy(x) | ||||
|     return self.act | ||||
|  | ||||
| def fisher_forward_linear(self, x): | ||||
|     x = F.linear(x, self.weight, self.bias) | ||||
|     self.act = self.dummy(x) | ||||
|     return self.act | ||||
|  | ||||
| @measure('fisher', bn=True, mode='channel') | ||||
| def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1): | ||||
|      | ||||
|     device = inputs.device | ||||
|  | ||||
|     if mode == 'param': | ||||
|         raise ValueError('Fisher pruning does not support parameter pruning.') | ||||
|  | ||||
|     net.train() | ||||
|     all_hooks = [] | ||||
|     for layer in net.modules(): | ||||
|         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|             #variables/op needed for fisher computation | ||||
|             layer.fisher = None | ||||
|             layer.act = 0. | ||||
|             layer.dummy = nn.Identity() | ||||
|  | ||||
|             #replace forward method of conv/linear | ||||
|             if isinstance(layer, nn.Conv2d): | ||||
|                 layer.forward = types.MethodType(fisher_forward_conv2d, layer) | ||||
|             if isinstance(layer, nn.Linear): | ||||
|                 layer.forward = types.MethodType(fisher_forward_linear, layer) | ||||
|  | ||||
|             #function to call during backward pass (hooked on identity op at output of layer) | ||||
|             def hook_factory(layer): | ||||
|                 def hook(module, grad_input, grad_output): | ||||
|                     act = layer.act.detach() | ||||
|                     grad = grad_output[0].detach() | ||||
|                     if len(act.shape) > 2: | ||||
|                         g_nk = torch.sum((act * grad), list(range(2,len(act.shape)))) | ||||
|                     else: | ||||
|                         g_nk = act * grad | ||||
|                     del_k = g_nk.pow(2).mean(0).mul(0.5) | ||||
|                     if layer.fisher is None: | ||||
|                         layer.fisher = del_k | ||||
|                     else: | ||||
|                         layer.fisher += del_k | ||||
|                     del layer.act #without deleting this, a nasty memory leak occurs! related: https://discuss.pytorch.org/t/memory-leak-when-using-forward-hook-and-backward-hook-simultaneously/27555 | ||||
|                 return hook | ||||
|  | ||||
|             #register backward hook on identity fcn to compute fisher info | ||||
|             layer.dummy.register_backward_hook(hook_factory(layer)) | ||||
|  | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         net.zero_grad() | ||||
|         outputs = net(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|  | ||||
|     # retrieve fisher info | ||||
|     def fisher(layer): | ||||
|         if layer.fisher is not None: | ||||
|             return torch.abs(layer.fisher.detach()) | ||||
|         else: | ||||
|             return torch.zeros(layer.weight.shape[0]) #size=ch | ||||
|  | ||||
|     grads_abs_ch = get_layer_metric_array(net, fisher, mode) | ||||
|  | ||||
|     #broadcast channel value here to all parameters in that channel | ||||
|     #to be compatible with stuff downstream (which expects per-parameter metrics) | ||||
|     #TODO cleanup on the selectors/apply_prune_mask side (?) | ||||
|     shapes = get_layer_metric_array(net, lambda l : l.weight.shape[1:], mode) | ||||
|  | ||||
|     grads_abs = reshape_elements(grads_abs_ch, shapes, device) | ||||
|  | ||||
|     return grads_abs | ||||
							
								
								
									
										38
									
								
								zero-cost-nas/foresight/pruners/measures/grad_norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								zero-cost-nas/foresight/pruners/measures/grad_norm.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import copy | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
| @measure('grad_norm', bn=True) | ||||
| def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=False): | ||||
|     net.zero_grad() | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         outputs = net.forward(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|  | ||||
|         grad_norm_arr = get_layer_metric_array(net, lambda l: l.weight.grad.norm() if l.weight.grad is not None else torch.zeros_like(l.weight), mode='param') | ||||
|          | ||||
|     return grad_norm_arr | ||||
							
								
								
									
										87
									
								
								zero-cost-nas/foresight/pruners/measures/grasp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								zero-cost-nas/foresight/pruners/measures/grasp.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| import torch.autograd as autograd | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| @measure('grasp', bn=True, mode='param') | ||||
| def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1): | ||||
|  | ||||
|     # get all applicable weights | ||||
|     weights = [] | ||||
|     for layer in net.modules(): | ||||
|         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|             weights.append(layer.weight) | ||||
|             layer.weight.requires_grad_(True) # TODO isn't this already true? | ||||
|  | ||||
|     # NOTE original code had some input/target splitting into 2 | ||||
|     # I am guessing this was because of GPU mem limit | ||||
|     net.zero_grad() | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         #forward/grad pass #1 | ||||
|         grad_w = None | ||||
|         for _ in range(num_iters): | ||||
|             #TODO get new data, otherwise num_iters is useless! | ||||
|             outputs = net.forward(inputs[st:en])/T | ||||
|             loss = loss_fn(outputs, targets[st:en]) | ||||
|             grad_w_p = autograd.grad(loss, weights, allow_unused=True) | ||||
|             if grad_w is None: | ||||
|                 grad_w = list(grad_w_p) | ||||
|             else: | ||||
|                 for idx in range(len(grad_w)): | ||||
|                     grad_w[idx] += grad_w_p[idx] | ||||
|  | ||||
|      | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         # forward/grad pass #2 | ||||
|         outputs = net.forward(inputs[st:en])/T | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         grad_f = autograd.grad(loss, weights, create_graph=True, allow_unused=True) | ||||
|          | ||||
|         # accumulate gradients computed in previous step and call backwards | ||||
|         z, count = 0,0 | ||||
|         for layer in net.modules(): | ||||
|             if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|                 if grad_w[count] is not None: | ||||
|                     z += (grad_w[count].data * grad_f[count]).sum() | ||||
|                 count += 1 | ||||
|         z.backward() | ||||
|  | ||||
|     # compute final sensitivity metric and put in grads | ||||
|     def grasp(layer): | ||||
|         if layer.weight.grad is not None: | ||||
|             return -layer.weight.data * layer.weight.grad   # -theta_q Hg | ||||
|             #NOTE in the grasp code they take the *bottom* (1-p)% of values | ||||
|             #but we take the *top* (1-p)%, therefore we remove the -ve sign | ||||
|             #EDIT accuracy seems to be negatively correlated with this metric, so we add -ve sign here! | ||||
|         else: | ||||
|             return torch.zeros_like(layer.weight) | ||||
|      | ||||
|     grads = get_layer_metric_array(net, grasp, mode) | ||||
|  | ||||
|     return grads | ||||
							
								
								
									
										57
									
								
								zero-cost-nas/foresight/pruners/measures/jacob_cov.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								zero-cost-nas/foresight/pruners/measures/jacob_cov.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import numpy as np | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_batch_jacobian(net, x, target, device, split_data): | ||||
|     x.requires_grad_(True) | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|         y = net(x[st:en]) | ||||
|         y.backward(torch.ones_like(y)) | ||||
|  | ||||
|     jacob = x.grad.detach() | ||||
|     x.requires_grad_(False) | ||||
|     return jacob, target.detach() | ||||
|  | ||||
| def eval_score(jacob, labels=None): | ||||
|     corrs = np.corrcoef(jacob) | ||||
|     v, _  = np.linalg.eig(corrs) | ||||
|     k = 1e-5 | ||||
|     return -np.sum(np.log(v + k) + 1./(v + k)) | ||||
|  | ||||
| @measure('jacob_cov', bn=True) | ||||
| def compute_jacob_cov(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     jacobs, labels = get_batch_jacobian(net, inputs, targets, device, split_data=split_data) | ||||
|     jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy() | ||||
|  | ||||
|     try: | ||||
|         jc = eval_score(jacobs, labels) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         jc = np.nan | ||||
|  | ||||
|     return jc | ||||
							
								
								
									
										22
									
								
								zero-cost-nas/foresight/pruners/measures/l2_norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								zero-cost-nas/foresight/pruners/measures/l2_norm.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| @measure('l2_norm', copy_net=False, mode='param') | ||||
| def get_l2_norm_array(net, inputs, targets, mode, split_data=1): | ||||
|     return get_layer_metric_array(net, lambda l: l.weight.norm(), mode=mode) | ||||
							
								
								
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/meco.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/meco.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import copy | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| from torch import nn | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_score(net, x, target, device, split_data): | ||||
|     result_list = [] | ||||
|  | ||||
|     def forward_hook(module, data_input, data_output): | ||||
|  | ||||
|         fea = data_output[0].detach() | ||||
|         fea = fea.reshape(fea.shape[0], -1) | ||||
|         corr = torch.corrcoef(fea) | ||||
|         corr[torch.isnan(corr)] = 0 | ||||
|         corr[torch.isinf(corr)] = 0 | ||||
|         values = torch.linalg.eig(corr)[0] | ||||
|         # result = np.real(np.min(values)) / np.real(np.max(values)) | ||||
|         result = torch.min(torch.real(values)) | ||||
|         result_list.append(result) | ||||
|  | ||||
|     for name, modules in net.named_modules(): | ||||
|         modules.register_forward_hook(forward_hook) | ||||
|  | ||||
|  | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         y = net(x[st:en]) | ||||
|     results = torch.tensor(result_list) | ||||
|     results = results[torch.logical_not(torch.isnan(results))] | ||||
|     v = torch.sum(results) | ||||
|     result_list.clear() | ||||
|     return v.item() | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('meco', bn=True) | ||||
| def compute_meco(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     try: | ||||
|         meco = get_score(net, inputs, targets, device, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         meco = np.nan, None | ||||
|     return meco | ||||
							
								
								
									
										55
									
								
								zero-cost-nas/foresight/pruners/measures/norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								zero-cost-nas/foresight/pruners/measures/norm.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_score(net, x, target, device, split_data): | ||||
|     result_list = [] | ||||
|     def forward_hook(module, data_input, data_output): | ||||
|         norm = torch.norm(data_input[0]) | ||||
|         result_list.append(norm) | ||||
|     net.classifier.register_forward_hook(forward_hook) | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         y = net(x[st:en]) | ||||
|     n = result_list[0].item() | ||||
|     result_list.clear() | ||||
|     return n | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('norm', bn=True) | ||||
| def compute_norm(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     # print('var:', feature.shape) | ||||
|     try: | ||||
|         norm, t = get_score(net, inputs, targets, device, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         norm, t = np.nan, None | ||||
|     # print(jc) | ||||
|     # print(f'norm time: {t} s') | ||||
|     return norm, t | ||||
							
								
								
									
										16
									
								
								zero-cost-nas/foresight/pruners/measures/param_count.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								zero-cost-nas/foresight/pruners/measures/param_count.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| import time | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('param_count', copy_net=False, mode='param') | ||||
| def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1): | ||||
|     s = time.time() | ||||
|     count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode) | ||||
|     e = time.time() | ||||
|     t = e - s | ||||
|     # print(f'param_count time: {t} s') | ||||
|     return count, t | ||||
							
								
								
									
										44
									
								
								zero-cost-nas/foresight/pruners/measures/plain.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								zero-cost-nas/foresight/pruners/measures/plain.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| @measure('plain', bn=True, mode='param') | ||||
| def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): | ||||
|  | ||||
|     net.zero_grad() | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         outputs = net.forward(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|  | ||||
|     # select the gradients that we want to use for search/prune | ||||
|     def plain(layer): | ||||
|         if layer.weight.grad is not None: | ||||
|             return layer.weight.grad * layer.weight | ||||
|         else: | ||||
|             return torch.zeros_like(layer.weight) | ||||
|  | ||||
|     grads_abs = get_layer_metric_array(net, plain, mode) | ||||
|     return grads_abs | ||||
							
								
								
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/snip.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/snip.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import copy | ||||
| import types | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| def snip_forward_conv2d(self, x): | ||||
|         return F.conv2d(x, self.weight * self.weight_mask, self.bias, | ||||
|                         self.stride, self.padding, self.dilation, self.groups) | ||||
|  | ||||
| def snip_forward_linear(self, x): | ||||
|         return F.linear(x, self.weight * self.weight_mask, self.bias) | ||||
|  | ||||
| @measure('snip', bn=True, mode='param') | ||||
| def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): | ||||
|     for layer in net.modules(): | ||||
|         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|             layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight)) | ||||
|             layer.weight.requires_grad = False | ||||
|  | ||||
|         # Override the forward methods: | ||||
|         if isinstance(layer, nn.Conv2d): | ||||
|             layer.forward = types.MethodType(snip_forward_conv2d, layer) | ||||
|  | ||||
|         if isinstance(layer, nn.Linear): | ||||
|             layer.forward = types.MethodType(snip_forward_linear, layer) | ||||
|  | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|      | ||||
|         outputs = net.forward(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|  | ||||
|     # select the gradients that we want to use for search/prune | ||||
|     def snip(layer): | ||||
|         if layer.weight_mask.grad is not None: | ||||
|             return torch.abs(layer.weight_mask.grad) | ||||
|         else: | ||||
|             return torch.zeros_like(layer.weight) | ||||
|      | ||||
|     grads_abs = get_layer_metric_array(net, snip, mode) | ||||
|  | ||||
|     return grads_abs | ||||
							
								
								
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/synflow.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/synflow.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| @measure('synflow', bn=False, mode='param') | ||||
| @measure('synflow_bn', bn=True, mode='param') | ||||
| def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn=None): | ||||
|  | ||||
|     device = inputs.device | ||||
|  | ||||
|     #convert params to their abs. Keep sign for converting it back. | ||||
|     @torch.no_grad() | ||||
|     def linearize(net): | ||||
|         signs = {} | ||||
|         for name, param in net.state_dict().items(): | ||||
|             signs[name] = torch.sign(param) | ||||
|             param.abs_() | ||||
|         return signs | ||||
|  | ||||
|     #convert to orig values | ||||
|     @torch.no_grad() | ||||
|     def nonlinearize(net, signs): | ||||
|         for name, param in net.state_dict().items(): | ||||
|             if 'weight_mask' not in name: | ||||
|                 param.mul_(signs[name]) | ||||
|  | ||||
|     # keep signs of all params | ||||
|     signs = linearize(net) | ||||
|      | ||||
|     # Compute gradients with input of 1s  | ||||
|     net.zero_grad() | ||||
|     net.double() | ||||
|     input_dim = list(inputs[0,:].shape) | ||||
|     inputs = torch.ones([1] + input_dim).double().to(device) | ||||
|     output = net.forward(inputs) | ||||
|     torch.sum(output).backward()  | ||||
|  | ||||
|     # select the gradients that we want to use for search/prune | ||||
|     def synflow(layer): | ||||
|         if layer.weight.grad is not None: | ||||
|             return torch.abs(layer.weight * layer.weight.grad) | ||||
|         else: | ||||
|             return torch.zeros_like(layer.weight) | ||||
|  | ||||
|     grads_abs = get_layer_metric_array(net, synflow, mode) | ||||
|  | ||||
|     # apply signs of all params | ||||
|     nonlinearize(net, signs) | ||||
|  | ||||
|     return grads_abs | ||||
|  | ||||
|  | ||||
							
								
								
									
										55
									
								
								zero-cost-nas/foresight/pruners/measures/var.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								zero-cost-nas/foresight/pruners/measures/var.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_score(net, x, target, device, split_data): | ||||
|     result_list = [] | ||||
|     def forward_hook(module, data_input, data_output): | ||||
|         var = torch.var(data_input[0]) | ||||
|         result_list.append(var) | ||||
|     net.classifier.register_forward_hook(forward_hook) | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         y = net(x[st:en]) | ||||
|     v = result_list[0].item() | ||||
|     result_list.clear() | ||||
|     return v | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('var', bn=True) | ||||
| def compute_var(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     # print('var:', feature.shape) | ||||
|     try: | ||||
|         var= get_score(net, inputs, targets, device, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         var= np.nan | ||||
|     # print(jc) | ||||
|     # print(f'var time: {t} s') | ||||
|     return var | ||||
							
								
								
									
										106
									
								
								zero-cost-nas/foresight/pruners/measures/zico.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								zero-cost-nas/foresight/pruners/measures/zico.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
| from torch import nn | ||||
|  | ||||
| from ...dataset import get_cifar_dataloaders | ||||
|  | ||||
|  | ||||
| def getgrad(model: torch.nn.Module, grad_dict: dict, step_iter=0): | ||||
|     if step_iter == 0: | ||||
|         for name, mod in model.named_modules(): | ||||
|             if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear): | ||||
|                 # print(mod.weight.grad.data.size()) | ||||
|                 # print(mod.weight.data.size()) | ||||
|                 try: | ||||
|                     grad_dict[name] = [mod.weight.grad.data.cpu().reshape(-1).numpy()] | ||||
|                 except: | ||||
|                     continue | ||||
|     else: | ||||
|         for name, mod in model.named_modules(): | ||||
|             if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear): | ||||
|                 try: | ||||
|                     grad_dict[name].append(mod.weight.grad.data.cpu().reshape(-1).numpy()) | ||||
|                 except: | ||||
|                     continue | ||||
|     return grad_dict | ||||
|  | ||||
|  | ||||
| def caculate_zico(grad_dict): | ||||
|     allgrad_array = None | ||||
|     for i, modname in enumerate(grad_dict.keys()): | ||||
|         grad_dict[modname] = np.array(grad_dict[modname]) | ||||
|     nsr_mean_sum = 0 | ||||
|     nsr_mean_sum_abs = 0 | ||||
|     nsr_mean_avg = 0 | ||||
|     nsr_mean_avg_abs = 0 | ||||
|     for j, modname in enumerate(grad_dict.keys()): | ||||
|         nsr_std = np.std(grad_dict[modname], axis=0) | ||||
|         # print(grad_dict[modname].shape) | ||||
|         # print(grad_dict[modname].shape, nsr_std.shape) | ||||
|         nonzero_idx = np.nonzero(nsr_std)[0] | ||||
|         nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0) | ||||
|         tmpsum = np.sum(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx]) | ||||
|         if tmpsum == 0: | ||||
|             pass | ||||
|         else: | ||||
|             nsr_mean_sum_abs += np.log(tmpsum) | ||||
|             nsr_mean_avg_abs += np.log(np.mean(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx])) | ||||
|     return nsr_mean_sum_abs | ||||
|  | ||||
|  | ||||
| def getzico(network, inputs, targets, loss_fn, split_data=2): | ||||
|     grad_dict = {} | ||||
|     network.train() | ||||
|     device = inputs.device | ||||
|     network.to(device) | ||||
|     N = inputs.shape[0] | ||||
|     split_data = 2 | ||||
|  | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         outputs = network.forward(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|         grad_dict = getgrad(network, grad_dict, sp) | ||||
|     # print(grad_dict) | ||||
|     res = caculate_zico(grad_dict) | ||||
|     return res | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('zico', bn=True) | ||||
| def compute_zico(net, inputs, targets, split_data=2, loss_fn=None): | ||||
|  | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     # print('var:', feature.shape) | ||||
|     try: | ||||
|         zico = getzico(net, inputs, targets, loss_fn, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         zico= np.nan | ||||
|     # print(jc) | ||||
|     # print(f'var time: {t} s') | ||||
|     return zico | ||||
							
								
								
									
										83
									
								
								zero-cost-nas/foresight/pruners/p_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								zero-cost-nas/foresight/pruners/p_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from ..models import * | ||||
|  | ||||
| def get_some_data(train_dataloader, num_batches, device): | ||||
|     traindata = [] | ||||
|     dataloader_iter = iter(train_dataloader) | ||||
|     for _ in range(num_batches): | ||||
|         traindata.append(next(dataloader_iter)) | ||||
|     inputs  = torch.cat([a for a,_ in traindata]) | ||||
|     targets = torch.cat([b for _,b in traindata]) | ||||
|     inputs = inputs.to(device) | ||||
|     targets = targets.to(device) | ||||
|     return inputs, targets | ||||
|  | ||||
| def get_some_data_grasp(train_dataloader, num_classes, samples_per_class, device): | ||||
|     datas = [[] for _ in range(num_classes)] | ||||
|     labels = [[] for _ in range(num_classes)] | ||||
|     mark = dict() | ||||
|     dataloader_iter = iter(train_dataloader) | ||||
|     while True: | ||||
|         inputs, targets = next(dataloader_iter) | ||||
|         for idx in range(inputs.shape[0]): | ||||
|             x, y = inputs[idx:idx+1], targets[idx:idx+1] | ||||
|             category = y.item() | ||||
|             if len(datas[category]) == samples_per_class: | ||||
|                 mark[category] = True | ||||
|                 continue | ||||
|             datas[category].append(x) | ||||
|             labels[category].append(y) | ||||
|         if len(mark) == num_classes: | ||||
|             break | ||||
|  | ||||
|     x = torch.cat([torch.cat(_, 0) for _ in datas]).to(device)  | ||||
|     y = torch.cat([torch.cat(_) for _ in labels]).view(-1).to(device) | ||||
|     return x, y | ||||
|  | ||||
| def get_layer_metric_array(net, metric, mode):  | ||||
|     metric_array = [] | ||||
|  | ||||
|     for layer in net.modules(): | ||||
|         if mode=='channel' and hasattr(layer,'dont_ch_prune'): | ||||
|             continue | ||||
|         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|             metric_array.append(metric(layer)) | ||||
|      | ||||
|     return metric_array | ||||
|  | ||||
| def reshape_elements(elements, shapes, device): | ||||
|     def broadcast_val(elements, shapes): | ||||
|         ret_grads = [] | ||||
|         for e,sh in zip(elements, shapes): | ||||
|             ret_grads.append(torch.stack([torch.Tensor(sh).fill_(v) for v in e], dim=0).to(device)) | ||||
|         return ret_grads | ||||
|     if type(elements[0]) == list: | ||||
|         outer = [] | ||||
|         for e,sh in zip(elements, shapes): | ||||
|             outer.append(broadcast_val(e,sh)) | ||||
|         return outer | ||||
|     else: | ||||
|         return broadcast_val(elements, shapes) | ||||
|  | ||||
| def count_parameters(model): | ||||
|     return sum(p.numel() for p in model.parameters() if p.requires_grad) | ||||
|  | ||||
							
								
								
									
										116
									
								
								zero-cost-nas/foresight/pruners/predictive.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								zero-cost-nas/foresight/pruners/predictive.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,116 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from .p_utils import * | ||||
| from . import measures | ||||
|  | ||||
| import types | ||||
| import copy | ||||
|  | ||||
|  | ||||
| def no_op(self,x): | ||||
|     return x | ||||
|  | ||||
| def copynet(self, bn): | ||||
|     net = copy.deepcopy(self) | ||||
|     if bn==False: | ||||
|         for l in net.modules(): | ||||
|             if isinstance(l,nn.BatchNorm2d) or isinstance(l,nn.BatchNorm1d) : | ||||
|                 l.forward = types.MethodType(no_op, l) | ||||
|     return net | ||||
|  | ||||
| def find_measures_arrays(net_orig, trainloader, dataload_info, device, measure_names=None, loss_fn=F.cross_entropy): | ||||
|     if measure_names is None: | ||||
|         measure_names = measures.available_measures | ||||
|  | ||||
|     dataload, num_imgs_or_batches, num_classes = dataload_info | ||||
|  | ||||
|     if not hasattr(net_orig,'get_prunable_copy'): | ||||
|         net_orig.get_prunable_copy = types.MethodType(copynet, net_orig) | ||||
|  | ||||
|     #move to cpu to free up mem | ||||
|     torch.cuda.empty_cache() | ||||
|     net_orig = net_orig.cpu()  | ||||
|     torch.cuda.empty_cache() | ||||
|  | ||||
|     #given 1 minibatch of data | ||||
|     if dataload == 'random': | ||||
|         inputs, targets = get_some_data(trainloader, num_batches=num_imgs_or_batches, device=device) | ||||
|     elif dataload == 'grasp': | ||||
|         inputs, targets = get_some_data_grasp(trainloader, num_classes, samples_per_class=num_imgs_or_batches, device=device) | ||||
|     else: | ||||
|         raise NotImplementedError(f'dataload {dataload} is not supported') | ||||
|  | ||||
|     done, ds = False, 1 | ||||
|     measure_values = {} | ||||
|  | ||||
|     while not done: | ||||
|         try: | ||||
|             for measure_name in measure_names: | ||||
|                 if measure_name not in measure_values: | ||||
|                     val = measures.calc_measure(measure_name, net_orig, device, inputs, targets, loss_fn=loss_fn, split_data=ds) | ||||
|                     measure_values[measure_name] = val | ||||
|  | ||||
|             done = True | ||||
|         except RuntimeError as e: | ||||
|             if 'out of memory' in str(e): | ||||
|                 done=False | ||||
|                 if ds == inputs.shape[0]//2: | ||||
|                     raise ValueError(f'Can\'t split data anymore, but still unable to run. Something is wrong')  | ||||
|                 ds += 1 | ||||
|                 while inputs.shape[0] % ds != 0: | ||||
|                     ds += 1 | ||||
|                 torch.cuda.empty_cache() | ||||
|                 print(f'Caught CUDA OOM, retrying with data split into {ds} parts') | ||||
|             else: | ||||
|                 raise e | ||||
|  | ||||
|     net_orig = net_orig.to(device).train() | ||||
|     return measure_values | ||||
|  | ||||
| def find_measures(net_orig,                  # neural network | ||||
|                   dataloader,                # a data loader (typically for training data) | ||||
|                   dataload_info,             # a tuple with (dataload_type = {random, grasp}, number_of_batches_for_random_or_images_per_class_for_grasp, number of classes) | ||||
|                   device,                    # GPU/CPU device used | ||||
|                   loss_fn=F.cross_entropy,   # loss function to use within the zero-cost metrics | ||||
|                   measure_names=None,        # an array of measure names to compute, if left blank, all measures are computed by default | ||||
|                   measures_arr=None):        # [not used] if the measures are already computed but need to be summarized, pass them here | ||||
|      | ||||
|     #Given a neural net | ||||
|     #and some information about the input data (dataloader) | ||||
|     #and loss function (loss_fn) | ||||
|     #this function returns an array of zero-cost proxy metrics. | ||||
|  | ||||
|     def sum_arr(arr): | ||||
|         sum = 0. | ||||
|         for i in range(len(arr)): | ||||
|             sum += torch.sum(arr[i]) | ||||
|         return sum.item() | ||||
|  | ||||
|     if measures_arr is None: | ||||
|         measures_arr = find_measures_arrays(net_orig, dataloader, dataload_info, device, loss_fn=loss_fn, measure_names=measure_names) | ||||
|  | ||||
|     measures = {} | ||||
|     for k,v in measures_arr.items(): | ||||
|         if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico']: | ||||
|             measures[k] = v | ||||
|         else: | ||||
|             measures[k] = sum_arr(v) | ||||
|  | ||||
|     return measures | ||||
							
								
								
									
										51
									
								
								zero-cost-nas/foresight/version.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								zero-cost-nas/foresight/version.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| version = '1.0.0' | ||||
| repo = 'unknown' | ||||
| commit = 'unknown' | ||||
| has_repo = False | ||||
|  | ||||
| try: | ||||
|     import git | ||||
|     from pathlib import Path | ||||
|  | ||||
|     try: | ||||
|         r = git.Repo(Path(__file__).parents[1]) | ||||
|         has_repo = True | ||||
|  | ||||
|         if not r.remotes: | ||||
|             repo = 'local' | ||||
|         else: | ||||
|             repo = r.remotes.origin.url | ||||
|  | ||||
|         commit = r.head.commit.hexsha | ||||
|         if r.is_dirty(): | ||||
|             commit += ' (dirty)' | ||||
|     except git.InvalidGitRepositoryError: | ||||
|         raise ImportError() | ||||
| except ImportError: | ||||
|     pass | ||||
|  | ||||
| try: | ||||
|     from . import _dist_info as info | ||||
|     assert not has_repo, '_dist_info should not exist when repo is in place' | ||||
|     assert version == info.version | ||||
|     repo = info.repo | ||||
|     commit = info.commit | ||||
| except (ImportError, SystemError): | ||||
|     pass | ||||
|  | ||||
| __all__ = ['version', 'repo', 'commit', 'has_repo'] | ||||
							
								
								
									
										68
									
								
								zero-cost-nas/foresight/weight_initializers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								zero-cost-nas/foresight/weight_initializers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,68 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch.nn as nn | ||||
|  | ||||
| def init_net(net, w_type, b_type): | ||||
|     if w_type == 'none': | ||||
|         pass | ||||
|     elif w_type == 'xavier': | ||||
|         net.apply(init_weights_vs) | ||||
|     elif w_type == 'kaiming': | ||||
|         net.apply(init_weights_he) | ||||
|     elif w_type == 'zero': | ||||
|         net.apply(init_weights_zero) | ||||
|     else: | ||||
|         raise NotImplementedError(f'init_type={w_type} is not supported.') | ||||
|  | ||||
|     if b_type == 'none': | ||||
|         pass | ||||
|     elif b_type == 'xavier': | ||||
|         net.apply(init_bias_vs) | ||||
|     elif b_type == 'kaiming': | ||||
|         net.apply(init_bias_he) | ||||
|     elif b_type == 'zero': | ||||
|         net.apply(init_bias_zero) | ||||
|     else: | ||||
|         raise NotImplementedError(f'init_type={b_type} is not supported.') | ||||
|  | ||||
| def init_weights_vs(m): | ||||
|     if type(m) == nn.Linear or type(m) == nn.Conv2d: | ||||
|         nn.init.xavier_normal_(m.weight) | ||||
|  | ||||
| def init_bias_vs(m): | ||||
|     if type(m) == nn.Linear or type(m) == nn.Conv2d: | ||||
|         if m.bias is not None: | ||||
|             nn.init.xavier_normal_(m.bias) | ||||
|  | ||||
| def init_weights_he(m): | ||||
|     if type(m) == nn.Linear or type(m) == nn.Conv2d: | ||||
|         nn.init.kaiming_normal_(m.weight) | ||||
|  | ||||
| def init_bias_he(m): | ||||
|     if type(m) == nn.Linear or type(m) == nn.Conv2d: | ||||
|         if m.bias is not None: | ||||
|             nn.init.kaiming_normal_(m.bias) | ||||
|  | ||||
| def init_weights_zero(m): | ||||
|     if type(m) == nn.Linear or type(m) == nn.Conv2d: | ||||
|         m.weight.data.fill_(.0) | ||||
|  | ||||
| def init_bias_zero(m): | ||||
|     if type(m) == nn.Linear or type(m) == nn.Conv2d: | ||||
|         if m.bias is not None: | ||||
|             m.bias.data.fill_(.0) | ||||
|  | ||||
|      | ||||
		Reference in New Issue
	
	Block a user