first commit
This commit is contained in:
commit
8d6225609e
33
README.md
Normal file
33
README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# Sample-Wise Activation Patterns for Ultra-Fast NAS <br/> (ICLR 2024 Spotlight)
|
||||
Training-free metrics (a.k.a. zero-cost proxies) are widely used to avoid resource-intensive neural network training, especially in Neural Architecture Search (NAS). Recent studies show that existing training-free metrics have several limitations, such as limited correlation and poor generalisation across different search spaces and tasks. Hence, we propose Sample-Wise Activation Patterns and its derivative, SWAP-Score, a novel high-performance training-free metric. It measures the expressivity of networks over a batch of input samples. The SWAP-Score is strongly correlated with ground-truth performance across various search spaces and tasks, outperforming 15 existing training-free metrics on NAS-Bench-101/201/301 and TransNAS-Bench-101.
|
||||
|
||||
# Usage
|
||||
|
||||
The following instruction demonstrates the usage of evaluating network's performance through SWAP-Score.
|
||||
|
||||
**/src/metrics/swap.py** contains the core components of SWAP-Score.
|
||||
|
||||
**/datasets/DARTS_archs_CIFAR10.csv** contains 1000 architectures (randomly sampled from DARTS space) along with their CIFAR-10 validation accuracies (trained for 200 epochs).
|
||||
|
||||
* Install necessary dependencies (a new virtual environment is suggested).
|
||||
```
|
||||
cd SWAP
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
* Calculate the correlation between SWAP-Score and CIFAR-10 validation accuracies of 1000 DARTS architectures.
|
||||
```
|
||||
python correlation.py
|
||||
```
|
||||
|
||||
|
||||
If you use or build on our code, please consider citing our paper:
|
||||
```
|
||||
@inproceedings{
|
||||
peng2024swapnas,
|
||||
title={{SWAP}-{NAS}: Sample-Wise Activation Patterns for Ultra-fast {NAS}},
|
||||
author={Yameng Peng and Andy Song and Haytham M. Fayek and Vic Ciesielski and Xiaojun Chang},
|
||||
booktitle={The Twelfth International Conference on Learning Representations},
|
||||
year={2024},
|
||||
url={https://openreview.net/forum?id=tveiUXU2aa}
|
||||
}
|
||||
```
|
66
correlation.py
Normal file
66
correlation.py
Normal file
@ -0,0 +1,66 @@
|
||||
import os
|
||||
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy import stats
|
||||
from src.utils.utilities import *
|
||||
from src.metrics.swap import SWAP
|
||||
from src.datasets.utilities import get_datasets
|
||||
from src.search_space.networks import *
|
||||
|
||||
# Settings for console outputs
|
||||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
warnings.simplefilter(action='ignore', category=UserWarning)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# general setting
|
||||
parser.add_argument('--data_path', default="datasets", type=str, nargs='?', help='path to the image dataset (datasets or datasets/ILSVRC/Data/CLS-LOC)')
|
||||
parser.add_argument('--seed', default=0, type=int, help='random seed')
|
||||
parser.add_argument('--device', default="mps", type=str, nargs='?', help='setup device (cpu, mps or cuda)')
|
||||
parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric')
|
||||
parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
|
||||
|
||||
train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True)
|
||||
loader = iter(train_loader)
|
||||
inputs, _ = next(loader)
|
||||
|
||||
results = []
|
||||
|
||||
for index, i in arch_info.iterrows():
|
||||
print(f'Evaluating network: {index}')
|
||||
|
||||
network = Network(3, 10, 1, eval(i.genotype))
|
||||
network = network.to(device)
|
||||
|
||||
swap = SWAP(model=network, inputs=inputs, device=device, seed=args.seed)
|
||||
|
||||
swap_score = []
|
||||
|
||||
for _ in range(args.repeats):
|
||||
network = network.apply(network_weight_gaussian_init)
|
||||
swap.reinit()
|
||||
swap_score.append(swap.forward())
|
||||
swap.clear()
|
||||
|
||||
results.append([np.mean(swap_score), i.valid_acc])
|
||||
|
||||
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc'])
|
||||
print()
|
||||
print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}')
|
||||
|
||||
|
||||
|
1000
datasets/DARTS_archs_CIFAR10.csv
Normal file
1000
datasets/DARTS_archs_CIFAR10.csv
Normal file
File diff suppressed because it is too large
Load Diff
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
numpy>=1.24.2
|
||||
pandas>=1.5.3
|
||||
scipy>=1.10.0
|
||||
torch>=2.0.1
|
||||
torchvision>=0.15.2
|
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
109
src/datasets/DownsampledImageNet.py
Normal file
109
src/datasets/DownsampledImageNet.py
Normal file
@ -0,0 +1,109 @@
|
||||
import os, sys, hashlib
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch.utils.data as data
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
def calculate_md5(fpath, chunk_size=1024 * 1024):
|
||||
md5 = hashlib.md5()
|
||||
with open(fpath, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(fpath, md5, **kwargs):
|
||||
return md5 == calculate_md5(fpath, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
if not os.path.isfile(fpath): return False
|
||||
if md5 is None: return True
|
||||
else : return check_md5(fpath, md5)
|
||||
|
||||
|
||||
class ImageNet16(data.Dataset):
|
||||
# http://image-net.org/download-images
|
||||
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
|
||||
# https://arxiv.org/pdf/1707.08819.pdf
|
||||
|
||||
train_list = [
|
||||
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
|
||||
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
|
||||
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
|
||||
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
|
||||
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
|
||||
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
|
||||
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
|
||||
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
|
||||
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
|
||||
['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
|
||||
]
|
||||
valid_list = [
|
||||
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
|
||||
]
|
||||
|
||||
def __init__(self, root, train, transform, use_num_of_class_only=None):
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.train = train # training set or valid set
|
||||
if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
|
||||
|
||||
if self.train: downloaded_list = self.train_list
|
||||
else : downloaded_list = self.valid_list
|
||||
self.data = []
|
||||
self.targets = []
|
||||
|
||||
# now load the picked numpy arrays
|
||||
for i, (file_name, checksum) in enumerate(downloaded_list):
|
||||
file_path = os.path.join(self.root, file_name)
|
||||
#print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
|
||||
with open(file_path, 'rb') as f:
|
||||
if sys.version_info[0] == 2:
|
||||
entry = pickle.load(f)
|
||||
else:
|
||||
entry = pickle.load(f, encoding='latin1')
|
||||
self.data.append(entry['data'])
|
||||
self.targets.extend(entry['labels'])
|
||||
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
|
||||
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
if use_num_of_class_only is not None:
|
||||
assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
|
||||
new_data, new_targets = [], []
|
||||
for I, L in zip(self.data, self.targets):
|
||||
if 1 <= L <= use_num_of_class_only:
|
||||
new_data.append( I )
|
||||
new_targets.append( L )
|
||||
self.data = new_data
|
||||
self.targets = new_targets
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], self.targets[index] - 1
|
||||
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.root
|
||||
for fentry in (self.train_list + self.valid_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, filename)
|
||||
if not check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
0
src/datasets/__init__.py
Normal file
0
src/datasets/__init__.py
Normal file
115
src/datasets/utilities.py
Normal file
115
src/datasets/utilities.py
Normal file
@ -0,0 +1,115 @@
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as dset
|
||||
from .DownsampledImageNet import ImageNet16
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
|
||||
Dataset2Class = {'cifar10': 10,
|
||||
'cifar100': 100,
|
||||
'imagenet-1k-s': 1000,
|
||||
'imagenet-1k': 1000,
|
||||
'ImageNet16' : 1000,
|
||||
'ImageNet16-120': 120,
|
||||
'ImageNet16-150': 150,
|
||||
'ImageNet16-200': 200}
|
||||
|
||||
class RandChannel(object):
|
||||
# randomly pick channels from input
|
||||
def __init__(self, num_channel):
|
||||
self.num_channel = num_channel
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def __call__(self, img):
|
||||
channel = img.size(0)
|
||||
channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False))
|
||||
return torch.index_select(img, 0, torch.Tensor(channel_choice).long())
|
||||
|
||||
|
||||
def get_datasets(name, root, input_size, cutout=-1):
|
||||
assert len(input_size) in [3, 4]
|
||||
if len(input_size) == 4:
|
||||
input_size = input_size[1:]
|
||||
assert input_size[1] == input_size[2]
|
||||
|
||||
if name == 'cifar10':
|
||||
mean = [0.49139968, 0.48215827, 0.44653124]
|
||||
std = [0.24703233, 0.24348505, 0.26158768]
|
||||
elif name == 'cifar100':
|
||||
mean = [0.5071, 0.4865, 0.4409]
|
||||
std = [0.2673, 0.2564, 0.2762]
|
||||
elif name.startswith('imagenet-1k'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith('ImageNet16'):
|
||||
mean = [0.481098, 0.45749, 0.407882]
|
||||
std = [0.247922, 0.240235, 0.255255]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
# Data Argumentation
|
||||
if name == 'cifar10' or name == 'cifar100':
|
||||
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
elif name.startswith('ImageNet16'):
|
||||
lists = [transforms.RandomCrop(input_size[1], padding=0), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
elif name.startswith('imagenet-1k'):
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
if name == 'imagenet-1k':
|
||||
xlists = []
|
||||
xlists.append(transforms.Resize((input_size[1], input_size[1]), interpolation=2))
|
||||
xlists.append(transforms.RandomCrop(input_size[1], padding=0))
|
||||
elif name == 'imagenet-1k-s':
|
||||
xlists = [transforms.RandomResizedCrop(input_size[1], scale=(0.2, 1.0))]
|
||||
xlists = []
|
||||
else: raise ValueError('invalid name : {:}'.format(name))
|
||||
xlists.append(transforms.ToTensor())
|
||||
xlists.append(normalize)
|
||||
xlists.append(RandChannel(input_size[0]))
|
||||
train_transform = transforms.Compose(xlists)
|
||||
test_transform = transforms.Compose([transforms.Resize(input_size[1]), transforms.CenterCrop(input_size[1]), transforms.ToTensor(), normalize])
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
if name == 'cifar10':
|
||||
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'cifar100':
|
||||
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name.startswith('imagenet-1k'):
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||
elif name == 'ImageNet16':
|
||||
root = osp.join(root, 'ImageNet16')
|
||||
train_data = ImageNet16(root, True , train_transform)
|
||||
test_data = ImageNet16(root, False, test_transform)
|
||||
assert len(train_data) == 1281167 and len(test_data) == 50000
|
||||
elif name == 'ImageNet16-120':
|
||||
root = osp.join(root, 'ImageNet16')
|
||||
train_data = ImageNet16(root, True , train_transform, 120)
|
||||
test_data = ImageNet16(root, False, test_transform , 120)
|
||||
assert len(train_data) == 151700 and len(test_data) == 6000
|
||||
elif name == 'ImageNet16-150':
|
||||
root = osp.join(root, 'ImageNet16')
|
||||
train_data = ImageNet16(root, True , train_transform, 150)
|
||||
test_data = ImageNet16(root, False, test_transform , 150)
|
||||
assert len(train_data) == 190272 and len(test_data) == 7500
|
||||
elif name == 'ImageNet16-200':
|
||||
root = osp.join(root, 'ImageNet16')
|
||||
train_data = ImageNet16(root, True , train_transform, 200)
|
||||
test_data = ImageNet16(root, False, test_transform , 200)
|
||||
assert len(train_data) == 254775 and len(test_data) == 10000
|
||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
class_num = Dataset2Class[name]
|
||||
return train_data, test_data, class_num
|
0
src/metrics/__init__.py
Normal file
0
src/metrics/__init__.py
Normal file
99
src/metrics/swap.py
Normal file
99
src/metrics/swap.py
Normal file
@ -0,0 +1,99 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.utils.utilities import count_parameters
|
||||
|
||||
def cal_regular_factor(model, mu, sigma):
|
||||
|
||||
model_params = torch.as_tensor(count_parameters(model))
|
||||
regular_factor = torch.exp(-(torch.pow((model_params-mu),2)/sigma))
|
||||
|
||||
return regular_factor
|
||||
|
||||
|
||||
class SampleWiseActivationPatterns(object):
|
||||
def __init__(self, device):
|
||||
self.swap = -1
|
||||
self.activations = None
|
||||
self.device = device
|
||||
|
||||
@torch.no_grad()
|
||||
def collect_activations(self, activations):
|
||||
n_sample = activations.size()[0]
|
||||
n_neuron = activations.size()[1]
|
||||
|
||||
if self.activations is None:
|
||||
self.activations = torch.zeros(n_sample, n_neuron).to(self.device)
|
||||
|
||||
self.activations = torch.sign(activations)
|
||||
|
||||
@torch.no_grad()
|
||||
def calSWAP(self, regular_factor):
|
||||
|
||||
self.activations = self.activations.T # transpose the activation matrix: (samples, neurons) to (neurons, samples)
|
||||
self.swap = torch.unique(self.activations, dim=0).size(0)
|
||||
|
||||
del self.activations
|
||||
self.activations = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return self.swap * regular_factor
|
||||
|
||||
|
||||
class SWAP:
|
||||
def __init__(self, model=None, inputs = None, device='cuda', seed=0, regular=False, mu=None, sigma=None):
|
||||
self.model = model
|
||||
self.interFeature = []
|
||||
self.seed = seed
|
||||
self.regular_factor = 1
|
||||
self.inputs = inputs
|
||||
self.device = device
|
||||
|
||||
if regular and mu is not None and sigma is not None:
|
||||
self.regular_factor = cal_regular_factor(self.model, mu, sigma).item()
|
||||
|
||||
self.reinit(self.model, self.seed)
|
||||
|
||||
def reinit(self, model=None, seed=None):
|
||||
if model is not None:
|
||||
self.model = model
|
||||
self.register_hook(self.model)
|
||||
self.swap = SampleWiseActivationPatterns(self.device)
|
||||
|
||||
if seed is not None and seed != self.seed:
|
||||
self.seed = seed
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
del self.interFeature
|
||||
self.interFeature = []
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def clear(self):
|
||||
self.swap = SampleWiseActivationPatterns(self.device)
|
||||
del self.interFeature
|
||||
self.interFeature = []
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def register_hook(self, model):
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, nn.ReLU):
|
||||
m.register_forward_hook(hook=self.hook_in_forward)
|
||||
|
||||
def hook_in_forward(self, module, input, output):
|
||||
if isinstance(input, tuple) and len(input[0].size()) == 4:
|
||||
self.interFeature.append(output.detach())
|
||||
|
||||
def forward(self):
|
||||
self.interFeature = []
|
||||
with torch.no_grad():
|
||||
self.model.forward(self.inputs.to(self.device))
|
||||
if len(self.interFeature) == 0: return
|
||||
activtions = torch.cat([f.view(self.inputs.size(0), -1) for f in self.interFeature], 1)
|
||||
self.swap.collect_activations(activtions)
|
||||
|
||||
return self.swap.calSWAP(self.regular_factor)
|
||||
|
||||
|
||||
|
||||
|
||||
|
0
src/search_space/__init__.py
Normal file
0
src/search_space/__init__.py
Normal file
105
src/search_space/networks.py
Normal file
105
src/search_space/networks.py
Normal file
@ -0,0 +1,105 @@
|
||||
from .operations import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import namedtuple
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
x = nn.functional.dropout(x, p=drop_prob)
|
||||
|
||||
return x
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, True)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, True)
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat # 2,3,4,5
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat # 2,3,4,5
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2 # 4
|
||||
self._concat = concat # 2,3,4,5
|
||||
self.multiplier = len(concat) # 4
|
||||
self._ops = nn.ModuleList()
|
||||
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
class Network(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, genotype):
|
||||
self.drop_path_prob = 0.
|
||||
super(Network, self).__init__()
|
||||
|
||||
self._layers = layers
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C, C, C
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = False
|
||||
|
||||
for i in range(layers):
|
||||
if i in [layers // 3, 2 * layers // 3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
s0 = s1 = input
|
||||
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
|
||||
out = self.global_pooling(s1)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out
|
||||
|
147
src/search_space/operations.py
Normal file
147
src/search_space/operations.py
Normal file
@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
OPS = {
|
||||
'none': lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
|
||||
'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine),
|
||||
'max_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max', affine),
|
||||
'skip_connect': lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
|
||||
'sep_conv_3x3': lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine),
|
||||
'sep_conv_5x5': lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine),
|
||||
'dil_conv_3x3': lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine),
|
||||
'dil_conv_5x5': lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine),
|
||||
}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, track_running_stats=True):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, track_running_stats=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine, track_running_stats=track_running_stats),
|
||||
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride=2, affine=True, track_running_stats=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 1:
|
||||
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 2:
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
|
||||
else:
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.stride = stride
|
||||
self.is_zero = True
|
||||
|
||||
def forward(self, x):
|
||||
if self.C_in == self.C_out:
|
||||
if self.stride == 1: return x.mul(0.)
|
||||
else : return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
else:
|
||||
shape = list(x.shape)
|
||||
shape[1] = self.C_out
|
||||
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return zeros
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats)
|
||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.preprocess: x = self.preprocess(inputs)
|
||||
else : x = inputs
|
||||
return self.op(x)
|
||||
|
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
38
src/utils/utilities.py
Normal file
38
src/utils/utilities.py
Normal file
@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(object):
|
||||
def __init__(self):
|
||||
self.arch = None
|
||||
self.geno = None
|
||||
self.score = None
|
||||
|
||||
def count_parameters(model):
|
||||
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e3
|
||||
|
||||
|
||||
def network_weight_gaussian_init(net: nn.Module):
|
||||
with torch.no_grad():
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
continue
|
||||
|
||||
return net
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user