This commit is contained in:
HamsterMimi
2024-01-23 10:08:45 +08:00
parent 1a57decf65
commit 3f6d16e791
92 changed files with 12855 additions and 41 deletions

2
.idea/MeCo.iml generated
View File

@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="jdk" jdkName="Remote Python 3.8.16 (sftp://jty@172.16.214.99:7712/jty/anaconda3/envs/zero-cost-nas/bin/python3.8)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

289
.idea/deployment.xml generated
View File

@@ -1,120 +1,372 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="jty@172.16.214.99:7712 password (21)" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="ubuntu@172.16.214.100:7712 password">
<paths name="jty@172.16.185.253:22 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (10)">
<paths name="jty@172.16.185.253:22 password (10)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (11)">
<paths name="jty@172.16.185.253:22 password (11)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (2)">
<paths name="jty@172.16.185.253:22 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (3)">
<paths name="jty@172.16.185.253:22 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (4)">
<paths name="jty@172.16.185.253:22 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (5)">
<paths name="jty@172.16.185.253:22 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (6)">
<paths name="jty@172.16.185.253:22 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (7)">
<paths name="jty@172.16.185.253:22 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (8)">
<paths name="jty@172.16.185.253:22 password (8)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (9)">
<paths name="jty@172.16.185.253:22 password (9)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password">
<paths name="jty@172.16.214.100:7712 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (2)">
<paths name="jty@172.16.214.100:7712 password (10)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (3)">
<paths name="jty@172.16.214.100:7712 password (11)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (4)">
<paths name="jty@172.16.214.100:7712 password (12)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (5)">
<paths name="jty@172.16.214.100:7712 password (13)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (6)">
<paths name="jty@172.16.214.100:7712 password (14)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (15)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (16)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (17)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (18)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (19)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (20)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (21)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (8)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.100:7712 password (9)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (10)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (11)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (12)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (13)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (14)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (15)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (16)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (17)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (18)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (19)">
<serverdata>
<mappings>
<mapping deploy="/tmp/pycharm_project_928" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (20)">
<serverdata>
<mappings>
<mapping deploy="/jty/jty" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (21)">
<serverdata>
<mappings>
<mapping deploy="/jty/jty/meco_submit" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (8)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="jty@172.16.214.99:7712 password (9)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
@@ -122,5 +374,6 @@
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>
</project>

5
.idea/misc.xml generated
View File

@@ -1,4 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.16 (sftp://ubuntu@172.16.214.100:7712/jty/anaconda3/envs/meco/bin/python3.8)" project-jdk-type="Python SDK" />
<component name="Black">
<option name="sdkName" value="Remote Python 3.8.16 (sftp://jty@172.16.214.99:7712/jty/anaconda3/envs/zero-cost-nas/bin/python3.8)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.16 (sftp://jty@172.16.214.99:7712/jty/anaconda3/envs/zero-cost-nas/bin/python3.8)" project-jdk-type="Python SDK" />
</project>

View File

@@ -11,24 +11,60 @@ from models import get_cell_based_tiny_net
import pickle
def get_score(net, x, device, measure='meco'):
result_list = []
def forward_hook(module, data_input, data_output):
fea = data_output[0].clone().detach()
n = torch.tensor(fea.shape[0])
fea = fea.reshape(n, -1)
if measure == 'meco':
corr = torch.corrcoef(fea)
corr[torch.isnan(corr)] = 0
corr[torch.isinf(corr)] = 0
values = torch.linalg.eig(corr)[0]
result = torch.min(torch.real(values))
elif measure == 'meco_opt':
idxs = random.sample(range(n), 8)
fea = fea[idxs, :]
corr = torch.corrcoef(fea)
corr[torch.isnan(corr)] = 0
corr[torch.isinf(corr)] = 0
values = torch.linalg.eig(corr)[0]
result = torch.min(torch.real(values)) * n / 8
result_list.append(result)
for name, modules in net.named_modules():
modules.register_forward_hook(forward_hook)
x = x.to(device)
net(x)
results = torch.tensor(result_list)
results = results[torch.logical_not(torch.isnan(results))]
results = results[torch.logical_not(torch.isinf(results))]
res = torch.sum(results)
result_list.clear()
return res.item()
def get_num_classes(args):
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
def parse_arguments():
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-201')
parser.add_argument('--api_loc', default='../data/NAS-Bench-201-v1_0-e61699.pth',
type=str, help='path to API')
# parser.add_argument('--api_loc', default='../data/NAS-Bench-201-v1_0-e61699.pth',
# type=str, help='path to API')
parser.add_argument('--outdir', default='./',
type=str, help='output directory')
parser.add_argument('--search_space', type=str, default='tss', choices=['tss', 'sss'])
parser.add_argument('--init_w_type', type=str, default='none',
help='weight initialization (before pruning) type [none, xavier, kaiming, zero, one]')
parser.add_argument('--init_b_type', type=str, default='none',
help='bias initialization (before pruning) type [none, xavier, kaiming, zero, one]')
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--dataset', type=str, default='ImageNet16-120',
parser.add_argument('--measure', type=str, default='meco', choices=['meco', 'meco_opt'])
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--dataset', type=str, default='cifar10',
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
parser.add_argument('--gpu', type=int, default=5, help='GPU index to work on')
parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on')
parser.add_argument('--data_size', type=int, default=32, help='data_size')
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
parser.add_argument('--dataload', type=str, default='appoint', help='random, grasp, appoint supported')
@@ -49,11 +85,9 @@ if __name__ == '__main__':
args = parse_arguments()
print(args.device)
if args.noacc:
api = pickle.load(open(args.api_loc,'rb'))
else:
from nas_201_api import NASBench201API as API
api = API(args.api_loc)
from nats_bench import create
api = create(None, args.search_space, fast_mode=True, verbose=False)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
@@ -61,9 +95,6 @@ if __name__ == '__main__':
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset, args.num_data_workers, resize=args.data_size)
x, y = next(iter(train_loader))
# random data
# x = torch.rand((args.batch_size, 3, args.data_size, args.data_size))
# y = 0
cached_res = []
pre = 'cf' if 'cifar' in args.dataset else 'im'
@@ -81,7 +112,6 @@ if __name__ == '__main__':
break
res = {'i': i, 'arch': arch_str}
# print(arch_str)
if args.search_space == 'tss':
net = nasbench2.get_model_from_arch_str(arch_str, get_num_classes(args))
arch_str2 = nasbench2.get_arch_str_from_model(net)
@@ -91,21 +121,22 @@ if __name__ == '__main__':
raise ValueError
elif args.search_space == 'sss':
config = api.get_net_config(i, args.dataset)
# print(config)
net = get_cell_based_tiny_net(config)
net.to(args.device)
# print(net)
init_net(net, args.init_w_type, args.init_b_type)
# print(x.size(), y)
measures = get_score(net, x, i, args.device)
measures = get_score(net, x, args.device, measure=args.measure)
res['meco'] = measures
res[f'{args.measure}'] = measures
if not args.noacc:
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
hp='200', is_random=False)
if args.search_space == 'tss':
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
hp='200', is_random=False)
else:
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
hp='90', is_random=False)
trainacc = info['train-accuracy']
valacc = info['valid-accuracy']

View File

@@ -0,0 +1,24 @@
import pandas as pd
import pickle
# path = 'result/sss_cf10_meco.p'
path = 'nb2_sss_cf10_seed42_dlappoint_dlinfo1_initwnone_initbnone_1.p'
meco = []
accs = []
with open(path, 'rb') as f:
while True:
try:
fl = pickle.load(f)
meco.append(fl['meco'])
accs.append(fl['testacc'])
except:
break
N = len(meco)
print(N)
df = pd.DataFrame({
'meco':meco[:N],
'acc': accs[:N]
})
print(df.corr(method='spearman'))

View File

@@ -0,0 +1,16 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from .version import *

View File

@@ -0,0 +1,133 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader
import torch
from .imagenet16 import *
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
# print(dataset)
if 'ImageNet16' in dataset:
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
size, pad = 16, 2
elif 'cifar' in dataset:
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
size, pad = 32, 4
elif 'svhn' in dataset:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
size, pad = 32, 0
elif dataset == 'ImageNet1k':
from .h5py_dataset import H5Dataset
size,pad = 224,2
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
#resize = 256
elif 'random' in dataset:
mean = (0.5, 0.5, 0.5)
std = (1, 1, 1)
size, pad = 32, 0
if resize is None:
resize = size
train_transform = transforms.Compose([
transforms.RandomCrop(size, padding=pad),
transforms.Resize(resize),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean,std),
])
test_transform = transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize(mean,std),
])
if dataset == 'cifar10':
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
test_dataset = CIFAR10(datadir, False, test_transform, download=True)
elif dataset == 'cifar100':
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
elif dataset == 'svhn':
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
elif dataset == 'ImageNet16-120':
train_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), True , train_transform, 120)
test_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), False, test_transform , 120)
elif dataset == 'ImageNet1k':
train_dataset = H5Dataset(os.path.join(datadir, 'imagenet-train-256.h5'), transform=train_transform)
test_dataset = H5Dataset(os.path.join(datadir, 'imagenet-val-256.h5'), transform=test_transform)
else:
raise ValueError('There are no more cifars or imagenets.')
train_loader = DataLoader(
train_dataset,
train_batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
test_loader = DataLoader(
test_dataset,
test_batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, test_loader
def get_mnist_dataloaders(train_batch_size, val_batch_size, num_workers):
data_transform = Compose([transforms.ToTensor()])
# Normalise? transforms.Normalize((0.1307,), (0.3081,))
train_dataset = MNIST("_dataset", True, data_transform, download=True)
test_dataset = MNIST("_dataset", False, data_transform, download=True)
train_loader = DataLoader(
train_dataset,
train_batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
test_loader = DataLoader(
test_dataset,
val_batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, test_loader
if __name__ == '__main__':
tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset')
for x, y in tr:
print(x.size(), y.size())
break

View File

@@ -0,0 +1,55 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import h5py
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
class H5Dataset(Dataset):
def __init__(self, h5_path, transform=None):
self.h5_path = h5_path
self.h5_file = None
self.length = len(h5py.File(h5_path, 'r'))
self.transform = transform
def __getitem__(self, index):
#loading in getitem allows us to use multiple processes for data loading
#because hdf5 files aren't pickelable so can't transfer them across processes
# https://discuss.pytorch.org/t/hdf5-a-data-format-for-pytorch/40379
# https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16
# TODO possible look at __getstate__ and __setstate__ as a more elegant solution
if self.h5_file is None:
self.h5_file = h5py.File(self.h5_path, 'r')
record = self.h5_file[str(index)]
if self.transform:
x = Image.fromarray(record['data'][()])
x = self.transform(x)
else:
x = torch.from_numpy(record['data'][()])
y = record['target'][()]
y = torch.from_numpy(np.asarray(y))
return (x,y)
def __len__(self):
return self.length

View File

@@ -0,0 +1,142 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import hashlib
import os
import sys
import numpy as np
import torch.utils.data as data
from PIL import Image
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath, md5, **kwargs):
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None):
if not os.path.isfile(fpath):
print(fpath)
return False
if md5 is None:
return True
else:
return check_md5(fpath, md5)
class ImageNet16(data.Dataset):
# http://image-net.org/download-images
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
# https://arxiv.org/pdf/1707.08819.pdf
train_list = [
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
['train_data_batch_10', '8f03f34ac4b42271a294f91bf480f29b'],
]
valid_list = [
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
]
def __init__(self, root, train, transform, use_num_of_class_only=None):
self.root = root
self.transform = transform
self.train = train # training set or valid set
if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.valid_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for i, (file_name, checksum) in enumerate(downloaded_list):
file_path = os.path.join(self.root, file_name)
# print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
self.targets.extend(entry['labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
if use_num_of_class_only is not None:
assert isinstance(use_num_of_class_only,
int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(
use_num_of_class_only)
new_data, new_targets = [], []
for I, L in zip(self.data, self.targets):
if 1 <= L <= use_num_of_class_only:
new_data.append(I)
new_targets.append(L)
self.data = new_data
self.targets = new_targets
# self.mean.append(entry['mean'])
# self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
# self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
# print ('Mean : {:}'.format(self.mean))
# temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
# std_data = np.std(temp, axis=0)
# std_data = np.mean(np.mean(std_data, axis=0), axis=0)
# print ('Std : {:}'.format(std_data))
def __getitem__(self, index):
img, target = self.data[index], self.targets[index] - 1
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.valid_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, filename)
if not check_integrity(fpath, md5):
return False
return True
#
if __name__ == '__main__':
train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True, None)
valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
print(len(train))
print(len(valid))
image, label = train[111]
trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True, None, 200)
validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None, 200)
print(len(trainX))
print(len(validX))
# import pdb; pdb.set_trace()

View File

@@ -0,0 +1,19 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from os.path import dirname, basename, isfile, join
import glob
modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]

View File

@@ -0,0 +1,251 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Builds the Pytorch computational graph.
Tensors flowing into a single vertex are added together for all vertices
except the output, which is concatenated instead. Tensors flowing out of input
are always added.
If interior edge channels don't match, drop the extra channels (channels are
guaranteed non-decreasing). Tensors flowing out of the input as always
projected instead.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import math
from .nasbench1_ops import *
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
def __init__(self, spec, stem_out, num_stacks, num_mods, num_classes, bn=True):
super(Network, self).__init__()
self.spec=spec
self.stem_out=stem_out
self.num_stacks=num_stacks
self.num_mods=num_mods
self.num_classes=num_classes
self.layers = nn.ModuleList([])
in_channels = 3
out_channels = stem_out
# initial stem convolution
stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1, bn=bn)
self.layers.append(stem_conv)
in_channels = out_channels
for stack_num in range(num_stacks):
if stack_num > 0:
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
self.layers.append(downsample)
out_channels *= 2
for _ in range(num_mods):
cell = Cell(spec, in_channels, out_channels, bn=bn)
self.layers.append(cell)
in_channels = out_channels
self.classifier = nn.Linear(out_channels, num_classes)
self._initialize_weights()
def forward(self, x):
for _, layer in enumerate(self.layers):
x = layer(x)
out = torch.mean(x, (2, 3))
out = self.classifier(out)
return out
def get_prunable_copy(self, bn=False):
model_new = Network(self.spec, self.stem_out, self.num_stacks, self.num_mods, self.num_classes, bn=bn)
#TODO this is quite brittle and doesn't work with nn.Sequential when bn is different
# it is only required to maintain initialization -- maybe init after get_punable_copy?
model_new.load_state_dict(self.state_dict(), strict=False)
model_new.train()
return model_new
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
class Cell(nn.Module):
"""
Builds the model using the adjacency matrix and op labels specified. Channels
controls the module output channel count but the interior channels are
determined via equally splitting the channel count whenever there is a
concatenation of Tensors.
"""
def __init__(self, spec, in_channels, out_channels, bn=True):
super(Cell, self).__init__()
self.spec = spec
self.num_vertices = np.shape(self.spec.matrix)[0]
# vertex_channels[i] = number of output channels of vertex i
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.spec.matrix)
#self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
# operation for each node
self.vertex_op = nn.ModuleList([None])
for t in range(1, self.num_vertices-1):
op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t], bn=bn)
self.vertex_op.append(op)
# operation for input on each vertex
self.input_op = nn.ModuleList([None])
for t in range(1, self.num_vertices):
if self.spec.matrix[0, t]:
self.input_op.append(Projection(in_channels, self.vertex_channels[t], bn=bn))
else:
self.input_op.append(None)
def forward(self, x):
tensors = [x]
out_concat = []
for t in range(1, self.num_vertices-1):
fan_in = [Truncate(tensors[src], self.vertex_channels[t]) for src in range(1, t) if self.spec.matrix[src, t]]
if self.spec.matrix[0, t]:
fan_in.append(self.input_op[t](x))
# perform operation on node
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
vertex_input = sum(fan_in)
#vertex_input = sum(fan_in) / len(fan_in)
vertex_output = self.vertex_op[t](vertex_input)
tensors.append(vertex_output)
if self.spec.matrix[t, self.num_vertices-1]:
out_concat.append(tensors[t])
if not out_concat:
assert self.spec.matrix[0, self.num_vertices-1]
outputs = self.input_op[self.num_vertices-1](tensors[0])
else:
if len(out_concat) == 1:
outputs = out_concat[0]
else:
outputs = torch.cat(out_concat, 1)
if self.spec.matrix[0, self.num_vertices-1]:
outputs += self.input_op[self.num_vertices-1](tensors[0])
#if self.spec.matrix[0, self.num_vertices-1]:
# out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
#outputs = sum(out_concat) / len(out_concat)
return outputs
def Projection(in_channels, out_channels, bn=True):
"""1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
return ConvBnRelu(in_channels, out_channels, 1, bn=bn)
def Truncate(inputs, channels):
"""Slice the inputs to channels if necessary."""
input_channels = inputs.size()[1]
if input_channels < channels:
raise ValueError('input channel < output channels for truncate')
elif input_channels == channels:
return inputs # No truncation necessary
else:
# Truncation should only be necessary when channel division leads to
# vertices with +1 channels. The input vertex should always be projected to
# the minimum channel count.
assert input_channels - channels == 1
return inputs[:, :channels, :, :]
def ComputeVertexChannels(in_channels, out_channels, matrix):
"""Computes the number of channels at every vertex.
Given the input channels and output channels, this calculates the number of
channels at each interior vertex. Interior vertices have the same number of
channels as the max of the channels of the vertices it feeds into. The output
channels are divided amongst the vertices that are directly connected to it.
When the division is not even, some vertices may receive an extra channel to
compensate.
Returns:
list of channel counts, in order of the vertices.
"""
num_vertices = np.shape(matrix)[0]
vertex_channels = [0] * num_vertices
vertex_channels[0] = in_channels
vertex_channels[num_vertices - 1] = out_channels
if num_vertices == 2:
# Edge case where module only has input and output vertices
return vertex_channels
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
in_degree = np.sum(matrix[1:], axis=0)
interior_channels = out_channels // in_degree[num_vertices - 1]
correction = out_channels % in_degree[num_vertices - 1] # Remainder to add
# Set channels of vertices that flow directly to output
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
vertex_channels[v] = interior_channels
if correction:
vertex_channels[v] += 1
correction -= 1
# Set channels for all other vertices to the max of the out edges, going
# backwards. (num_vertices - 2) index skipped because it only connects to
# output.
for v in range(num_vertices - 3, 0, -1):
if not matrix[v, num_vertices - 1]:
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
assert vertex_channels[v] > 0
# Sanity check, verify that channels never increase and final channels add up.
final_fan_in = 0
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
final_fan_in += vertex_channels[v]
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
assert vertex_channels[v] >= vertex_channels[dst]
assert final_fan_in == out_channels or num_vertices == 2
# num_vertices == 2 means only input/output nodes, so 0 fan-in
return vertex_channels

View File

@@ -0,0 +1,83 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Base operations used by the modules in this search space."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBnRelu(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bn=True):
super(ConvBnRelu, self).__init__()
if bn:
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=False)
)
else:
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.ReLU(inplace=False)
)
def forward(self, x):
return self.conv_bn_relu(x)
class Conv3x3BnRelu(nn.Module):
"""3x3 convolution with batch norm and ReLU activation."""
def __init__(self, in_channels, out_channels, bn=True):
super(Conv3x3BnRelu, self).__init__()
self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1, bn=bn)
def forward(self, x):
x = self.conv3x3(x)
return x
class Conv1x1BnRelu(nn.Module):
"""1x1 convolution with batch norm and ReLU activation."""
def __init__(self, in_channels, out_channels, bn=True):
super(Conv1x1BnRelu, self).__init__()
self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0, bn=bn)
def forward(self, x):
x = self.conv1x1(x)
return x
class MaxPool3x3(nn.Module):
"""3x3 max pool with no subsampling."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bn=None):
super(MaxPool3x3, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size, stride, padding)
def forward(self, x):
x = self.maxpool(x)
return x
# Commas should not be used in op names
OP_MAP = {
'conv3x3-bn-relu': Conv3x3BnRelu,
'conv1x1-bn-relu': Conv1x1BnRelu,
'maxpool3x3': MaxPool3x3
}

View File

@@ -0,0 +1,295 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Model specification for module connectivity individuals.
This module handles pruning the unused parts of the computation graph but should
avoid creating any TensorFlow models (this is done inside model_builder.py).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import hashlib
import itertools
import numpy as np
# Graphviz is optional and only required for visualization.
try:
import graphviz # pylint: disable=g-import-not-at-top
except ImportError:
pass
def _ToModelSpec(mat, ops):
return ModelSpec(mat, ops)
def gen_is_edge_fn(bits):
"""Generate a boolean function for the edge connectivity.
Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is
[[0, A, B, D],
[0, 0, C, E],
[0, 0, 0, F],
[0, 0, 0, 0]]
Note that this function is agnostic to the actual matrix dimension due to
order in which elements are filled out (column-major, starting from least
significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5
matrix is
[[0, A, B, D, 0],
[0, 0, C, E, 0],
[0, 0, 0, F, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
Args:
bits: integer which will be interpreted as a bit mask.
Returns:
vectorized function that returns True when an edge is present.
"""
def is_edge(x, y):
"""Is there an edge from x to y (0-indexed)?"""
if x >= y:
return 0
# Map x, y to index into bit string
index = x + (y * (y - 1) // 2)
return (bits >> index) % 2 == 1
return np.vectorize(is_edge)
def is_full_dag(matrix):
"""Full DAG == all vertices on a path from vert 0 to (V-1).
i.e. no disconnected or "hanging" vertices.
It is sufficient to check for:
1) no rows of 0 except for row V-1 (only output vertex has no out-edges)
2) no cols of 0 except for col 0 (only input vertex has no in-edges)
Args:
matrix: V x V upper-triangular adjacency matrix
Returns:
True if the there are no dangling vertices.
"""
shape = np.shape(matrix)
rows = matrix[:shape[0]-1, :] == 0
rows = np.all(rows, axis=1) # Any row with all 0 will be True
rows_bad = np.any(rows)
cols = matrix[:, 1:] == 0
cols = np.all(cols, axis=0) # Any col with all 0 will be True
cols_bad = np.any(cols)
return (not rows_bad) and (not cols_bad)
def num_edges(matrix):
"""Computes number of edges in adjacency matrix."""
return np.sum(matrix)
def hash_module(matrix, labeling):
"""Computes a graph-invariance MD5 hash of the matrix and label pair.
Args:
matrix: np.ndarray square upper-triangular adjacency matrix.
labeling: list of int labels of length equal to both dimensions of
matrix.
Returns:
MD5 hash of the matrix and labeling.
"""
vertices = np.shape(matrix)[0]
in_edges = np.sum(matrix, axis=0).tolist()
out_edges = np.sum(matrix, axis=1).tolist()
assert len(in_edges) == len(out_edges) == len(labeling)
hashes = list(zip(out_edges, in_edges, labeling))
hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]
# Computing this up to the diameter is probably sufficient but since the
# operation is fast, it is okay to repeat more times.
for _ in range(vertices):
new_hashes = []
for v in range(vertices):
in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]
out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]
new_hashes.append(hashlib.md5(
(''.join(sorted(in_neighbors)) + '|' +
''.join(sorted(out_neighbors)) + '|' +
hashes[v]).encode('utf-8')).hexdigest())
hashes = new_hashes
fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()
return fingerprint
def permute_graph(graph, label, permutation):
"""Permutes the graph and labels based on permutation.
Args:
graph: np.ndarray adjacency matrix.
label: list of labels of same length as graph dimensions.
permutation: a permutation list of ints of same length as graph dimensions.
Returns:
np.ndarray where vertex permutation[v] is vertex v from the original graph
"""
# vertex permutation[v] in new graph is vertex v in the old graph
forward_perm = zip(permutation, list(range(len(permutation))))
inverse_perm = [x[1] for x in sorted(forward_perm)]
edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1
new_matrix = np.fromfunction(np.vectorize(edge_fn),
(len(label), len(label)),
dtype=np.int8)
new_label = [label[inverse_perm[i]] for i in range(len(label))]
return new_matrix, new_label
def is_isomorphic(graph1, graph2):
"""Exhaustively checks if 2 graphs are isomorphic."""
matrix1, label1 = np.array(graph1[0]), graph1[1]
matrix2, label2 = np.array(graph2[0]), graph2[1]
assert np.shape(matrix1) == np.shape(matrix2)
assert len(label1) == len(label2)
vertices = np.shape(matrix1)[0]
# Note: input and output in our constrained graphs always map to themselves
# but this script does not enforce that.
for perm in itertools.permutations(range(0, vertices)):
pmatrix1, plabel1 = permute_graph(matrix1, label1, perm)
if np.array_equal(pmatrix1, matrix2) and plabel1 == label2:
return True
return False
class ModelSpec(object):
"""Model specification given adjacency matrix and labeling."""
def __init__(self, matrix, ops, data_format='channels_last'):
"""Initialize the module spec.
Args:
matrix: ndarray or nested list with shape [V, V] for the adjacency matrix.
ops: V-length list of labels for the base ops used. The first and last
elements are ignored because they are the input and output vertices
which have no operations. The elements are retained to keep consistent
indexing.
data_format: channels_last or channels_first.
Raises:
ValueError: invalid matrix or ops
"""
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix)
shape = np.shape(matrix)
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError('matrix must be square')
if shape[0] != len(ops):
raise ValueError('length of ops must match matrix dimensions')
if not is_upper_triangular(matrix):
raise ValueError('matrix must be upper triangular')
# Both the original and pruned matrices are deep copies of the matrix and
# ops so any changes to those after initialization are not recognized by the
# spec.
self.original_matrix = copy.deepcopy(matrix)
# print(self.original_matrix)
self.original_ops = copy.deepcopy(ops)
self.matrix = copy.deepcopy(matrix)
self.ops = copy.deepcopy(ops)
self.valid_spec = True
self._prune()
self.data_format = data_format
def _prune(self):
"""Prune the extraneous parts of the graph.
General procedure:
1) Remove parts of graph not connected to input.
2) Remove parts of graph not connected to output.
3) Reorder the vertices so that they are consecutive after steps 1 and 2.
These 3 steps can be combined by deleting the rows and columns of the
vertices that are not reachable from both the input and output (in reverse).
"""
num_vertices = np.shape(self.original_matrix)[0]
# DFS forward from input
visited_from_input = set([0])
frontier = [0]
while frontier:
top = frontier.pop()
for v in range(top + 1, num_vertices):
if self.original_matrix[top, v] and v not in visited_from_input:
visited_from_input.add(v)
frontier.append(v)
# DFS backward from output
visited_from_output = set([num_vertices - 1])
frontier = [num_vertices - 1]
while frontier:
top = frontier.pop()
for v in range(0, top):
if self.original_matrix[v, top] and v not in visited_from_output:
visited_from_output.add(v)
frontier.append(v)
# Any vertex that isn't connected to both input and output is extraneous to
# the computation graph.
extraneous = set(range(num_vertices)).difference(
visited_from_input.intersection(visited_from_output))
# If the non-extraneous graph is less than 2 vertices, the input is not
# connected to the output and the spec is invalid.
if len(extraneous) > num_vertices - 2:
self.matrix = None
self.ops = None
self.valid_spec = False
return
self.matrix = np.delete(self.matrix, list(extraneous), axis=0)
self.matrix = np.delete(self.matrix, list(extraneous), axis=1)
for index in sorted(extraneous, reverse=True):
del self.ops[index]
def hash_spec(self, canonical_ops):
"""Computes the isomorphism-invariant graph hash of this spec.
Args:
canonical_ops: list of operations in the canonical ordering which they
were assigned (i.e. the order provided in the config['available_ops']).
Returns:
MD5 hash of this spec which can be used to query the dataset.
"""
# Invert the operations back to integer label indices used in graph gen.
labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2]
return graph_util.hash_module(self.matrix, labeling)
def visualize(self):
"""Creates a dot graph. Can be visualized in colab directly."""
num_vertices = np.shape(self.matrix)[0]
g = graphviz.Digraph()
g.node(str(0), 'input')
for v in range(1, num_vertices - 1):
g.node(str(v), self.ops[v])
g.node(str(num_vertices - 1), 'output')
for src in range(num_vertices - 1):
for dst in range(src + 1, num_vertices):
if self.matrix[src, dst]:
g.edge(str(src), str(dst))
return g
def is_upper_triangular(matrix):
"""True if matrix is 0 on diagonal and below."""
for src in range(np.shape(matrix)[0]):
for dst in range(0, src + 1):
if matrix[src, dst] != 0:
return False
return True

View File

@@ -0,0 +1,140 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nasbench2_ops import *
def gen_searchcell_mask_from_arch_str(arch_str):
nodes = arch_str.split('+')
nodes = [node[1:-1].split('|') for node in nodes]
nodes = [[op_and_input.split('~') for op_and_input in node] for node in nodes]
keep_mask = []
for curr_node_idx in range(len(nodes)):
for prev_node_idx in range(curr_node_idx+1):
_op = [edge[0] for edge in nodes[curr_node_idx] if int(edge[1]) == prev_node_idx]
assert len(_op) == 1, 'The arch string does not follow the assumption of 1 connection between two nodes.'
for _op_name in OPS.keys():
keep_mask.append(_op[0] == _op_name)
return keep_mask
def get_model_from_arch_str(arch_str, num_classes, use_bn=True, init_channels=16):
keep_mask = gen_searchcell_mask_from_arch_str(arch_str)
net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn, keep_mask=keep_mask, stem_ch=init_channels)
return net
def get_super_model(num_classes, use_bn=True):
net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn)
return net
class NAS201Model(nn.Module):
def __init__(self, arch_str, num_classes, use_bn=True, keep_mask=None, stem_ch=16):
super(NAS201Model, self).__init__()
self.arch_str=arch_str
self.num_classes=num_classes
self.use_bn= use_bn
self.stem = stem(out_channels=stem_ch, use_bn=use_bn)
self.stack_cell1 = nn.Sequential(*[SearchCell(in_channels=stem_ch, out_channels=stem_ch, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
self.reduction1 = reduction(in_channels=stem_ch, out_channels=stem_ch*2)
self.stack_cell2 = nn.Sequential(*[SearchCell(in_channels=stem_ch*2, out_channels=stem_ch*2, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
self.reduction2 = reduction(in_channels=stem_ch*2, out_channels=stem_ch*4)
self.stack_cell3 = nn.Sequential(*[SearchCell(in_channels=stem_ch*4, out_channels=stem_ch*4, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
# self.top = top(in_dims=stem_ch*4, num_classes=num_classes, use_bn=use_bn)
self.top = top(in_dims=stem_ch*4, use_bn=use_bn)
self.classifier = nn.Linear(stem_ch*4, num_classes)
self.pre_GAP = nn.Sequential(nn.BatchNorm2d(stem_ch * 4), nn.ReLU(inplace=True))
def forward(self, x):
x = self.stem(x)
x = self.stack_cell1(x)
x = self.reduction1(x)
x = self.stack_cell2(x)
x = self.reduction2(x)
x = self.stack_cell3(x)
x = self.top(x)
x = self.classifier(x)
return x
def forward_pre_GAP(self, x):
x = self.stem(x)
x = self.stack_cell1(x)
x = self.reduction1(x)
x = self.stack_cell2(x)
x = self.reduction2(x)
x = self.stack_cell3(x)
x = self.pre_GAP(x)
return x
def get_prunable_copy(self, bn=False):
model_new = get_model_from_arch_str(self.arch_str, self.num_classes, use_bn=bn)
#TODO this is quite brittle and doesn't work with nn.Sequential when bn is different
# it is only required to maintain initialization -- maybe init after get_punable_copy?
model_new.load_state_dict(self.state_dict(), strict=False)
model_new.train()
return model_new
def get_arch_str_from_model(net):
search_cell = net.stack_cell1[0].options
keep_mask = net.stack_cell1[0].keep_mask
num_nodes = net.stack_cell1[0].num_nodes
nodes = []
idx = 0
for curr_node in range(num_nodes -1):
edges = []
for prev_node in range(curr_node+1): # n-1 prev nodes
for _op_name in OPS.keys():
if keep_mask[idx]:
edges.append(f'{_op_name}~{prev_node}')
idx += 1
node_str = '|'.join(edges)
node_str = f'|{node_str}|'
nodes.append(node_str)
arch_str = '+'.join(nodes)
return arch_str
if __name__ == "__main__":
arch_str = '|nor_conv_3x3~0|+|none~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'
n = get_model_from_arch_str(arch_str=arch_str, num_classes=10)
print(n.stack_cell1[0])
arch_str2 = get_arch_str_from_model(n)
print(arch_str)
print(arch_str2)
print(f'Are the two arch strings same? {arch_str == arch_str2}')

View File

@@ -0,0 +1,166 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
class ReLUConvBN(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, affine, track_running_stats=True, use_bn=True, name='ReLUConvBN'):
super(ReLUConvBN, self).__init__()
self.name = name
if use_bn:
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
nn.BatchNorm2d(out_channels, affine=affine, track_running_stats=track_running_stats)
)
else:
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine)
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def __init__(self, name='Identity'):
self.name = name
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride, name='Zero'):
self.name = name
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:,:,::self.stride,::self.stride].mul(0.)
class POOLING(nn.Module):
def __init__(self, kernel_size, stride, padding, name='POOLING'):
super(POOLING, self).__init__()
self.name = name
self.avgpool = nn.AvgPool2d(kernel_size=kernel_size, stride=1, padding=1, count_include_pad=False)
def forward(self, x):
return self.avgpool(x)
class reduction(nn.Module):
def __init__(self, in_channels, out_channels):
super(reduction, self).__init__()
self.residual = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False))
self.conv_a = ReLUConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, dilation=1, affine=True, track_running_stats=True)
self.conv_b = ReLUConvBN(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, dilation=1, affine=True, track_running_stats=True)
def forward(self, x):
basicblock = self.conv_a(x)
basicblock = self.conv_b(basicblock)
residual = self.residual(x)
return residual + basicblock
class stem(nn.Module):
def __init__(self, out_channels, use_bn=True):
super(stem, self).__init__()
if use_bn:
self.net = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels))
else:
self.net = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=out_channels, kernel_size=3, padding=1, bias=False)
)
def forward(self, x):
return self.net(x)
class top(nn.Module):
# def __init__(self, in_dims, num_classes, use_bn=True):
def __init__(self, in_dims, use_bn=True):
super(top, self).__init__()
if use_bn:
self.lastact = nn.Sequential(nn.BatchNorm2d(in_dims), nn.ReLU(inplace=True))
else:
self.lastact = nn.ReLU(inplace=True)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
# self.classifier = nn.Linear(in_dims, num_classes)
def forward(self, x):
x = self.lastact(x)
x = self.global_pooling(x)
x = x.view(x.size(0), -1)
# logits = self.classifier(x)
# return logits
return x
class SearchCell(nn.Module):
def __init__(self, in_channels, out_channels, stride, affine, track_running_stats, use_bn=True, num_nodes=4, keep_mask=None):
super(SearchCell, self).__init__()
self.num_nodes = num_nodes
self.options = nn.ModuleList()
for curr_node in range(self.num_nodes-1):
for prev_node in range(curr_node+1):
for _op_name in OPS.keys():
op = OPS[_op_name](in_channels, out_channels, stride, affine, track_running_stats, use_bn)
self.options.append(op)
if keep_mask is not None:
self.keep_mask = keep_mask
else:
self.keep_mask = [True]*len(self.options)
def forward(self, x):
outs = [x]
idx = 0
for curr_node in range(self.num_nodes-1):
edges_in = []
for prev_node in range(curr_node+1): # n-1 prev nodes
for op_idx in range(len(OPS.keys())):
if self.keep_mask[idx]:
edges_in.append(self.options[idx](outs[prev_node]))
idx += 1
node_output = sum(edges_in)
outs.append(node_output)
return outs[-1]
OPS = {
'none' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: Zero(stride, name='none'),
'avg_pool_3x3' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: POOLING(3, 1, 1, name='avg_pool_3x3'),
'nor_conv_3x3' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: ReLUConvBN(in_channels, out_channels, 3, 1, 1, 1, affine, track_running_stats, use_bn, name='nor_conv_3x3'),
'nor_conv_1x1' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: ReLUConvBN(in_channels, out_channels, 1, 1, 0, 1, affine, track_running_stats, use_bn, name='nor_conv_1x1'),
'skip_connect' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: Identity(name='skip_connect'),
}

View File

@@ -0,0 +1,19 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from os.path import dirname, basename, isfile, join
import glob
modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]

View File

@@ -0,0 +1,69 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
available_measures = []
_measure_impls = {}
def measure(name, bn=True, copy_net=True, force_clean=True, **impl_args):
def make_impl(func):
def measure_impl(net_orig, device, *args, **kwargs):
if copy_net:
net = net_orig.get_prunable_copy(bn=bn).to(device)
else:
net = net_orig
ret = func(net, *args, **kwargs, **impl_args)
if copy_net and force_clean:
import gc
import torch
del net
torch.cuda.empty_cache()
gc.collect()
return ret
global _measure_impls
if name in _measure_impls:
raise KeyError(f'Duplicated measure! {name}')
available_measures.append(name)
_measure_impls[name] = measure_impl
return func
return make_impl
def calc_measure(name, net, device, *args, **kwargs):
return _measure_impls[name](net, device, *args, **kwargs)
def load_all():
# from . import grad_norm
# from . import snip
# from . import grasp
# from . import fisher
# from . import jacob_cov
# from . import plain
# from . import synflow
# from . import var
# from . import cor
# from . import norm
from . import meco
# from . import zico
# from . import gradsign
# from . import ntk
# from . import zen
# TODO: should we do that by default?
load_all()

View File

@@ -0,0 +1,53 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
def forward_hook(module, data_input, data_output):
corr = np.mean(np.corrcoef(data_input[0].detach().cpu().numpy()))
result_list.append(corr)
net.classifier.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
cor = result_list[0].item()
result_list.clear()
return cor
@measure('cor', bn=True)
def compute_norm(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
cor= get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
cor= np.nan
return cor

View File

@@ -0,0 +1,67 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import copy
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
from torch import nn
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
result_t = []
def forward_hook(module, data_input, data_output):
s = time.time()
fea = data_output[0].detach().cpu().numpy()
fea = fea.reshape(fea.shape[0], -1)
result = 1 / np.var(np.corrcoef(fea))
e = time.time()
t = e - s
result_list.append(result)
result_t.append(t)
for name, modules in net.named_modules():
modules.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
results = np.array(result_list)
results = results[np.logical_not(np.isnan(results))]
v = np.sum(results)
t = sum(result_t)
result_list.clear()
result_t.clear()
return v, t
@measure('cova', bn=True)
def compute_cova(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
cova, t = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
cova, t = np.nan, None
return cova, t

View File

@@ -0,0 +1,107 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import types
from . import measure
from ..p_utils import get_layer_metric_array, reshape_elements
def fisher_forward_conv2d(self, x):
x = F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
#intercept and store the activations after passing through 'hooked' identity op
self.act = self.dummy(x)
return self.act
def fisher_forward_linear(self, x):
x = F.linear(x, self.weight, self.bias)
self.act = self.dummy(x)
return self.act
@measure('fisher', bn=True, mode='channel')
def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1):
device = inputs.device
if mode == 'param':
raise ValueError('Fisher pruning does not support parameter pruning.')
net.train()
all_hooks = []
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
#variables/op needed for fisher computation
layer.fisher = None
layer.act = 0.
layer.dummy = nn.Identity()
#replace forward method of conv/linear
if isinstance(layer, nn.Conv2d):
layer.forward = types.MethodType(fisher_forward_conv2d, layer)
if isinstance(layer, nn.Linear):
layer.forward = types.MethodType(fisher_forward_linear, layer)
#function to call during backward pass (hooked on identity op at output of layer)
def hook_factory(layer):
def hook(module, grad_input, grad_output):
act = layer.act.detach()
grad = grad_output[0].detach()
if len(act.shape) > 2:
g_nk = torch.sum((act * grad), list(range(2,len(act.shape))))
else:
g_nk = act * grad
del_k = g_nk.pow(2).mean(0).mul(0.5)
if layer.fisher is None:
layer.fisher = del_k
else:
layer.fisher += del_k
del layer.act #without deleting this, a nasty memory leak occurs! related: https://discuss.pytorch.org/t/memory-leak-when-using-forward-hook-and-backward-hook-simultaneously/27555
return hook
#register backward hook on identity fcn to compute fisher info
layer.dummy.register_backward_hook(hook_factory(layer))
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
net.zero_grad()
outputs = net(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
# retrieve fisher info
def fisher(layer):
if layer.fisher is not None:
return torch.abs(layer.fisher.detach())
else:
return torch.zeros(layer.weight.shape[0]) #size=ch
grads_abs_ch = get_layer_metric_array(net, fisher, mode)
#broadcast channel value here to all parameters in that channel
#to be compatible with stuff downstream (which expects per-parameter metrics)
#TODO cleanup on the selectors/apply_prune_mask side (?)
shapes = get_layer_metric_array(net, lambda l : l.weight.shape[1:], mode)
grads_abs = reshape_elements(grads_abs_ch, shapes, device)
return grads_abs

View File

@@ -0,0 +1,38 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn.functional as F
import copy
from . import measure
from ..p_utils import get_layer_metric_array
@measure('grad_norm', bn=True)
def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=False):
net.zero_grad()
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
grad_norm_arr = get_layer_metric_array(net, lambda l: l.weight.grad.norm() if l.weight.grad is not None else torch.zeros_like(l.weight), mode='param')
return grad_norm_arr

View File

@@ -0,0 +1,76 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
from torch import nn
import numpy as np
from . import measure
def get_flattened_metric(net, metric):
grad_list = []
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
grad_list.append(metric(layer).flatten())
flattened_grad = np.concatenate(grad_list)
return flattened_grad
def get_grad_conflict(net, inputs, targets, loss_fn):
N = inputs.shape[0]
batch_grad = []
for i in range(N):
net.zero_grad()
outputs = net.forward(inputs[[i]])
loss = loss_fn(outputs, targets[[i]])
loss.backward()
flattened_grad = get_flattened_metric(net, lambda
l: l.weight.grad.data.clone().cpu().numpy() if l.weight.grad is not None else torch.zeros_like(
l.weight).clone().cpu().numpy())
batch_grad.append(flattened_grad)
batch_grad = np.stack(batch_grad)
direction_code = np.sign(batch_grad)
direction_code = abs(direction_code.sum(axis=0))
score = np.nansum(direction_code)
return score
def get_gradsign(input, target, net, device, loss_fn):
s = []
net = net.to(device)
x, target = input, target
# x2 = torch.clone(x)
# x2 = x2.to(device)
x, target = x.to(device), target.to(device)
s.append(get_grad_conflict(net=net, inputs=x, targets=target, loss_fn=loss_fn))
s = np.mean(s)
return s
@measure('gradsign', bn=True)
def compute_gradsign(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
gradsign = get_gradsign(inputs, targets, net, device, loss_fn)
except Exception as e:
print(e)
gradsign= np.nan
return gradsign

View File

@@ -0,0 +1,87 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from . import measure
from ..p_utils import get_layer_metric_array
@measure('grasp', bn=True, mode='param')
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
# get all applicable weights
weights = []
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
weights.append(layer.weight)
layer.weight.requires_grad_(True) # TODO isn't this already true?
# NOTE original code had some input/target splitting into 2
# I am guessing this was because of GPU mem limit
net.zero_grad()
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
#forward/grad pass #1
grad_w = None
for _ in range(num_iters):
#TODO get new data, otherwise num_iters is useless!
outputs = net.forward(inputs[st:en])/T
loss = loss_fn(outputs, targets[st:en])
grad_w_p = autograd.grad(loss, weights, allow_unused=True)
if grad_w is None:
grad_w = list(grad_w_p)
else:
for idx in range(len(grad_w)):
grad_w[idx] += grad_w_p[idx]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
# forward/grad pass #2
outputs = net.forward(inputs[st:en])/T
loss = loss_fn(outputs, targets[st:en])
grad_f = autograd.grad(loss, weights, create_graph=True, allow_unused=True)
# accumulate gradients computed in previous step and call backwards
z, count = 0,0
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
if grad_w[count] is not None:
z += (grad_w[count].data * grad_f[count]).sum()
count += 1
z.backward()
# compute final sensitivity metric and put in grads
def grasp(layer):
if layer.weight.grad is not None:
return -layer.weight.data * layer.weight.grad # -theta_q Hg
#NOTE in the grasp code they take the *bottom* (1-p)% of values
#but we take the *top* (1-p)%, therefore we remove the -ve sign
#EDIT accuracy seems to be negatively correlated with this metric, so we add -ve sign here!
else:
return torch.zeros_like(layer.weight)
grads = get_layer_metric_array(net, grasp, mode)
return grads

View File

@@ -0,0 +1,57 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import numpy as np
from . import measure
def get_batch_jacobian(net, x, target, device, split_data):
x.requires_grad_(True)
N = x.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
y = net(x[st:en])
y.backward(torch.ones_like(y))
jacob = x.grad.detach()
x.requires_grad_(False)
return jacob, target.detach()
def eval_score(jacob, labels=None):
corrs = np.corrcoef(jacob)
v, _ = np.linalg.eig(corrs)
k = 1e-5
return -np.sum(np.log(v + k) + 1./(v + k))
@measure('jacob_cov', bn=True)
def compute_jacob_cov(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
jacobs, labels = get_batch_jacobian(net, inputs, targets, device, split_data=split_data)
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()
try:
jc = eval_score(jacobs, labels)
except Exception as e:
print(e)
jc = np.nan
return jc

View File

@@ -0,0 +1,22 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from . import measure
from ..p_utils import get_layer_metric_array
@measure('l2_norm', copy_net=False, mode='param')
def get_l2_norm_array(net, inputs, targets, mode, split_data=1):
return get_layer_metric_array(net, lambda l: l.weight.norm(), mode=mode)

View File

@@ -0,0 +1,63 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
def forward_hook(module, data_input, data_output):
s = time.time()
mean = torch.mean(data_input[0])
e = time.time()
t = e - s
result_list.append(mean)
result_list.append(t)
net.classifier.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
# t1 = time.time()
y = net(x[st:en])
# t2 = time.time()
# print('var:', t2-t1)
m = result_list[0].item()
t = result_list[1]
result_list.clear()
return m, t
@measure('mean', bn=True)
def compute_mean(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', features.shape)
try:
mean, t = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
mean, t = np.nan, None
# print(jc)
# print(f'var time: {t} s')
return mean, t

View File

@@ -0,0 +1,73 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import copy
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from torch import nn
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
x = torch.randn(size=(1, 3, 64, 64)).to(device)
net.to(device)
def forward_hook(module, data_input, data_output):
fea = data_output[0].detach()
fea = fea.reshape(fea.shape[0], -1)
n = fea.shape[0]
corr = torch.corrcoef(fea)
corr[torch.isnan(corr)] = 0
corr[torch.isinf(corr)] = 0
values = torch.linalg.eig(corr)[0]
# result = np.real(np.min(values)) / np.real(np.max(values))
result = torch.min(torch.real(values))
result_list.append(result)
for name, modules in net.named_modules():
modules.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
# break
results = torch.tensor(result_list)
results = results[torch.logical_not(torch.isnan(results))]
v = torch.sum(results)
result_list.clear()
return v.item()
@measure('meco', bn=True)
def compute_meco(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
meco = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
meco = np.nan, None
return meco

View File

@@ -0,0 +1,55 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
def forward_hook(module, data_input, data_output):
norm = torch.norm(data_input[0])
result_list.append(norm)
net.classifier.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
n = result_list[0].item()
result_list.clear()
return n
@measure('norm', bn=True)
def compute_norm(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', feature.shape)
try:
norm, t = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
norm, t = np.nan, None
# print(jc)
# print(f'norm time: {t} s')
return norm, t

View File

@@ -0,0 +1,94 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import numpy as np
from . import measure
def recal_bn(network, inputs, targets, recalbn, device):
for m in network.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.running_mean.data.fill_(0)
m.running_var.data.fill_(0)
m.num_batches_tracked.data.zero_()
m.momentum = None
network.train()
with torch.no_grad():
for i, (inputs, targets) in enumerate(zip(inputs, targets)):
if i >= recalbn: break
inputs = inputs.cuda(device=device, non_blocking=True)
_, _ = network(inputs)
return network
def get_ntk_n(inputs, targets, network, device, recalbn=0, train_mode=False, num_batch=1):
device = device
# if recalbn > 0:
# network = recal_bn(network, xloader, recalbn, device)
# if network_2 is not None:
# network_2 = recal_bn(network_2, xloader, recalbn, device)
network.eval()
networks = []
networks.append(network)
ntks = []
# if train_mode:
# networks.train()
# else:
# networks.eval()
######
grads = [[] for _ in range(len(networks))]
for i in range(num_batch):
if num_batch > 0 and i >= num_batch: break
inputs = inputs.cuda(device=device, non_blocking=True)
for net_idx, network in enumerate(networks):
network.zero_grad()
# print(inputs.size())
inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
logit = network(inputs_)
if isinstance(logit, tuple):
logit = logit[1] # 201 networks: return features and logits
for _idx in range(len(inputs_)):
logit[_idx:_idx + 1].backward(torch.ones_like(logit[_idx:_idx + 1]), retain_graph=True)
grad = []
for name, W in network.named_parameters():
if 'weight' in name and W.grad is not None:
grad.append(W.grad.view(-1).detach())
grads[net_idx].append(torch.cat(grad, -1))
network.zero_grad()
torch.cuda.empty_cache()
######
grads = [torch.stack(_grads, 0) for _grads in grads]
ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
for ntk in ntks:
eigenvalues, _ = torch.linalg.eigh(ntk) # ascending
conds = np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0)
return conds
@measure('ntk', bn=True)
def compute_ntk(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
conds = get_ntk_n(inputs, targets, net, device)
except Exception as e:
print(e)
conds= np.nan
return conds

View File

@@ -0,0 +1,16 @@
import time
import torch
from . import measure
from ..p_utils import get_layer_metric_array
@measure('param_count', copy_net=False, mode='param')
def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1):
s = time.time()
count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode)
e = time.time()
t = e - s
# print(f'param_count time: {t} s')
return count, t

View File

@@ -0,0 +1,71 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import copy
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
from torch import nn
# import pandas as pd
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
result_t = []
def forward_hook(module, data_input, data_output):
s = time.time()
fea = data_output[0].detach().cpu().numpy()
fea = fea.reshape(fea.shape[0], -1)
# result = 1 / np.var(np.corrcoef(fea))
result = np.var(np.corrcoef(fea))
e = time.time()
t = e - s
result_list.append(result)
result_t.append(t)
for name, modules in net.named_modules():
modules.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
# print(y)
results = np.array(result_list)
results = results[np.logical_not(np.isnan(results))]
v = np.sum(results)
t = sum(result_t)
result_list.clear()
result_t.clear()
return v, t
@measure('pearson', bn=True)
def compute_pearson(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
pearson, t = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
pearson, t = np.nan, None
return pearson, t

View File

@@ -0,0 +1,44 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn.functional as F
from . import measure
from ..p_utils import get_layer_metric_array
@measure('plain', bn=True, mode='param')
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
net.zero_grad()
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
# select the gradients that we want to use for search/prune
def plain(layer):
if layer.weight.grad is not None:
return layer.weight.grad * layer.weight
else:
return torch.zeros_like(layer.weight)
grads_abs = get_layer_metric_array(net, plain, mode)
return grads_abs

View File

@@ -0,0 +1,69 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import types
from . import measure
from ..p_utils import get_layer_metric_array
def snip_forward_conv2d(self, x):
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
self.stride, self.padding, self.dilation, self.groups)
def snip_forward_linear(self, x):
return F.linear(x, self.weight * self.weight_mask, self.bias)
@measure('snip', bn=True, mode='param')
def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
layer.weight.requires_grad = False
# Override the forward methods:
if isinstance(layer, nn.Conv2d):
layer.forward = types.MethodType(snip_forward_conv2d, layer)
if isinstance(layer, nn.Linear):
layer.forward = types.MethodType(snip_forward_linear, layer)
# Compute gradients (but don't apply them)
net.zero_grad()
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
# select the gradients that we want to use for search/prune
def snip(layer):
if layer.weight_mask.grad is not None:
return torch.abs(layer.weight_mask.grad)
else:
return torch.zeros_like(layer.weight)
grads_abs = get_layer_metric_array(net, snip, mode)
return grads_abs

View File

@@ -0,0 +1,69 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
from . import measure
from ..p_utils import get_layer_metric_array
@measure('synflow', bn=False, mode='param')
@measure('synflow_bn', bn=True, mode='param')
def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn=None):
device = inputs.device
#convert params to their abs. Keep sign for converting it back.
@torch.no_grad()
def linearize(net):
signs = {}
for name, param in net.state_dict().items():
signs[name] = torch.sign(param)
param.abs_()
return signs
#convert to orig values
@torch.no_grad()
def nonlinearize(net, signs):
for name, param in net.state_dict().items():
if 'weight_mask' not in name:
param.mul_(signs[name])
# keep signs of all params
signs = linearize(net)
# Compute gradients with input of 1s
net.zero_grad()
net.double()
input_dim = list(inputs[0,:].shape)
inputs = torch.ones([1] + input_dim).double().to(device)
output = net.forward(inputs)
torch.sum(output).backward()
# select the gradients that we want to use for search/prune
def synflow(layer):
if layer.weight.grad is not None:
return torch.abs(layer.weight * layer.weight.grad)
else:
return torch.zeros_like(layer.weight)
grads_abs = get_layer_metric_array(net, synflow, mode)
# apply signs of all params
nonlinearize(net, signs)
return grads_abs

View File

@@ -0,0 +1,55 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
def forward_hook(module, data_input, data_output):
var = torch.var(data_input[0])
result_list.append(var)
net.classifier.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
v = result_list[0].item()
result_list.clear()
return v
@measure('var', bn=True)
def compute_var(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', feature.shape)
try:
var= get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
var= np.nan
# print(jc)
# print(f'var time: {t} s')
return var

View File

@@ -0,0 +1,110 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
from torch import nn
import numpy as np
from . import measure
def network_weight_gaussian_init(net: nn.Module):
with torch.no_grad():
for n, m in net.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
try:
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
except:
pass
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
else:
continue
return net
def get_zen(gpu, model, mixup_gamma=1e-2, resolution=32, batch_size=64, repeat=32,
fp16=False):
info = {}
nas_score_list = []
if gpu is not None:
device = torch.device(gpu)
else:
device = torch.device('cpu')
if fp16:
dtype = torch.half
else:
dtype = torch.float32
with torch.no_grad():
for repeat_count in range(repeat):
network_weight_gaussian_init(model)
input = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype)
input2 = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype)
mixup_input = input + mixup_gamma * input2
output = model.forward_pre_GAP(input)
mixup_output = model.forward_pre_GAP(mixup_input)
nas_score = torch.sum(torch.abs(output - mixup_output), dim=[1, 2, 3])
nas_score = torch.mean(nas_score)
# compute BN scaling
log_bn_scaling_factor = 0.0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
try:
bn_scaling_factor = torch.sqrt(torch.mean(m.running_var))
log_bn_scaling_factor += torch.log(bn_scaling_factor)
except:
pass
pass
pass
nas_score = torch.log(nas_score) + log_bn_scaling_factor
nas_score_list.append(float(nas_score))
std_nas_score = np.std(nas_score_list)
avg_precision = 1.96 * std_nas_score / np.sqrt(len(nas_score_list))
avg_nas_score = np.mean(nas_score_list)
info = float(avg_nas_score)
return info
@measure('zen', bn=True)
def compute_zen(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
zen = get_zen(device,net)
except Exception as e:
print(e)
zen= np.nan
return zen

View File

@@ -0,0 +1,106 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
from torch import nn
from ...dataset import get_cifar_dataloaders
def getgrad(model: torch.nn.Module, grad_dict: dict, step_iter=0):
if step_iter == 0:
for name, mod in model.named_modules():
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
# print(mod.weight.grad.data.size())
# print(mod.weight.data.size())
try:
grad_dict[name] = [mod.weight.grad.data.cpu().reshape(-1).numpy()]
except:
continue
else:
for name, mod in model.named_modules():
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
try:
grad_dict[name].append(mod.weight.grad.data.cpu().reshape(-1).numpy())
except:
continue
return grad_dict
def caculate_zico(grad_dict):
allgrad_array = None
for i, modname in enumerate(grad_dict.keys()):
grad_dict[modname] = np.array(grad_dict[modname])
nsr_mean_sum = 0
nsr_mean_sum_abs = 0
nsr_mean_avg = 0
nsr_mean_avg_abs = 0
for j, modname in enumerate(grad_dict.keys()):
nsr_std = np.std(grad_dict[modname], axis=0)
# print(grad_dict[modname].shape)
# print(grad_dict[modname].shape, nsr_std.shape)
nonzero_idx = np.nonzero(nsr_std)[0]
nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0)
tmpsum = np.sum(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx])
if tmpsum == 0:
pass
else:
nsr_mean_sum_abs += np.log(tmpsum)
nsr_mean_avg_abs += np.log(np.mean(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx]))
return nsr_mean_sum_abs
def getzico(network, inputs, targets, loss_fn, split_data=2):
grad_dict = {}
network.train()
device = inputs.device
network.to(device)
N = inputs.shape[0]
split_data = 2
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
outputs = network.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
grad_dict = getgrad(network, grad_dict, sp)
# print(grad_dict)
res = caculate_zico(grad_dict)
return res
@measure('zico', bn=True)
def compute_zico(net, inputs, targets, split_data=2, loss_fn=None):
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', feature.shape)
try:
zico = getzico(net, inputs, targets, loss_fn, split_data=split_data)
except Exception as e:
print(e)
zico= np.nan
# print(jc)
# print(f'var time: {t} s')
return zico

View File

@@ -0,0 +1,83 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..models import *
def get_some_data(train_dataloader, num_batches, device):
traindata = []
dataloader_iter = iter(train_dataloader)
for _ in range(num_batches):
traindata.append(next(dataloader_iter))
inputs = torch.cat([a for a,_ in traindata])
targets = torch.cat([b for _,b in traindata])
inputs = inputs.to(device)
targets = targets.to(device)
return inputs, targets
def get_some_data_grasp(train_dataloader, num_classes, samples_per_class, device):
datas = [[] for _ in range(num_classes)]
labels = [[] for _ in range(num_classes)]
mark = dict()
dataloader_iter = iter(train_dataloader)
while True:
inputs, targets = next(dataloader_iter)
for idx in range(inputs.shape[0]):
x, y = inputs[idx:idx+1], targets[idx:idx+1]
category = y.item()
if len(datas[category]) == samples_per_class:
mark[category] = True
continue
datas[category].append(x)
labels[category].append(y)
if len(mark) == num_classes:
break
x = torch.cat([torch.cat(_, 0) for _ in datas]).to(device)
y = torch.cat([torch.cat(_) for _ in labels]).view(-1).to(device)
return x, y
def get_layer_metric_array(net, metric, mode):
metric_array = []
for layer in net.modules():
if mode=='channel' and hasattr(layer,'dont_ch_prune'):
continue
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
metric_array.append(metric(layer))
return metric_array
def reshape_elements(elements, shapes, device):
def broadcast_val(elements, shapes):
ret_grads = []
for e,sh in zip(elements, shapes):
ret_grads.append(torch.stack([torch.Tensor(sh).fill_(v) for v in e], dim=0).to(device))
return ret_grads
if type(elements[0]) == list:
outer = []
for e,sh in zip(elements, shapes):
outer.append(broadcast_val(e,sh))
return outer
else:
return broadcast_val(elements, shapes)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

View File

@@ -0,0 +1,116 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from .p_utils import *
from . import measures
import types
import copy
def no_op(self,x):
return x
def copynet(self, bn):
net = copy.deepcopy(self)
if bn==False:
for l in net.modules():
if isinstance(l,nn.BatchNorm2d) or isinstance(l,nn.BatchNorm1d) :
l.forward = types.MethodType(no_op, l)
return net
def find_measures_arrays(net_orig, trainloader, dataload_info, device, measure_names=None, loss_fn=F.cross_entropy):
if measure_names is None:
measure_names = measures.available_measures
dataload, num_imgs_or_batches, num_classes = dataload_info
if not hasattr(net_orig,'get_prunable_copy'):
net_orig.get_prunable_copy = types.MethodType(copynet, net_orig)
#move to cpu to free up mem
torch.cuda.empty_cache()
net_orig = net_orig.cpu()
torch.cuda.empty_cache()
#given 1 minibatch of data
if dataload == 'random':
inputs, targets = get_some_data(trainloader, num_batches=num_imgs_or_batches, device=device)
elif dataload == 'grasp':
inputs, targets = get_some_data_grasp(trainloader, num_classes, samples_per_class=num_imgs_or_batches, device=device)
else:
raise NotImplementedError(f'dataload {dataload} is not supported')
done, ds = False, 1
measure_values = {}
while not done:
try:
for measure_name in measure_names:
if measure_name not in measure_values:
val = measures.calc_measure(measure_name, net_orig, device, inputs, targets, loss_fn=loss_fn, split_data=ds)
measure_values[measure_name] = val
done = True
except RuntimeError as e:
if 'out of memory' in str(e):
done=False
if ds == inputs.shape[0]//2:
raise ValueError(f'Can\'t split data anymore, but still unable to run. Something is wrong')
ds += 1
while inputs.shape[0] % ds != 0:
ds += 1
torch.cuda.empty_cache()
print(f'Caught CUDA OOM, retrying with data split into {ds} parts')
else:
raise e
net_orig = net_orig.to(device).train()
return measure_values
def find_measures(net_orig, # neural network
dataloader, # a data loader (typically for training data)
dataload_info, # a tuple with (dataload_type = {random, grasp}, number_of_batches_for_random_or_images_per_class_for_grasp, number of classes)
device, # GPU/CPU device used
loss_fn=F.cross_entropy, # loss function to use within the zero-cost metrics
measure_names=None, # an array of measure names to compute, if left blank, all measures are computed by default
measures_arr=None): # [not used] if the measures are already computed but need to be summarized, pass them here
#Given a neural net
#and some information about the input data (dataloader)
#and loss function (loss_fn)
#this function returns an array of zero-cost proxy metrics.
def sum_arr(arr):
sum = 0.
for i in range(len(arr)):
sum += torch.sum(arr[i])
return sum.item()
if measures_arr is None:
measures_arr = find_measures_arrays(net_orig, dataloader, dataload_info, device, loss_fn=loss_fn, measure_names=measure_names)
measures = {}
for k,v in measures_arr.items():
if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico', 'ntk', 'gradsign', 'zen']:
measures[k] = v
else:
measures[k] = sum_arr(v)
return measures

View File

@@ -0,0 +1,51 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
version = '1.0.0'
repo = 'unknown'
commit = 'unknown'
has_repo = False
try:
import git
from pathlib import Path
try:
r = git.Repo(Path(__file__).parents[1])
has_repo = True
if not r.remotes:
repo = 'local'
else:
repo = r.remotes.origin.url
commit = r.head.commit.hexsha
if r.is_dirty():
commit += ' (dirty)'
except git.InvalidGitRepositoryError:
raise ImportError()
except ImportError:
pass
try:
from . import _dist_info as info
assert not has_repo, '_dist_info should not exist when repo is in place'
assert version == info.version
repo = info.repo
commit = info.commit
except (ImportError, SystemError):
pass
__all__ = ['version', 'repo', 'commit', 'has_repo']

View File

@@ -0,0 +1,84 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch.nn as nn
def init_net(net, w_type, b_type):
if w_type == 'none':
pass
elif w_type == 'xavier':
net.apply(init_weights_vs)
elif w_type == 'kaiming':
net.apply(init_weights_he)
elif w_type == 'zero':
net.apply(init_weights_zero)
elif w_type == 'one':
net.apply(init_weights_one)
else:
raise NotImplementedError(f'init_type={w_type} is not supported.')
if b_type == 'none':
pass
elif b_type == 'xavier':
net.apply(init_bias_vs)
elif b_type == 'kaiming':
net.apply(init_bias_he)
elif b_type == 'zero':
net.apply(init_bias_zero)
elif b_type == 'one':
net.apply(init_bias_one)
else:
raise NotImplementedError(f'init_type={b_type} is not supported.')
def init_weights_vs(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
nn.init.xavier_normal_(m.weight)
def init_bias_vs(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
if m.bias is not None:
nn.init.xavier_normal_(m.bias)
def init_weights_he(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
nn.init.kaiming_normal_(m.weight)
def init_bias_he(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
if m.bias is not None:
nn.init.kaiming_normal_(m.bias)
def init_weights_zero(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
m.weight.data.fill_(.0)
def init_weights_one(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
m.weight.data.fill_(1.)
def init_bias_zero(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
if m.bias is not None:
m.bias.data.fill_(.0)
def init_bias_one(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
if m.bias is not None:
m.bias.data.fill_(1.)

View File

@@ -0,0 +1,117 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
class Bottleneck(nn.Module):
def __init__(self, nChannels, growthRate):
super(Bottleneck, self).__init__()
interChannels = 4 * growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(
interChannels, growthRate, kernel_size=3, padding=1, bias=False
)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat((x, out), 1)
return out
class SingleLayer(nn.Module):
def __init__(self, nChannels, growthRate):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(
nChannels, growthRate, kernel_size=3, padding=1, bias=False
)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = torch.cat((x, out), 1)
return out
class Transition(nn.Module):
def __init__(self, nChannels, nOutChannels):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = F.avg_pool2d(out, 2)
return out
class DenseNet(nn.Module):
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
super(DenseNet, self).__init__()
if bottleneck:
nDenseBlocks = int((depth - 4) / 6)
else:
nDenseBlocks = int((depth - 4) / 3)
self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format(
"bottleneck" if bottleneck else "basic",
depth,
reduction,
growthRate,
nClasses,
)
nChannels = 2 * growthRate
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks * growthRate
nOutChannels = int(math.floor(nChannels * reduction))
self.trans1 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks * growthRate
nOutChannels = int(math.floor(nChannels * reduction))
self.trans2 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks * growthRate
self.act = nn.Sequential(
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8)
)
self.fc = nn.Linear(nChannels, nClasses)
self.apply(initialize_resnet)
def get_message(self):
return self.message
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels, growthRate))
else:
layers.append(SingleLayer(nChannels, growthRate))
nChannels += growthRate
return nn.Sequential(*layers)
def forward(self, inputs):
out = self.conv1(inputs)
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.dense3(out)
features = self.act(out)
features = features.view(features.size(0), -1)
out = self.fc(features)
return features, out

View File

@@ -0,0 +1,180 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
from .SharedUtils import additive_func
class Downsample(nn.Module):
def __init__(self, nIn, nOut, stride):
super(Downsample, self).__init__()
assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format(
stride, nIn, nOut
)
self.in_dim = nIn
self.out_dim = nOut
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
x = self.avg(x)
out = self.conv(x)
return out
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(
nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias
)
self.bn = nn.BatchNorm2d(nOut)
if relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
self.out_dim = nOut
self.num_conv = 1
def forward(self, x):
conv = self.conv(x)
bn = self.bn(conv)
if self.relu:
return self.relu(bn)
else:
return bn
class ResNetBasicblock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True)
self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False)
if stride == 2:
self.downsample = Downsample(inplanes, planes, stride)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False)
else:
self.downsample = None
self.out_dim = planes
self.num_conv = 2
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True)
self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True)
self.conv_1x4 = ConvBNReLU(
planes, planes * self.expansion, 1, 1, 0, False, False
)
if stride == 2:
self.downsample = Downsample(inplanes, planes * self.expansion, stride)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes, planes * self.expansion, 1, 1, 0, False, False
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.num_conv = 3
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, bottleneck)
return F.relu(out, inplace=True)
class CifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes, zero_init_residual):
super(CifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format(
block_name, depth, layer_blocks
)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)])
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
assert (
sum(x.num_conv for x in self.layers) + 1 == depth
), "invalid depth check {:} vs {:}".format(
sum(x.num_conv for x in self.layers) + 1, depth
)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,115 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
class WideBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, dropout=False):
super(WideBasicblock, self).__init__()
self.bn_a = nn.BatchNorm2d(inplanes)
self.conv_a = nn.Conv2d(
inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn_b = nn.BatchNorm2d(planes)
if dropout:
self.dropout = nn.Dropout2d(p=0.5, inplace=True)
else:
self.dropout = None
self.conv_b = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
if inplanes != planes:
self.downsample = nn.Conv2d(
inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False
)
else:
self.downsample = None
def forward(self, x):
basicblock = self.bn_a(x)
basicblock = F.relu(basicblock)
basicblock = self.conv_a(basicblock)
basicblock = self.bn_b(basicblock)
basicblock = F.relu(basicblock)
if self.dropout is not None:
basicblock = self.dropout(basicblock)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
x = self.downsample(x)
return x + basicblock
class CifarWideResNet(nn.Module):
"""
ResNet optimized for the Cifar dataset, as specified in
https://arxiv.org/abs/1512.03385.pdf
"""
def __init__(self, depth, widen_factor, num_classes, dropout):
super(CifarWideResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 4) // 6
print(
"CifarPreResNet : Depth : {} , Layers for each block : {}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.dropout = dropout
self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format(
depth, widen_factor, num_classes
)
self.inplanes = 16
self.stage_1 = self._make_layer(
WideBasicblock, 16 * widen_factor, layer_blocks, 1
)
self.stage_2 = self._make_layer(
WideBasicblock, 32 * widen_factor, layer_blocks, 2
)
self.stage_3 = self._make_layer(
WideBasicblock, 64 * widen_factor, layer_blocks, 2
)
self.lastact = nn.Sequential(
nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True)
)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(64 * widen_factor, num_classes)
self.apply(initialize_resnet)
def get_message(self):
return self.message
def _make_layer(self, block, planes, blocks, stride):
layers = []
layers.append(block(self.inplanes, planes, stride, self.dropout))
self.inplanes = planes
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, 1, self.dropout))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv_3x3(x)
x = self.stage_1(x)
x = self.stage_2(x)
x = self.stage_3(x)
x = self.lastact(x)
x = self.avgpool(x)
features = x.view(x.size(0), -1)
outs = self.classifier(features)
return features, outs

View File

@@ -0,0 +1,117 @@
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from .initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias=False,
)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU6(inplace=True)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
return out
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend(
[
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
]
)
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(
self, num_classes, width_mult, input_channel, last_channel, block_name, dropout
):
super(MobileNetV2, self).__init__()
if block_name == "InvertedResidual":
block = InvertedResidual
else:
raise ValueError("invalid block name : {:}".format(block_name))
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * max(1.0, width_mult))
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
stride = s if i == 0 else 1
features.append(
block(input_channel, output_channel, stride, expand_ratio=t)
)
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.last_channel, num_classes),
)
self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format(
width_mult, input_channel, last_channel, block_name, dropout
)
# weight initialization
self.apply(initialize_resnet)
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

View File

@@ -0,0 +1,217 @@
# Deep Residual Learning for Image Recognition, CVPR 2016
import torch.nn as nn
from .initialization import initialize_resnet
def conv3x3(in_planes, out_planes, stride=1, groups=1):
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
groups=groups,
bias=False,
)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(
self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64
):
super(BasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64
):
super(Bottleneck, self).__init__()
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = conv3x3(width, width, stride, groups)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block_name,
layers,
deep_stem,
num_classes,
zero_init_residual,
groups,
width_per_group,
):
super(ResNet, self).__init__()
# planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
if block_name == "BasicBlock":
block = BasicBlock
elif block_name == "Bottleneck":
block = Bottleneck
else:
raise ValueError("invalid block-name : {:}".format(block_name))
if not deep_stem:
self.conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
else:
self.conv = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.inplanes = 64
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group
)
self.layer2 = self._make_layer(
block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group
)
self.layer3 = self._make_layer(
block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group
)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.message = (
"block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format(
block, layers, deep_stem, num_classes
)
)
self.apply(initialize_resnet)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride, groups, base_width):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if stride == 2:
downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
conv1x1(self.inplanes, planes * block.expansion, 1),
nn.BatchNorm2d(planes * block.expansion),
)
elif stride == 1:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
)
else:
raise ValueError("invalid stride [{:}] for downsample".format(stride))
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, groups, base_width)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, 1, None, groups, base_width))
return nn.Sequential(*layers)
def get_message(self):
return self.message
def forward(self, x):
x = self.conv(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.fc(features)
return features, logits

View File

@@ -0,0 +1,37 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
def additive_func(A, B):
assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format(
A.size(), B.size()
)
C = min(A.size(1), B.size(1))
if A.size(1) == B.size(1):
return A + B
elif A.size(1) < B.size(1):
out = B.clone()
out[:, :C] += A
return out
else:
out = A.clone()
out[:, :C] += B
return out
def change_key(key, value):
def func(m):
if hasattr(m, key):
setattr(m, key, value)
return func
def parse_channel_info(xstring):
blocks = xstring.split(" ")
blocks = [x.split("-") for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,329 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from os import path as osp
from typing import List, Text
import torch
__all__ = [
"change_key",
"get_cell_based_tiny_net",
"get_search_spaces",
"get_cifar_models",
"get_imagenet_models",
"obtain_model",
"obtain_search_model",
"load_net_from_checkpoint",
"CellStructure",
"CellArchitectures",
]
# useful modules
from xautodl.config_utils import dict2config
from .SharedUtils import change_key
from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
if isinstance(config, dict):
config = dict2config(config, None) # to support the argument being a dict
# print(config)
super_type = getattr(config, "super_type", "basic")
# print(super_type)
group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"]
if super_type == "basic" and config.name in group_names:
from .cell_searchs import nas201_super_nets as nas_super_nets
try:
return nas_super_nets[config.name](
config.C,
config.N,
config.max_nodes,
config.num_classes,
config.space,
config.affine,
config.track_running_stats,
)
except:
return nas_super_nets[config.name](
config.C, config.N, config.max_nodes, config.num_classes, config.space
)
elif super_type == "search-shape":
from .shape_searchs import GenericNAS301Model
genotype = CellStructure.str2structure(config.genotype)
return GenericNAS301Model(
config.candidate_Cs,
config.max_num_Cs,
genotype,
config.num_classes,
config.affine,
config.track_running_stats,
)
elif super_type == "nasnet-super":
from .cell_searchs import nasnet_super_nets as nas_super_nets
return nas_super_nets[config.name](
config.C,
config.N,
config.steps,
config.multiplier,
config.stem_multiplier,
config.num_classes,
config.space,
config.affine,
config.track_running_stats,
)
elif config.name == "infer.tiny":
from .cell_infers import TinyNetwork
if hasattr(config, "genotype"):
genotype = config.genotype
elif hasattr(config, "arch_str"):
genotype = CellStructure.str2structure(config.arch_str)
else:
raise ValueError(
"Can not find genotype from this config : {:}".format(config)
)
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
# sss 网络用到的
elif config.name == "infer.shape.tiny":
from .shape_infers import DynamicShapeTinyNet
if isinstance(config.channels, str):
channels = tuple([int(x) for x in config.channels.split(":")])
else:
channels = config.channels
genotype = CellStructure.str2structure(config.genotype)
return DynamicShapeTinyNet(channels, genotype, config.num_classes)
elif config.name == "infer.nasnet-cifar":
from .cell_infers import NASNetonCIFAR
raise NotImplementedError
else:
raise ValueError("invalid network name : {:}".format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name) -> List[Text]:
if xtype == "cell" or xtype == "tss": # The topology search space.
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format(
name, SearchSpaceNames.keys()
)
return SearchSpaceNames[name]
elif xtype == "sss": # The size search space.
if name in ["nats-bench", "nats-bench-size"]:
return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5}
else:
raise ValueError("Invalid name : {:}".format(name))
else:
raise ValueError("invalid search-space type is {:}".format(xtype))
def get_cifar_models(config, extra_path=None):
super_type = getattr(config, "super_type", "basic")
if super_type == "basic":
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
if config.arch == "resnet":
return CifarResNet(
config.module, config.depth, config.class_num, config.zero_init_residual
)
elif config.arch == "densenet":
return DenseNet(
config.growthRate,
config.depth,
config.reduction,
config.class_num,
config.bottleneck,
)
elif config.arch == "wideresnet":
return CifarWideResNet(
config.depth, config.wide_factor, config.class_num, config.dropout
)
else:
raise ValueError("invalid module type : {:}".format(config.arch))
elif super_type.startswith("infer"):
from .shape_infers import InferWidthCifarResNet
from .shape_infers import InferDepthCifarResNet
from .shape_infers import InferCifarResNet
from .cell_infers import NASNetonCIFAR
assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format(
super_type
)
infer_mode = super_type.split("-")[1]
if infer_mode == "width":
return InferWidthCifarResNet(
config.module,
config.depth,
config.xchannels,
config.class_num,
config.zero_init_residual,
)
elif infer_mode == "depth":
return InferDepthCifarResNet(
config.module,
config.depth,
config.xblocks,
config.class_num,
config.zero_init_residual,
)
elif infer_mode == "shape":
return InferCifarResNet(
config.module,
config.depth,
config.xblocks,
config.xchannels,
config.class_num,
config.zero_init_residual,
)
elif infer_mode == "nasnet.cifar":
genotype = config.genotype
if extra_path is not None: # reload genotype by extra_path
if not osp.isfile(extra_path):
raise ValueError("invalid extra_path : {:}".format(extra_path))
xdata = torch.load(extra_path)
current_epoch = xdata["epoch"]
genotype = xdata["genotypes"][current_epoch - 1]
C = config.C if hasattr(config, "C") else config.ichannel
N = config.N if hasattr(config, "N") else config.layers
return NASNetonCIFAR(
C, N, config.stem_multi, config.class_num, genotype, config.auxiliary
)
else:
raise ValueError("invalid infer-mode : {:}".format(infer_mode))
else:
raise ValueError("invalid super-type : {:}".format(super_type))
def get_imagenet_models(config):
super_type = getattr(config, "super_type", "basic")
if super_type == "basic":
from .ImageNet_ResNet import ResNet
from .ImageNet_MobileNetV2 import MobileNetV2
if config.arch == "resnet":
return ResNet(
config.block_name,
config.layers,
config.deep_stem,
config.class_num,
config.zero_init_residual,
config.groups,
config.width_per_group,
)
elif config.arch == "mobilenet_v2":
return MobileNetV2(
config.class_num,
config.width_multi,
config.input_channel,
config.last_channel,
"InvertedResidual",
config.dropout,
)
else:
raise ValueError("invalid arch : {:}".format(config.arch))
elif super_type.startswith("infer"): # NAS searched architecture
assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format(
super_type
)
infer_mode = super_type.split("-")[1]
if infer_mode == "shape":
from .shape_infers import InferImagenetResNet
from .shape_infers import InferMobileNetV2
if config.arch == "resnet":
return InferImagenetResNet(
config.block_name,
config.layers,
config.xblocks,
config.xchannels,
config.deep_stem,
config.class_num,
config.zero_init_residual,
)
elif config.arch == "MobileNetV2":
return InferMobileNetV2(
config.class_num, config.xchannels, config.xblocks, config.dropout
)
else:
raise ValueError("invalid arch-mode : {:}".format(config.arch))
else:
raise ValueError("invalid infer-mode : {:}".format(infer_mode))
else:
raise ValueError("invalid super-type : {:}".format(super_type))
# Try to obtain the network by config.
def obtain_model(config, extra_path=None):
if config.dataset == "cifar":
return get_cifar_models(config, extra_path)
elif config.dataset == "imagenet":
return get_imagenet_models(config)
else:
raise ValueError("invalid dataset in the model config : {:}".format(config))
def obtain_search_model(config):
if config.dataset == "cifar":
if config.arch == "resnet":
from .shape_searchs import SearchWidthCifarResNet
from .shape_searchs import SearchDepthCifarResNet
from .shape_searchs import SearchShapeCifarResNet
if config.search_mode == "width":
return SearchWidthCifarResNet(
config.module, config.depth, config.class_num
)
elif config.search_mode == "depth":
return SearchDepthCifarResNet(
config.module, config.depth, config.class_num
)
elif config.search_mode == "shape":
return SearchShapeCifarResNet(
config.module, config.depth, config.class_num
)
else:
raise ValueError("invalid search mode : {:}".format(config.search_mode))
elif config.arch == "simres":
from .shape_searchs import SearchWidthSimResNet
if config.search_mode == "width":
return SearchWidthSimResNet(config.depth, config.class_num)
else:
raise ValueError("invalid search mode : {:}".format(config.search_mode))
else:
raise ValueError(
"invalid arch : {:} for dataset [{:}]".format(
config.arch, config.dataset
)
)
elif config.dataset == "imagenet":
from .shape_searchs import SearchShapeImagenetResNet
assert config.search_mode == "shape", "invalid search-mode : {:}".format(
config.search_mode
)
if config.arch == "resnet":
return SearchShapeImagenetResNet(
config.block_name, config.layers, config.deep_stem, config.class_num
)
else:
raise ValueError("invalid model config : {:}".format(config))
else:
raise ValueError("invalid dataset in the model config : {:}".format(config))
def load_net_from_checkpoint(checkpoint):
assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint)
checkpoint = torch.load(checkpoint)
model_config = dict2config(checkpoint["model-config"], None)
model = obtain_model(model_config)
model.load_state_dict(checkpoint["base-model"])
return model

View File

@@ -0,0 +1,5 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .tiny_network import TinyNetwork
from .nasnet_cifar import NASNetonCIFAR

View File

@@ -0,0 +1,155 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from xautodl.models.cell_operations import OPS
# Cell for NAS-Bench-201
class InferCell(nn.Module):
def __init__(
self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True
):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i - 1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](
C_in, C_out, stride, affine, track_running_stats
)
else:
layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats)
cur_index.append(len(self.layers))
cur_innod.append(op_in)
self.layers.append(layer)
self.node_IX.append(cur_index)
self.node_IN.append(cur_innod)
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format(
**self.__dict__
)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
y = [
"I{:}-L{:}".format(_ii, _il)
for _il, _ii in zip(node_layers, node_innods)
]
x = "{:}<-({:})".format(i + 1, ",".join(y))
laystr.append(x)
return (
string
+ ", [{:}]".format(" | ".join(laystr))
+ ", {:}".format(self.genotype.tostr())
)
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
node_feature = sum(
self.layers[_il](nodes[_ii])
for _il, _ii in zip(node_layers, node_innods)
)
nodes.append(node_feature)
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module):
def __init__(
self,
genotype,
C_prev_prev,
C_prev,
C,
reduction,
reduction_prev,
affine,
track_running_stats,
):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev:
self.preprocess0 = OPS["skip_connect"](
C_prev_prev, C, 2, affine, track_running_stats
)
else:
self.preprocess0 = OPS["nor_conv_1x1"](
C_prev_prev, C, 1, affine, track_running_stats
)
self.preprocess1 = OPS["nor_conv_1x1"](
C_prev, C, 1, affine, track_running_stats
)
if not reduction:
nodes, concats = genotype["normal"], genotype["normal_concat"]
else:
nodes, concats = genotype["reduce"], genotype["reduce_concat"]
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = "{:}<-{:}".format(i + 2, j)
self.edges[node_str] = OPS[name](
C, C, stride, affine, track_running_stats
)
# [TODO] to support drop_prob in this function..
def forward(self, s0, s1, unused_drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i, node in enumerate(self._nodes):
clist = []
for in_node in node:
name, j = in_node[0], in_node[1]
node_str = "{:}<-{:}".format(i + 2, j)
op = self.edges[node_str]
clist.append(op(states[j]))
states.append(sum(clist))
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(
5, stride=3, padding=0, count_include_pad=False
), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True),
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x

View File

@@ -0,0 +1,118 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
# The macro structure is based on NASNet
class NASNetonCIFAR(nn.Module):
def __init__(
self,
C,
N,
stem_multiplier,
num_classes,
genotype,
auxiliary,
affine=True,
track_running_stats=True,
):
super(NASNetonCIFAR, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.auxiliary_index = None
self.auxiliary_head = None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
cell = InferCell(
genotype,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = (
C_prev,
cell._multiplier * C_curr,
reduction,
)
if reduction and C_curr == C * 4 and auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
self.auxiliary_index = index
self._Layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob
def auxiliary_param(self):
if self.auxiliary_head is None:
return []
else:
return list(self.auxiliary_head.parameters())
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs):
stem_feature, logits_aux = self.stem(inputs), None
cell_results = [stem_feature, stem_feature]
for i, cell in enumerate(self.cells):
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
cell_results.append(cell_feature)
if (
self.auxiliary_index is not None
and i == self.auxiliary_index
and self.training
):
logits_aux = self.auxiliary_head(cell_results[-1])
out = self.lastact(cell_results[-1])
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if logits_aux is None:
return out, logits
else:
return out, [logits, logits_aux]

View File

@@ -0,0 +1,63 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append(cell)
C_prev = cell.out_dim
self._Layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,553 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
__all__ = ["OPS", "RAW_OP_CLASSES", "ResNetBasicblock", "SearchSpaceNames"]
OPS = {
"none": lambda C_in, C_out, stride, affine, track_running_stats: Zero(
C_in, C_out, stride
),
"avg_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING(
C_in, C_out, stride, "avg", affine, track_running_stats
),
"max_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING(
C_in, C_out, stride, "max", affine, track_running_stats
),
"nor_conv_7x7": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
C_in,
C_out,
(7, 7),
(stride, stride),
(3, 3),
(1, 1),
affine,
track_running_stats,
),
"nor_conv_3x3": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
C_in,
C_out,
(3, 3),
(stride, stride),
(1, 1),
(1, 1),
affine,
track_running_stats,
),
"nor_conv_1x1": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
C_in,
C_out,
(1, 1),
(stride, stride),
(0, 0),
(1, 1),
affine,
track_running_stats,
),
"dua_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(
C_in,
C_out,
(3, 3),
(stride, stride),
(1, 1),
(1, 1),
affine,
track_running_stats,
),
"dua_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(
C_in,
C_out,
(5, 5),
(stride, stride),
(2, 2),
(1, 1),
affine,
track_running_stats,
),
"dil_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: SepConv(
C_in,
C_out,
(3, 3),
(stride, stride),
(2, 2),
(2, 2),
affine,
track_running_stats,
),
"dil_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: SepConv(
C_in,
C_out,
(5, 5),
(stride, stride),
(4, 4),
(2, 2),
affine,
track_running_stats,
),
"skip_connect": lambda C_in, C_out, stride, affine, track_running_stats: Identity()
if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
}
CONNECT_NAS_BENCHMARK = ["none", "skip_connect", "nor_conv_3x3"]
NAS_BENCH_201 = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"]
DARTS_SPACE = [
"none",
"skip_connect",
"dua_sepc_3x3",
"dua_sepc_5x5",
"dil_sepc_3x3",
"dil_sepc_5x5",
"avg_pool_3x3",
"max_pool_3x3",
]
SearchSpaceNames = {
"connect-nas": CONNECT_NAS_BENCHMARK,
"nats-bench": NAS_BENCH_201,
"nas-bench-201": NAS_BENCH_201,
"darts": DARTS_SPACE,
}
class ReLUConvBN(nn.Module):
def __init__(
self,
C_in,
C_out,
kernel_size,
stride,
padding,
dilation,
affine,
track_running_stats=True,
):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in,
C_out,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=not affine,
),
nn.BatchNorm2d(
C_out, affine=affine, track_running_stats=track_running_stats
),
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(
self,
C_in,
C_out,
kernel_size,
stride,
padding,
dilation,
affine,
track_running_stats=True,
):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in,
C_in,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=C_in,
bias=False,
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
nn.BatchNorm2d(
C_out, affine=affine, track_running_stats=track_running_stats
),
)
def forward(self, x):
return self.op(x)
class DualSepConv(nn.Module):
def __init__(
self,
C_in,
C_out,
kernel_size,
stride,
padding,
dilation,
affine,
track_running_stats=True,
):
super(DualSepConv, self).__init__()
self.op_a = SepConv(
C_in,
C_in,
kernel_size,
stride,
padding,
dilation,
affine,
track_running_stats,
)
self.op_b = SepConv(
C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats
)
def forward(self, x):
x = self.op_a(x)
x = self.op_b(x)
return x
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ReLUConvBN(
inplanes, planes, 3, stride, 1, 1, affine, track_running_stats
)
self.conv_b = ReLUConvBN(
planes, planes, 3, 1, 1, 1, affine, track_running_stats
)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(
inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False
),
)
elif inplanes != planes:
self.downsample = ReLUConvBN(
inplanes, planes, 1, 1, 0, 1, affine, track_running_stats
)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def extra_repr(self):
string = "{name}(inC={in_dim}, outC={out_dim}, stride={stride})".format(
name=self.__class__.__name__, **self.__dict__
)
return string
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return residual + basicblock
class POOLING(nn.Module):
def __init__(
self, C_in, C_out, stride, mode, affine=True, track_running_stats=True
):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(
C_in, C_out, 1, 1, 0, 1, affine, track_running_stats
)
if mode == "avg":
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == "max":
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else:
raise ValueError("Invalid mode={:} in POOLING".format(mode))
def forward(self, inputs):
if self.preprocess:
x = self.preprocess(inputs)
else:
x = inputs
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1:
return x.mul(0.0)
else:
return x[:, :, :: self.stride, :: self.stride].mul(0.0)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
def extra_repr(self):
return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
# assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(
nn.Conv2d(
C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine
)
)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(
C_in, C_out, 1, stride=stride, padding=0, bias=not affine
)
else:
raise ValueError("Invalid stride : {:}".format(stride))
self.bn = nn.BatchNorm2d(
C_out, affine=affine, track_running_stats=track_running_stats
)
def forward(self, x):
if self.stride == 2:
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
else:
out = self.conv(x)
out = self.bn(out)
return out
def extra_repr(self):
return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__)
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
class PartAwareOp(nn.Module):
def __init__(self, C_in, C_out, stride, part=4):
super().__init__()
self.part = 4
self.hidden = C_in // 3
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.local_conv_list = nn.ModuleList()
for i in range(self.part):
self.local_conv_list.append(
nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, self.hidden, 1),
nn.BatchNorm2d(self.hidden, affine=True),
)
)
self.W_K = nn.Linear(self.hidden, self.hidden)
self.W_Q = nn.Linear(self.hidden, self.hidden)
if stride == 2:
self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
elif stride == 1:
self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
else:
raise ValueError("Invalid Stride : {:}".format(stride))
def forward(self, x):
batch, C, H, W = x.size()
assert H >= self.part, "input size too small : {:} vs {:}".format(
x.shape, self.part
)
IHs = [0]
for i in range(self.part):
IHs.append(min(H, int((i + 1) * (float(H) / self.part))))
local_feat_list = []
for i in range(self.part):
feature = x[:, :, IHs[i] : IHs[i + 1], :]
xfeax = self.avg_pool(feature)
xfea = self.local_conv_list[i](xfeax)
local_feat_list.append(xfea)
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
part_feature = part_feature.transpose(1, 2).contiguous()
part_K = self.W_K(part_feature)
part_Q = self.W_Q(part_feature).transpose(1, 2).contiguous()
weight_att = torch.bmm(part_K, part_Q)
attention = torch.softmax(weight_att, dim=2)
aggreateF = torch.bmm(attention, part_feature).transpose(1, 2).contiguous()
features = []
for i in range(self.part):
feature = aggreateF[:, :, i : i + 1].expand(
batch, self.hidden, IHs[i + 1] - IHs[i]
)
feature = feature.view(batch, self.hidden, IHs[i + 1] - IHs[i], 1)
features.append(feature)
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
final_fea = torch.cat((x, features), dim=1)
outputs = self.last(final_fea)
return outputs
def drop_path(x, drop_prob):
if drop_prob > 0.0:
keep_prob = 1.0 - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = torch.div(x, keep_prob)
x.mul_(mask)
return x
# Searching for A Robust Neural Architecture in Four GPU Hours
class GDAS_Reduction_Cell(nn.Module):
def __init__(
self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats
):
super(GDAS_Reduction_Cell, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(
C_prev_prev, C, 2, affine, track_running_stats
)
else:
self.preprocess0 = ReLUConvBN(
C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats
)
self.preprocess1 = ReLUConvBN(
C_prev, C, 1, 1, 0, 1, affine, track_running_stats
)
self.reduction = True
self.ops1 = nn.ModuleList(
[
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C,
C,
(1, 3),
stride=(1, 2),
padding=(0, 1),
groups=8,
bias=not affine,
),
nn.Conv2d(
C,
C,
(3, 1),
stride=(2, 1),
padding=(1, 0),
groups=8,
bias=not affine,
),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
),
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C,
C,
(1, 3),
stride=(1, 2),
padding=(0, 1),
groups=8,
bias=not affine,
),
nn.Conv2d(
C,
C,
(3, 1),
stride=(2, 1),
padding=(1, 0),
groups=8,
bias=not affine,
),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
),
]
)
self.ops2 = nn.ModuleList(
[
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
),
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(
C, affine=affine, track_running_stats=track_running_stats
),
),
]
)
@property
def multiplier(self):
return 4
def forward(self, s0, s1, drop_prob=-1):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
X0 = self.ops1[0](s0)
X1 = self.ops1[1](s1)
if self.training and drop_prob > 0.0:
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
# X2 = self.ops2[0] (X0+X1)
X2 = self.ops2[0](s0)
X3 = self.ops2[1](s1)
if self.training and drop_prob > 0.0:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
return torch.cat([X0, X1, X2, X3], dim=1)
# To manage the useful classes in this file.
RAW_OP_CLASSES = {"gdas_reduction": GDAS_Reduction_Cell}

View File

@@ -0,0 +1,33 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# The macro structure is defined in NAS-Bench-201
from .search_model_darts import TinyNetworkDarts
from .search_model_gdas import TinyNetworkGDAS
from .search_model_setn import TinyNetworkSETN
from .search_model_enas import TinyNetworkENAS
from .search_model_random import TinyNetworkRANDOM
from .generic_model import GenericNAS201Model
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
from .search_model_gdas_nasnet import NASNetworkGDAS
from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC
from .search_model_darts_nasnet import NASNetworkDARTS
nas201_super_nets = {
"DARTS-V1": TinyNetworkDarts,
"DARTS-V2": TinyNetworkDarts,
"GDAS": TinyNetworkGDAS,
"SETN": TinyNetworkSETN,
"ENAS": TinyNetworkENAS,
"RANDOM": TinyNetworkRANDOM,
"generic": GenericNAS201Model,
}
nasnet_super_nets = {
"GDAS": NASNetworkGDAS,
"GDAS_FRC": NASNetworkGDAS_FRC,
"DARTS": NASNetworkDARTS,
}

View File

@@ -0,0 +1,14 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from search_model_enas_utils import Controller
def main():
controller = Controller(6, 4)
predictions = controller()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,366 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
#####################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from typing import Text
from torch.distributions.categorical import Categorical
from ..cell_operations import ResNetBasicblock, drop_path
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(
self,
edge2index,
op_names,
max_nodes,
lstm_size=32,
lstm_num_layers=2,
tanh_constant=2.5,
temperature=5.0,
):
super(Controller, self).__init__()
# assign the attributes
self.max_nodes = max_nodes
self.num_edge = len(edge2index)
self.edge2index = edge2index
self.num_ops = len(op_names)
self.op_names = op_names
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter(
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
)
self.w_lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.lstm_N,
)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
def convert_structure(self, _arch):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_index = _arch[self.edge2index[node_str]]
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append(op_index.item())
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append(op_log_prob.view(-1))
op_entropy = op_distribution.entropy()
entropys.append(op_entropy.view(-1))
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return (
torch.sum(torch.cat(log_probs)),
torch.sum(torch.cat(entropys)),
self.convert_structure(sampled_arch),
)
class GenericNAS201Model(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(GenericNAS201Model, self).__init__()
self._C = C
self._layerN = N
self._max_nodes = max_nodes
self._stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self._cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self._cells.append(cell)
C_prev = cell.out_dim
self._op_names = deepcopy(search_space)
self._Layer = len(self._cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(
nn.BatchNorm2d(
C_prev, affine=affine, track_running_stats=track_running_stats
),
nn.ReLU(inplace=True),
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._num_edge = num_edge
# algorithm related
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self._mode = None
self.dynamic_cell = None
self._tau = None
self._algo = None
self._drop_path = None
self.verbose = False
def set_algo(self, algo: Text):
# used for searching
assert self._algo is None, "This functioin can only be called once."
self._algo = algo
if algo == "enas":
self.controller = Controller(
self.edge2index, self._op_names, self._max_nodes
)
else:
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(self._num_edge, len(self._op_names))
)
if algo == "gdas":
self._tau = 10
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"]
self._mode = mode
if mode == "dynamic":
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def set_drop_path(self, progress, drop_path_rate):
if drop_path_rate is None:
self._drop_path = None
elif progress is None:
self._drop_path = drop_path_rate
else:
self._drop_path = progress * drop_path_rate
@property
def mode(self):
return self._mode
@property
def drop_path(self):
return self._drop_path
@property
def weights(self):
xlist = list(self._stem.parameters())
xlist += list(self._cells.parameters())
xlist += list(self.lastact.parameters())
xlist += list(self.global_pooling.parameters())
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self._tau = tau
@property
def tau(self):
return self._tau
@property
def alphas(self):
if self._algo == "enas":
return list(self.controller.parameters())
else:
return [self.arch_parameters]
@property
def message(self):
string = self.extra_repr()
for i, cell in enumerate(self._cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self._cells), cell.extra_repr()
)
return string
def show_alphas(self):
with torch.no_grad():
if self._algo == "enas":
return "w_pred :\n{:}".format(self.controller.w_pred.weight)
else:
return "arch-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format(
name=self.__class__.__name__, **self.__dict__
)
@property
def genotype(self):
genotypes = []
for i in range(1, self._max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self._op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self._max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self._op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self._op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def get_log_prob(self, arch):
with torch.no_grad():
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = "{:}<-{:}".format(i + 1, xin)
op_index = self._op_names.index(op)
select_logits.append(logits[self.edge2index[node_str], op_index])
return sum(select_logits).item()
def return_topK(self, K, use_random=False):
archs = Structure.gen_all(self._op_names, self._max_nodes, False)
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs):
K = len(archs)
if use_random:
return random.sample(archs, K)
else:
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def normalize_archp(self):
if self.mode == "gdas":
while True:
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
with torch.no_grad():
hardwts_cpu = hardwts.detach().cpu()
return hardwts, hardwts_cpu, index, "GUMBEL"
else:
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
index = alphas.max(-1, keepdim=True)[1]
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
return alphas, alphas_cpu, index, "SOFTMAX"
def forward(self, inputs):
alphas, alphas_cpu, index, verbose_str = self.normalize_archp()
feature = self._stem(inputs)
for i, cell in enumerate(self._cells):
if isinstance(cell, SearchCell):
if self.mode == "urs":
feature = cell.forward_urs(feature)
if self.verbose:
verbose_str += "-forward_urs"
elif self.mode == "select":
feature = cell.forward_select(feature, alphas_cpu)
if self.verbose:
verbose_str += "-forward_select"
elif self.mode == "joint":
feature = cell.forward_joint(feature, alphas)
if self.verbose:
verbose_str += "-forward_joint"
elif self.mode == "dynamic":
feature = cell.forward_dynamic(feature, self.dynamic_cell)
if self.verbose:
verbose_str += "-forward_dynamic"
elif self.mode == "gdas":
feature = cell.forward_gdas(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas"
elif self.mode == "gdas_v1":
feature = cell.forward_gdas_v1(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas_v1"
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:
feature = cell(feature)
if self.drop_path is not None:
feature = drop_path(feature, self.drop_path)
if self.verbose and random.random() < 0.001:
print(verbose_str)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,274 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from copy import deepcopy
def get_combination(space, num):
combs = []
for i in range(num):
if i == 0:
for func in space:
combs.append([(func, i)])
else:
new_combs = []
for string in combs:
for func in space:
xstring = string + [(func, i)]
new_combs.append(xstring)
combs = new_combs
return combs
class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(
genotype, tuple
), "invalid class of genotype : {:}".format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(
node_info, tuple
), "invalid class of node_info : {:}".format(type(node_info))
assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(
node_in, tuple
), "invalid class of in-node : {:}".format(type(node_in))
assert (
len(node_in) == 2 and node_in[1] <= idx
), "invalid in-node : {:}".format(node_in)
self.node_N.append(len(node_info))
self.nodes.append(tuple(deepcopy(node_info)))
def tolist(self, remove_str):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
# note that we re-order the input node in this function
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
genotypes = []
for node_info in self.nodes:
node_info = list(node_info)
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0:
return None, False
genotypes.append(node_info)
return genotypes, True
def node(self, index):
assert index > 0 and index <= len(self), "invalid index={:} < {:}".format(
index, len(self)
)
return self.nodes[index]
def tostr(self):
strings = []
for node_info in self.nodes:
string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info])
string = "|{:}|".format(string)
strings.append(string)
return "+".join(strings)
def check_valid(self):
nodes = {0: True}
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == "none" or nodes[xin] is False:
x = False
else:
x = True
sums.append(x)
nodes[i + 1] = sum(sums) > 0
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
# two operations are special, i.e., none and skip_connect
nodes = {0: "0"}
for i_node, node_info in enumerate(self.nodes):
cur_node = []
for op, xin in node_info:
if consider_zero is None:
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
elif consider_zero:
if op == "none" or nodes[xin] == "#":
x = "#" # zero
elif op == "skip_connect":
x = nodes[xin]
else:
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
else:
if op == "skip_connect":
x = nodes[xin]
else:
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
cur_node.append(x)
nodes[i_node + 1] = "+".join(sorted(cur_node))
return nodes[len(self.nodes)]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
# assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names:
return False
return True
def __repr__(self):
return "{name}({node_num} nodes with {node_info})".format(
name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__
)
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
if isinstance(xstr, Structure):
return xstr
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
type(xstr)
)
nodestrs = xstr.split("+")
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != "", node_str.split("|")))
for xinput in inputs:
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
xinput
)
inputs = (xi.split("~") for xi in inputs)
input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs)
genotypes.append(input_infos)
return Structure(genotypes)
@staticmethod
def str2fullstructure(xstr, default_name="none"):
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
type(xstr)
)
nodestrs = xstr.split("+")
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != "", node_str.split("|")))
for xinput in inputs:
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
xinput
)
inputs = (xi.split("~") for xi in inputs)
input_infos = list((op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes = list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes:
input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append(tuple(node_info))
return Structure(genotypes)
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(
search_space, tuple
), "invalid class of search-space : {:}".format(type(search_space))
assert (
num >= 2
), "There should be at least two nodes in a neural cell instead of {:}".format(
num
)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [tuple(arch)]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append(previous_arch + [tuple(cur_node)])
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure(
[
(("nor_conv_3x3", 0),), # node-1
(("nor_conv_3x3", 1),), # node-2
(("skip_connect", 0), ("skip_connect", 2)),
] # node-3
)
AllConv3x3_CODE = Structure(
[
(("nor_conv_3x3", 0),), # node-1
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1)), # node-2
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
] # node-3
)
AllFull_CODE = Structure(
[
(
("skip_connect", 0),
("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
), # node-1
(
("skip_connect", 0),
("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
("skip_connect", 1),
("nor_conv_1x1", 1),
("nor_conv_3x3", 1),
("avg_pool_3x3", 1),
), # node-2
(
("skip_connect", 0),
("nor_conv_1x1", 0),
("nor_conv_3x3", 0),
("avg_pool_3x3", 0),
("skip_connect", 1),
("nor_conv_1x1", 1),
("nor_conv_3x3", 1),
("avg_pool_3x3", 1),
("skip_connect", 2),
("nor_conv_1x1", 2),
("nor_conv_3x3", 2),
("avg_pool_3x3", 2),
),
] # node-3
)
AllConv1x1_CODE = Structure(
[
(("nor_conv_1x1", 0),), # node-1
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1)), # node-2
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)),
] # node-3
)
AllIdentity_CODE = Structure(
[
(("skip_connect", 0),), # node-1
(("skip_connect", 0), ("skip_connect", 1)), # node-2
(("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)),
] # node-3
)
architectures = {
"resnet": ResNet_CODE,
"all_c3x3": AllConv3x3_CODE,
"all_c1x1": AllConv1x1_CODE,
"all_idnt": AllIdentity_CODE,
"all_full": AllFull_CODE,
}

View File

@@ -0,0 +1,267 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, random, torch
import warnings
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from ..cell_operations import OPS
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
class NAS201SearchCell(nn.Module):
def __init__(
self,
C_in,
C_out,
stride,
max_nodes,
op_names,
affine=False,
track_running_stats=True,
):
super(NAS201SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.edges = nn.ModuleDict()
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if j == 0:
xlists = [
OPS[op_name](C_in, C_out, stride, affine, track_running_stats)
for op_name in op_names
]
else:
xlists = [
OPS[op_name](C_in, C_out, 1, affine, track_running_stats)
for op_name in op_names
]
self.edges[node_str] = nn.ModuleList(xlists)
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def extra_repr(self):
string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format(
**self.__dict__
)
return string
def forward(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = weightss[self.edge2index[node_str]]
inter_nodes.append(
sum(
layer(nodes[j]) * w
for layer, w in zip(self.edges[node_str], weights)
)
)
nodes.append(sum(inter_nodes))
return nodes[-1]
# GDAS
def forward_gdas(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = hardwts[self.edge2index[node_str]]
argmaxs = index[self.edge2index[node_str]].item()
weigsum = sum(
weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie]
for _ie, edge in enumerate(self.edges[node_str])
)
inter_nodes.append(weigsum)
nodes.append(sum(inter_nodes))
return nodes[-1]
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
def forward_gdas_v1(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = hardwts[self.edge2index[node_str]]
argmaxs = index[self.edge2index[node_str]].item()
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
inter_nodes.append(weigsum)
nodes.append(sum(inter_nodes))
return nodes[-1]
# joint
def forward_joint(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = weightss[self.edge2index[node_str]]
# aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
aggregation = sum(
layer(nodes[j]) * w
for layer, w in zip(self.edges[node_str], weights)
)
inter_nodes.append(aggregation)
nodes.append(sum(inter_nodes))
return nodes[-1]
# uniform random sampling per iteration, SETN
def forward_urs(self, inputs):
nodes = [inputs]
for i in range(1, self.max_nodes):
while True: # to avoid select zero for all ops
sops, has_non_zero = [], False
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
candidates = self.edges[node_str]
select_op = random.choice(candidates)
sops.append(select_op)
if not hasattr(select_op, "is_zero") or select_op.is_zero is False:
has_non_zero = True
if has_non_zero:
break
inter_nodes = []
for j, select_op in enumerate(sops):
inter_nodes.append(select_op(nodes[j]))
nodes.append(sum(inter_nodes))
return nodes[-1]
# select the argmax
def forward_select(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = weightss[self.edge2index[node_str]]
inter_nodes.append(
self.edges[node_str][weights.argmax().item()](nodes[j])
)
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
nodes.append(sum(inter_nodes))
return nodes[-1]
# forward with a specific structure
def forward_dynamic(self, inputs, structure):
nodes = [inputs]
for i in range(1, self.max_nodes):
cur_op_node = structure.nodes[i - 1]
inter_nodes = []
for op_name, j in cur_op_node:
node_str = "{:}<-{:}".format(i, j)
op_index = self.op_names.index(op_name)
inter_nodes.append(self.edges[node_str][op_index](nodes[j]))
nodes.append(sum(inter_nodes))
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in space:
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)
def forward_gdas(self, x, weights, index):
return self._ops[index](x) * weights[index]
def forward_darts(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))
class NASNetSearchCell(nn.Module):
def __init__(
self,
space,
steps,
multiplier,
C_prev_prev,
C_prev,
C,
reduction,
reduction_prev,
affine,
track_running_stats,
):
super(NASNetSearchCell, self).__init__()
self.reduction = reduction
self.op_names = deepcopy(space)
if reduction_prev:
self.preprocess0 = OPS["skip_connect"](
C_prev_prev, C, 2, affine, track_running_stats
)
else:
self.preprocess0 = OPS["nor_conv_1x1"](
C_prev_prev, C, 1, affine, track_running_stats
)
self.preprocess1 = OPS["nor_conv_1x1"](
C_prev, C, 1, affine, track_running_stats
)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self.edges = nn.ModuleDict()
for i in range(self._steps):
for j in range(2 + i):
node_str = "{:}<-{:}".format(
i, j
) # indicate the edge from node-(j) to node-(i+2)
stride = 2 if reduction and j < 2 else 1
op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[node_str] = op
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
@property
def multiplier(self):
return self._multiplier
def forward_gdas(self, s0, s1, weightss, indexs):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = "{:}<-{:}".format(i, j)
op = self.edges[node_str]
weights = weightss[self.edge2index[node_str]]
index = indexs[self.edge2index[node_str]].item()
clist.append(op.forward_gdas(h, weights, index))
states.append(sum(clist))
return torch.cat(states[-self._multiplier :], dim=1)
def forward_darts(self, s0, s1, weightss):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = "{:}<-{:}".format(i, j)
op = self.edges[node_str]
weights = weightss[self.edge2index[node_str]]
clist.append(op.forward_darts(h, weights))
states.append(sum(clist))
return torch.cat(states[-self._multiplier :], dim=1)

View File

@@ -0,0 +1,122 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
########################################################
# DARTS: Differentiable Architecture Search, ICLR 2019 #
########################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkDarts(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkDarts, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
return "arch-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, alphas)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,178 @@
####################
# DARTS, ICLR 2019 #
####################
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module):
def __init__(
self,
C: int,
N: int,
steps: int,
multiplier: int,
stem_multiplier: int,
num_classes: int,
search_space: List[Text],
affine: bool,
track_running_stats: bool,
):
super(NASNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
cell = SearchCell(
search_space,
steps,
multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
def get_weights(self) -> List[torch.nn.Parameter]:
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self) -> List[torch.nn.Parameter]:
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self) -> Text:
with torch.no_grad():
A = "arch-normal-parameters :\n{:}".format(
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
)
B = "arch-reduce-parameters :\n{:}".format(
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
)
return "{:}\n{:}".format(A, B)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self) -> Text:
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self) -> Dict[Text, List]:
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
# (TODO) xuanyidong:
# Here the selected two edges might come from the same input node.
# And this case could be a problem that two edges will collapse into a single one
# due to our assumption -- at most one edge from an input node during evaluation.
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
)
gene_reduce = _parse(
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
"reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
ww = reduce_w
else:
ww = normal_w
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,114 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
from .search_model_enas_utils import Controller
class TinyNetworkENAS(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkENAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# to maintain the sampled architecture
self.sampled_arch = None
def update_arch(self, _arch):
if _arch is None:
self.sampled_arch = None
elif isinstance(_arch, Structure):
self.sampled_arch = _arch
elif isinstance(_arch, (list, tuple)):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_index = _arch[self.edge2index[node_str]]
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
self.sampled_arch = Structure(genotypes)
else:
raise ValueError("invalid type of input architecture : {:}".format(_arch))
return self.sampled_arch
def create_controller(self):
return Controller(len(self.edge2index), len(self.op_names))
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.sampled_arch)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,74 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(
self,
num_edge,
num_ops,
lstm_size=32,
lstm_num_layers=2,
tanh_constant=2.5,
temperature=5.0,
):
super(Controller, self).__init__()
# assign the attributes
self.num_edge = num_edge
self.num_ops = num_ops
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter(
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
)
self.w_lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.lstm_N,
)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append(op_index.item())
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append(op_log_prob.view(-1))
op_entropy = op_distribution.entropy()
entropys.append(op_entropy.view(-1))
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return (
torch.sum(torch.cat(log_probs)),
torch.sum(torch.cat(entropys)),
sampled_arch,
)

View File

@@ -0,0 +1,142 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkGDAS(nn.Module):
# def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
return "arch-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def forward(self, inputs):
while True:
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_gdas(feature, hardwts, index)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,200 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell
from ..cell_operations import RAW_OP_CLASSES
# The macro structure is based on NASNet
class NASNetworkGDAS_FRC(nn.Module):
def __init__(
self,
C,
N,
steps,
multiplier,
stem_multiplier,
num_classes,
search_space,
affine,
track_running_stats,
):
super(NASNetworkGDAS_FRC, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = RAW_OP_CLASSES["gdas_reduction"](
C_prev_prev,
C_prev,
C_curr,
reduction_prev,
affine,
track_running_stats,
)
else:
cell = SearchCell(
search_space,
steps,
multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
reduction
or num_edge == cell.num_edges
and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = (
C_prev,
cell.multiplier * C_curr,
reduction,
)
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
A = "arch-normal-parameters :\n{:}".format(
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
)
return "{:}".format(A)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
return hardwts, index
hardwts, index = get_gumbel_prob(self.arch_parameters)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
s0, s1 = s1, cell(s0, s1)
else:
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,197 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkGDAS(nn.Module):
def __init__(
self,
C,
N,
steps,
multiplier,
stem_multiplier,
num_classes,
search_space,
affine,
track_running_stats,
):
super(NASNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
cell = SearchCell(
search_space,
steps,
multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.tau = 10
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = "arch-normal-parameters :\n{:}".format(
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
)
B = "arch-reduce-parameters :\n{:}".format(
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
)
return "{:}\n{:}".format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
)
gene_reduce = _parse(
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
"reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (
(torch.isinf(gumbels).any())
or (torch.isinf(probs).any())
or (torch.isnan(probs).any())
):
continue
else:
break
return hardwts, index
normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters)
reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction:
hardwts, index = reduce_hardwts, reduce_index
else:
hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,102 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##############################################################################
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
##############################################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkRANDOM(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkRANDOM, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_cache = None
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def random_genotype(self, set_cache):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = random.choice(self.op_names)
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
arch = Structure(genotypes)
if set_cache:
self.arch_cache = arch
return arch
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.arch_cache)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,178 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkSETN(nn.Module):
def __init__(
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
):
super(TinyNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(
C_prev,
C_curr,
1,
max_nodes,
search_space,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev = cell.out_dim
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.mode = "urs"
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ["urs", "joint", "select", "dynamic"]
self.mode = mode
if mode == "dynamic":
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def get_cal_mode(self):
return self.mode
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self):
return [self.arch_parameters]
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.op_names[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def get_log_prob(self, arch):
with torch.no_grad():
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = "{:}<-{:}".format(i + 1, xin)
op_index = self.op_names.index(op)
select_logits.append(logits[self.edge2index[node_str], op_index])
return sum(select_logits).item()
def return_topK(self, K):
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs):
K = len(archs)
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
if self.mode == "urs":
feature = cell.forward_urs(feature)
elif self.mode == "select":
feature = cell.forward_select(feature, alphas_cpu)
elif self.mode == "joint":
feature = cell.forward_joint(feature, alphas)
elif self.mode == "dynamic":
feature = cell.forward_dynamic(feature, self.dynamic_cell)
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,205 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkSETN(nn.Module):
def __init__(
self,
C: int,
N: int,
steps: int,
multiplier: int,
stem_multiplier: int,
num_classes: int,
search_space: List[Text],
affine: bool,
track_running_stats: bool,
):
super(NASNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C * stem_multiplier),
)
# config for each layer
layer_channels = (
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
)
layer_reductions = (
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = (
C * stem_multiplier,
C * stem_multiplier,
C,
False,
)
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(
zip(layer_channels, layer_reductions)
):
cell = SearchCell(
search_space,
steps,
multiplier,
C_prev_prev,
C_prev,
C_curr,
reduction,
reduction_prev,
affine,
track_running_stats,
)
if num_edge is None:
num_edge, edge2index = cell.num_edges, cell.edge2index
else:
assert (
num_edge == cell.num_edges and edge2index == cell.edge2index
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
self.cells.append(cell)
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
self.op_names = deepcopy(search_space)
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.arch_reduce_parameters = nn.Parameter(
1e-3 * torch.randn(num_edge, len(search_space))
)
self.mode = "urs"
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ["urs", "joint", "select", "dynamic"]
self.mode = mode
if mode == "dynamic":
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def get_weights(self):
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
xlist += list(self.lastact.parameters()) + list(
self.global_pooling.parameters()
)
xlist += list(self.classifier.parameters())
return xlist
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = "arch-normal-parameters :\n{:}".format(
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
)
B = "arch-reduce-parameters :\n{:}".format(
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
)
return "{:}\n{:}".format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[self.edge2index[node_str]]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return Structure(genotypes)
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2 + i):
node_str = "{:}<-{:}".format(i, j)
ws = weights[self.edge2index[node_str]]
for k, op_name in enumerate(self.op_names):
if op_name == "none":
continue
edges.append((op_name, j, ws[k]))
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append(tuple(selected_edges))
return gene
with torch.no_grad():
gene_normal = _parse(
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
)
gene_reduce = _parse(
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
)
return {
"normal": gene_normal,
"normal_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
"reduce": gene_reduce,
"reduce_concat": list(
range(2 + self._steps - self._multiplier, self._steps + 2)
),
}
def forward(self, inputs):
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
# [TODO]
raise NotImplementedError
if cell.reduction:
hardwts, index = reduce_hardwts, reduce_index
else:
hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,74 @@
import torch
import torch.nn as nn
def copy_conv(module, init):
assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module)
assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init)
new_i, new_o = module.in_channels, module.out_channels
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
if module.bias is not None:
module.bias.copy_(init.bias.detach()[:new_o])
def copy_bn(module, init):
assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module)
assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init)
num_features = module.num_features
if module.weight is not None:
module.weight.copy_(init.weight.detach()[:num_features])
if module.bias is not None:
module.bias.copy_(init.bias.detach()[:num_features])
if module.running_mean is not None:
module.running_mean.copy_(init.running_mean.detach()[:num_features])
if module.running_var is not None:
module.running_var.copy_(init.running_var.detach()[:num_features])
def copy_fc(module, init):
assert isinstance(module, nn.Linear), "invalid module : {:}".format(module)
assert isinstance(init, nn.Linear), "invalid module : {:}".format(init)
new_i, new_o = module.in_features, module.out_features
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
if module.bias is not None:
module.bias.copy_(init.bias.detach()[:new_o])
def copy_base(module, init):
assert type(module).__name__ in [
"ConvBNReLU",
"Downsample",
], "invalid module : {:}".format(module)
assert type(init).__name__ in [
"ConvBNReLU",
"Downsample",
], "invalid module : {:}".format(init)
if module.conv is not None:
copy_conv(module.conv, init.conv)
if module.bn is not None:
copy_bn(module.bn, init.bn)
def copy_basic(module, init):
copy_base(module.conv_a, init.conv_a)
copy_base(module.conv_b, init.conv_b)
if module.downsample is not None:
if init.downsample is not None:
copy_base(module.downsample, init.downsample)
# else:
# import pdb; pdb.set_trace()
def init_from_model(network, init_model):
with torch.no_grad():
copy_fc(network.classifier, init_model.classifier)
for base, target in zip(init_model.layers, network.layers):
assert (
type(base).__name__ == type(target).__name__
), "invalid type : {:} vs {:}".format(base, target)
if type(base).__name__ == "ConvBNReLU":
copy_base(target, base)
elif type(base).__name__ == "ResNetBasicblock":
copy_basic(target, base)
else:
raise ValueError("unknown type name : {:}".format(type(base).__name__))

View File

@@ -0,0 +1,16 @@
import torch
import torch.nn as nn
def initialize_resnet(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

View File

@@ -0,0 +1,287 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.bn:
out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
self.conv_a = ConvBNReLU(
iCs[0],
iCs[1],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
self.conv_1x1 = ConvBNReLU(
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
iCs[1],
iCs[2],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=False,
has_bn=False,
has_relu=False,
)
residual_in = iCs[3]
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferCifarResNet(nn.Module):
def __init__(
self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual
):
super(InferCifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
self.message = (
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList(
[
ConvBNReLU(
xchannels[0],
xchannels[1],
3,
1,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
]
)
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iCs,
module.out_dim,
stride,
)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL + 1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,263 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.bn:
out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ConvBNReLU(
inplanes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
planes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
planes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=False,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
super(InferDepthCifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
self.message = (
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.layers = nn.ModuleList(
[
ConvBNReLU(
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
)
]
)
self.channels = [16]
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
planes,
module.out_dim,
stride,
)
if iL + 1 == xblocks[stage]: # reach the maximum depth
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.channels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,277 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.bn:
out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
self.conv_a = ConvBNReLU(
iCs[0],
iCs[1],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
self.conv_1x1 = ConvBNReLU(
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
iCs[1],
iCs[2],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=False,
has_bn=False,
has_relu=False,
)
residual_in = iCs[3]
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
super(InferWidthCifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
self.message = (
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList(
[
ConvBNReLU(
xchannels[0],
xchannels[1],
3,
1,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
]
)
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iCs,
module.out_dim,
stride,
)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,324 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.bn:
out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
self.conv_a = ConvBNReLU(
iCs[0],
iCs[1],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=True,
has_bn=True,
has_relu=False,
)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[2],
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
assert isinstance(iCs, tuple) or isinstance(
iCs, list
), "invalid type of iCs : {:}".format(iCs)
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
self.conv_1x1 = ConvBNReLU(
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
iCs[1],
iCs[2],
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=True,
has_bn=True,
has_relu=False,
)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(
iCs[0],
iCs[3],
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
residual_in = iCs[3]
else:
self.downsample = None
# self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferImagenetResNet(nn.Module):
def __init__(
self,
block_name,
layers,
xblocks,
xchannels,
deep_stem,
num_classes,
zero_init_residual,
):
super(InferImagenetResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "BasicBlock":
block = ResNetBasicblock
elif block_name == "Bottleneck":
block = ResNetBottleneck
else:
raise ValueError("invalid block : {:}".format(block_name))
assert len(xblocks) == len(
layers
), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks)
self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format(
sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks
)
self.num_classes = num_classes
self.xchannels = xchannels
if not deep_stem:
self.layers = nn.ModuleList(
[
ConvBNReLU(
xchannels[0],
xchannels[1],
7,
2,
3,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
]
)
last_channel_idx = 1
else:
self.layers = nn.ModuleList(
[
ConvBNReLU(
xchannels[0],
xchannels[1],
3,
2,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
),
ConvBNReLU(
xchannels[1],
xchannels[2],
3,
1,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
),
]
)
last_channel_idx = 2
self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
for stage, layer_blocks in enumerate(layers):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iCs,
module.out_dim,
stride,
)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL + 1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format(
last_channel_idx, len(self.xchannels)
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,176 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
#####################################################
from torch import nn
from ..initialization import initialize_resnet
from ..SharedUtils import parse_channel_info
class ConvBNReLU(nn.Module):
def __init__(
self,
in_planes,
out_planes,
kernel_size,
stride,
groups,
has_bn=True,
has_relu=True,
):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias=False,
)
if has_bn:
self.bn = nn.BatchNorm2d(out_planes)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU6(inplace=True)
else:
self.relu = None
def forward(self, x):
out = self.conv(x)
if self.bn:
out = self.bn(out)
if self.relu:
out = self.relu(out)
return out
class InvertedResidual(nn.Module):
def __init__(self, channels, stride, expand_ratio, additive):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2], "invalid stride : {:}".format(stride)
assert len(channels) in [2, 3], "invalid channels : {:}".format(channels)
if len(channels) == 2:
layers = []
else:
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
layers.extend(
[
# dw
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
# pw-linear
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
]
)
self.conv = nn.Sequential(*layers)
self.additive = additive
if self.additive and channels[0] != channels[-1]:
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
else:
self.shortcut = None
self.out_dim = channels[-1]
def forward(self, x):
out = self.conv(x)
# if self.additive: return additive_func(out, x)
if self.shortcut:
return out + self.shortcut(x)
else:
return out
class InferMobileNetV2(nn.Module):
def __init__(self, num_classes, xchannels, xblocks, dropout):
super(InferMobileNetV2, self).__init__()
block = InvertedResidual
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
assert len(inverted_residual_setting) == len(
xblocks
), "invalid number of layers : {:} vs {:}".format(
len(inverted_residual_setting), len(xblocks)
)
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
assert block_num <= ir_setting[2], "{:} vs {:}".format(
block_num, ir_setting
)
xchannels = parse_channel_info(xchannels)
# for i, chs in enumerate(xchannels):
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
self.xchannels = xchannels
self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks)
# building first layer
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
last_channel_idx = 1
# building inverted residual blocks
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
for i in range(n):
stride = s if i == 0 else 1
additv = True if i > 0 else False
module = block(self.xchannels[last_channel_idx], stride, t, additv)
features.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(
stage,
i,
n,
len(features),
self.xchannels[last_channel_idx],
stride,
t,
c,
)
last_channel_idx += 1
if i + 1 == xblocks[stage]:
out_channel = module.out_dim
for iiL in range(i + 1, n):
last_channel_idx += 1
self.xchannels[last_channel_idx][0] = module.out_dim
break
# building last several layers
features.append(
ConvBNReLU(
self.xchannels[last_channel_idx][0],
self.xchannels[last_channel_idx][1],
1,
1,
1,
)
)
assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format(
last_channel_idx, len(self.xchannels)
)
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
)
# weight initialization
self.apply(initialize_resnet)
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

View File

@@ -0,0 +1,74 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from typing import List, Text, Any
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from ..cell_infers.cells import InferCell
class DynamicShapeTinyNet(nn.Module):
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
super(DynamicShapeTinyNet, self).__init__()
self._channels = channels
if len(channels) % 3 != 2:
raise ValueError("invalid number of layers : {:}".format(len(channels)))
self._num_stage = N = len(channels) // 3
self.stem = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels[0]),
)
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
c_prev = channels[0]
self.cells = nn.ModuleList()
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(c_prev, c_curr, 2, True)
else:
cell = InferCell(genotype, c_prev, c_curr, 1)
self.cells.append(cell)
c_prev = cell.out_dim
self._num_layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, num_classes)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self.cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
def forward_pre_GAP(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
return out

View File

@@ -0,0 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .InferCifarResNet_width import InferWidthCifarResNet
from .InferImagenetResNet import InferImagenetResNet
from .InferCifarResNet_depth import InferDepthCifarResNet
from .InferCifarResNet import InferCifarResNet
from .InferMobileNetV2 import InferMobileNetV2
from .InferTinyCellNet import DynamicShapeTinyNet

View File

@@ -0,0 +1,5 @@
def parse_channel_info(xstring):
blocks = xstring.split(" ")
blocks = [x.split("-") for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,760 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(nDepth, return_num):
if nDepth == 2:
choices = (1, 2)
elif nDepth == 3:
choices = (1, 2, 3)
elif nDepth > 3:
choices = list(range(1, nDepth + 1, 2))
if choices[-1] < nDepth:
choices.append(nDepth)
else:
raise ValueError("invalid nDepth : {:}".format(nDepth))
if return_num:
return len(choices)
else:
return choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:, :oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
# else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = "basic"
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range:
assert (
iC <= self.conv.in_channels and oC <= self.conv.out_channels
), "{:} vs {:} | {:} vs {:}".format(
iC, self.conv.in_channels, oC, self.conv.out_channels
)
assert (
isinstance(self.InShape, tuple) and len(self.InShape) == 2
), "invalid in-shape : {:}".format(self.InShape)
assert (
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
), "invalid out-shape : {:}".format(self.OutShape)
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None:
flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, "invalid length : {:}".format(index)
# compute expected flop
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg:
out = self.avg(inputs)
else:
out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu:
out = self.relu(out)
else:
out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.has_bn:
out = self.BNs[-1](conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2), out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ConvBNReLU(
inplanes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = "basic"
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, "invalid channels : {:}".format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, "get_flops"):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_C = (
channels[0]
* channels[-1]
* self.conv_b.OutShape[0]
* self.conv_b.OutShape[1]
)
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
out_a, expected_inC_a, expected_flop_a = self.conv_a(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
out_b, expected_inC_b, expected_flop_b = self.conv_b(
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[1], indexes[1], probs[1])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return (
nn.functional.relu(out, inplace=True),
expected_inC_b,
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
)
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
planes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
planes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = "basic"
def get_range(self):
return (
self.conv_1x1.get_range()
+ self.conv_3x3.get_range()
+ self.conv_1x4.get_range()
)
def get_flops(self, channels):
assert len(channels) == 4, "invalid channels : {:}".format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, "get_flops"):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_D = (
channels[0]
* channels[-1]
* self.conv_1x4.OutShape[0]
* self.conv_1x4.OutShape[1]
)
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
)
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[2], indexes[2], probs[2])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return (
nn.functional.relu(out, inplace=True),
expected_inC_1x4,
sum(
[
expected_flop_1x1,
expected_flop_3x3,
expected_flop_1x4,
expected_flop_c,
]
),
)
class SearchShapeCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchShapeCifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
self.message = (
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList(
[
ConvBNReLU(
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
)
]
)
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage in range(3):
cur_block_choices = get_depth_choices(layer_blocks, False)
assert (
cur_block_choices[-1] == layer_blocks
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
self.message += (
"\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(
stage, cur_block_choices, layer_blocks
)
)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices:
block_choices.append(layer_index)
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {
"choices": block_choices,
"stage": stage,
"xstart": xstart,
}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append((xend, info))
xstart, xstage = info["xstart"], info["stage"]
for ilayer in range(xstart, xend + 1):
idx = bisect_right(info["choices"], ilayer - 1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = "basic"
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append((start_index, len(self.Ranges)))
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
len(self.Ranges) + 1, depth
)
self.register_parameter(
"width_attentions",
nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))),
)
self.register_parameter(
"depth_attentions",
nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))),
)
nn.init.normal_(self.width_attentions, 0, 0.01)
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self, LR=None):
if LR is None:
return [self.width_attentions, self.depth_attentions]
else:
return [
{"params": self.width_attentions, "lr": LR},
{"params": self.depth_attentions, "lr": LR},
]
def base_parameters(self):
return (
list(self.layers.parameters())
+ list(self.avgpool.parameters())
+ list(self.classifier.parameters())
)
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None:
config_dict = config_dict.copy()
# select channels
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == "genotype":
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][torch.argmax(probe).item()]
elif mode == "max":
C = self.Ranges[i][-1]
elif mode == "fix":
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
elif mode == "random":
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
extra_info
)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
for j in range(prob.size(0)):
prob[j] = 1 / (
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
)
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
else:
raise ValueError("invalid mode : {:}".format(mode))
channels.append(C)
# select depth
if mode == "genotype":
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
elif mode == "max" or mode == "fix":
choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))]
elif mode == "random":
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
else:
raise ValueError("invalid mode : {:}".format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple(channels[s : e + 1])
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop += layer.get_flops(xchl)
else:
flop += 0 # do not use this layer
else:
flop += layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict["xchannels"] = channels
config_dict["xblocks"] = selected_layers
config_dict["super_type"] = "infer-shape"
config_dict["estimated_FLOP"] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = (
"for depth and width, there are {:} + {:} attention probabilities.".format(
len(self.depth_attentions), len(self.width_attentions)
)
)
string += "\n{:}".format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu()
selc = prob.argmax().item()
prob = prob.tolist()
prob = ["{:.3f}".format(x) for x in prob]
xstring = "{:03d}/{:03d}-th : {:}".format(
i, len(self.depth_attentions), " ".join(prob)
)
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:17s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
string += "\n-----------------------------------------------"
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu()
selc = prob.argmax().item()
prob = prob.tolist()
prob = ["{:.3f}".format(x) for x in prob]
xstring = "{:03d}/{:03d}-th : {:}".format(
i, len(self.width_attentions), " ".join(prob)
)
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:52s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || dis={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert (
epoch_ratio >= 0 and epoch_ratio <= 1
), "invalid epoch-ratio : {:}".format(epoch_ratio)
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, inputs):
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip(
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
)
selected_widths, selected_width_probs = select2withP(
self.width_attentions, self.tau
)
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
feature_maps = []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[
last_channel_idx : last_channel_idx + layer.num_conv
]
selected_w_probs = selected_width_probs[
last_channel_idx : last_channel_idx + layer.num_conv
]
layer_prob = flop_width_probs[
last_channel_idx : last_channel_idx + layer.num_conv
]
x, expected_inC, expected_flop = layer(
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
)
feature_maps.append(x)
last_channel_idx += layer.num_conv
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]["choices"]
xstagei = self.depth_info[i]["stage"]
# print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
# for A, W in zip(choices, selected_depth_probs[xstagei]):
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
possible_tensors = []
max_C = max(feature_maps[A].size(1) for A in choices)
for tempi, A in enumerate(choices):
xtensor = ChannelWiseInter(feature_maps[A], max_C)
# drop_ratio = 1-(tempi+1.0)/len(choices)
# xtensor = drop_path(xtensor, drop_ratio)
possible_tensors.append(xtensor)
weighted_sum = sum(
xtensor * W
for xtensor, W in zip(
possible_tensors, selected_depth_probs[xstagei]
)
)
x = weighted_sum
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
else:
x_expected_flop = expected_flop
flops.append(x_expected_flop)
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack([sum(flops)])
def basic_forward(self, inputs):
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,515 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(nDepth, return_num):
if nDepth == 2:
choices = (1, 2)
elif nDepth == 3:
choices = (1, 2, 3)
elif nDepth > 3:
choices = list(range(1, nDepth + 1, 2))
if choices[-1] < nDepth:
choices.append(nDepth)
else:
raise ValueError("invalid nDepth : {:}".format(nDepth))
if return_num:
return len(choices)
else:
return choices
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
if has_bn:
self.bn = nn.BatchNorm2d(nOut)
else:
self.bn = None
if has_relu:
self.relu = nn.ReLU(inplace=False)
else:
self.relu = None
self.in_dim = nIn
self.out_dim = nOut
def get_flops(self, divide=1):
iC, oC = self.in_dim, self.out_dim
assert (
iC <= self.conv.in_channels and oC <= self.conv.out_channels
), "{:} vs {:} | {:} vs {:}".format(
iC, self.conv.in_channels, oC, self.conv.out_channels
)
assert (
isinstance(self.InShape, tuple) and len(self.InShape) == 2
), "invalid in-shape : {:}".format(self.InShape)
assert (
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
), "invalid out-shape : {:}".format(self.OutShape)
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None:
flops += all_positions / divide
return flops
def forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.bn:
out = self.bn(conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2), out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ConvBNReLU(
inplanes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = "basic"
def get_flops(self, divide=1):
flop_A = self.conv_a.get_flops(divide)
flop_B = self.conv_b.get_flops(divide)
if hasattr(self.downsample, "get_flops"):
flop_C = self.downsample.get_flops(divide)
else:
flop_C = 0
return flop_A + flop_B + flop_C
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
planes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
planes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = "basic"
def get_range(self):
return (
self.conv_1x1.get_range()
+ self.conv_3x3.get_range()
+ self.conv_1x4.get_range()
)
def get_flops(self, divide):
flop_A = self.conv_1x1.get_flops(divide)
flop_B = self.conv_3x3.get_flops(divide)
flop_C = self.conv_1x4.get_flops(divide)
if hasattr(self.downsample, "get_flops"):
flop_D = self.downsample.get_flops(divide)
else:
flop_D = 0
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
class SearchDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchDepthCifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
self.message = (
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList(
[
ConvBNReLU(
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
)
]
)
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage in range(3):
cur_block_choices = get_depth_choices(layer_blocks, False)
assert (
cur_block_choices[-1] == layer_blocks
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
self.message += (
"\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(
stage, cur_block_choices, layer_blocks
)
)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices:
block_choices.append(layer_index)
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {
"choices": block_choices,
"stage": stage,
"xstart": xstart,
}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append((xend, info))
xstart, xstage = info["xstart"], info["stage"]
for ilayer in range(xstart, xend + 1):
idx = bisect_right(info["choices"], ilayer - 1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = "basic"
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
self.register_parameter(
"depth_attentions",
nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))),
)
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.depth_attentions]
def base_parameters(self):
return (
list(self.layers.parameters())
+ list(self.avgpool.parameters())
+ list(self.classifier.parameters())
)
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None:
config_dict = config_dict.copy()
# select depth
if mode == "genotype":
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
elif mode == "max":
choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))]
elif mode == "random":
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
else:
raise ValueError("invalid mode : {:}".format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop += layer.get_flops()
else:
flop += 0 # do not use this layer
else:
flop += layer.get_flops()
# the last fc layer
flop += self.classifier.in_features * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict["xblocks"] = selected_layers
config_dict["super_type"] = "infer-depth"
config_dict["estimated_FLOP"] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for depth, there are {:} attention probabilities.".format(
len(self.depth_attentions)
)
string += "\n{:}".format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu()
selc = prob.argmax().item()
prob = prob.tolist()
prob = ["{:.3f}".format(x) for x in prob]
xstring = "{:03d}/{:03d}-th : {:}".format(
i, len(self.depth_attentions), " ".join(prob)
)
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:17s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert (
epoch_ratio >= 0 and epoch_ratio <= 1
), "invalid epoch-ratio : {:}".format(epoch_ratio)
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, inputs):
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip(
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
)
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
x, flops = inputs, []
feature_maps = []
for i, layer in enumerate(self.layers):
layer_i = layer(x)
feature_maps.append(layer_i)
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]["choices"]
xstagei = self.depth_info[i]["stage"]
possible_tensors = []
for tempi, A in enumerate(choices):
xtensor = feature_maps[A]
possible_tensors.append(xtensor)
weighted_sum = sum(
xtensor * W
for xtensor, W in zip(
possible_tensors, selected_depth_probs[xstagei]
)
)
x = weighted_sum
else:
x = layer_i
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
# print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6)))
x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops(
1e6
)
else:
x_expected_flop = layer.get_flops(1e6)
flops.append(x_expected_flop)
flops.append(
(self.classifier.in_features * self.classifier.out_features * 1.0 / 1e6)
)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack([sum(flops)])
def basic_forward(self, inputs):
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,619 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices as get_choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:, :oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
# else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = "basic"
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range:
assert (
iC <= self.conv.in_channels and oC <= self.conv.out_channels
), "{:} vs {:} | {:} vs {:}".format(
iC, self.conv.in_channels, oC, self.conv.out_channels
)
assert (
isinstance(self.InShape, tuple) and len(self.InShape) == 2
), "invalid in-shape : {:}".format(self.InShape)
assert (
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
), "invalid out-shape : {:}".format(self.OutShape)
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None:
flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, "invalid length : {:}".format(index)
# compute expected flop
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg:
out = self.avg(inputs)
else:
out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu:
out = self.relu(out)
else:
out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.has_bn:
out = self.BNs[-1](conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2), out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ConvBNReLU(
inplanes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = "basic"
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, "invalid channels : {:}".format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, "get_flops"):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_C = (
channels[0]
* channels[-1]
* self.conv_b.OutShape[0]
* self.conv_b.OutShape[1]
)
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
out_a, expected_inC_a, expected_flop_a = self.conv_a(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
out_b, expected_inC_b, expected_flop_b = self.conv_b(
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[1], indexes[1], probs[1])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return (
nn.functional.relu(out, inplace=True),
expected_inC_b,
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
)
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
planes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
planes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = "basic"
def get_range(self):
return (
self.conv_1x1.get_range()
+ self.conv_3x3.get_range()
+ self.conv_1x4.get_range()
)
def get_flops(self, channels):
assert len(channels) == 4, "invalid channels : {:}".format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, "get_flops"):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_D = (
channels[0]
* channels[-1]
* self.conv_1x4.OutShape[0]
* self.conv_1x4.OutShape[1]
)
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
)
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[2], indexes[2], probs[2])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return (
nn.functional.relu(out, inplace=True),
expected_inC_1x4,
sum(
[
expected_flop_1x1,
expected_flop_3x3,
expected_flop_1x4,
expected_flop_c,
]
),
)
class SearchWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchWidthCifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "ResNetBasicblock":
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
layer_blocks = (depth - 2) // 6
elif block_name == "ResNetBottleneck":
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, "depth should be one of 164"
layer_blocks = (depth - 2) // 9
else:
raise ValueError("invalid block : {:}".format(block_name))
self.message = (
"SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList(
[
ConvBNReLU(
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
)
]
)
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = "basic"
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append((start_index, len(self.Ranges)))
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
len(self.Ranges) + 1, depth
)
self.register_parameter(
"width_attentions",
nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))),
)
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.width_attentions]
def base_parameters(self):
return (
list(self.layers.parameters())
+ list(self.avgpool.parameters())
+ list(self.classifier.parameters())
)
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None:
config_dict = config_dict.copy()
# weights = [F.softmax(x, dim=0) for x in self.width_attentions]
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == "genotype":
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][torch.argmax(probe).item()]
elif mode == "max":
C = self.Ranges[i][-1]
elif mode == "fix":
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
elif mode == "random":
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
extra_info
)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
for j in range(prob.size(0)):
prob[j] = 1 / (
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
)
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
else:
raise ValueError("invalid mode : {:}".format(mode))
channels.append(C)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple(channels[s : e + 1])
flop += layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict["xchannels"] = channels
config_dict["super_type"] = "infer-width"
config_dict["estimated_FLOP"] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for width, there are {:} attention probabilities.".format(
len(self.width_attentions)
)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu()
selc = prob.argmax().item()
prob = prob.tolist()
prob = ["{:.3f}".format(x) for x in prob]
xstring = "{:03d}/{:03d}-th : {:}".format(
i, len(self.width_attentions), " ".join(prob)
)
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:52s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || dis={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert (
epoch_ratio >= 0 and epoch_ratio <= 1
), "invalid epoch-ratio : {:}".format(epoch_ratio)
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[
last_channel_idx : last_channel_idx + layer.num_conv
]
selected_w_probs = selected_probs[
last_channel_idx : last_channel_idx + layer.num_conv
]
layer_prob = flop_probs[
last_channel_idx : last_channel_idx + layer.num_conv
]
x, expected_inC, expected_flop = layer(
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
)
last_channel_idx += layer.num_conv
flops.append(expected_flop)
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack([sum(flops)])
def basic_forward(self, inputs):
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,766 @@
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(layers):
min_depth = min(layers)
info = {"num": min_depth}
for i, depth in enumerate(layers):
choices = []
for j in range(1, min_depth + 1):
choices.append(int(float(depth) * j / min_depth))
info[i] = choices
return info
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:, :oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(
self,
nIn,
nOut,
kernel,
stride,
padding,
bias,
has_avg,
has_bn,
has_relu,
last_max_pool=False,
):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
# else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
if last_max_pool:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
self.maxpool = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = "basic"
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range:
assert (
iC <= self.conv.in_channels and oC <= self.conv.out_channels
), "{:} vs {:} | {:} vs {:}".format(
iC, self.conv.in_channels, oC, self.conv.out_channels
)
assert (
isinstance(self.InShape, tuple) and len(self.InShape) == 2
), "invalid in-shape : {:}".format(self.InShape)
assert (
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
), "invalid out-shape : {:}".format(self.OutShape)
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None:
flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, "invalid length : {:}".format(index)
# compute expected flop
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg:
out = self.avg(inputs)
else:
out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu:
out = self.relu(out)
if self.maxpool:
out = self.maxpool(out)
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.has_bn:
out = self.BNs[-1](conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2), out.size(-1))
if self.maxpool:
out = self.maxpool(out)
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ConvBNReLU(
inplanes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_b = ConvBNReLU(
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=True,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = "basic"
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, "invalid channels : {:}".format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, "get_flops"):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_C = (
channels[0]
* channels[-1]
* self.conv_b.OutShape[0]
* self.conv_b.OutShape[1]
)
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
# import pdb; pdb.set_trace()
out_a, expected_inC_a, expected_flop_a = self.conv_a(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
out_b, expected_inC_b, expected_flop_b = self.conv_b(
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[1], indexes[1], probs[1])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return (
nn.functional.relu(out, inplace=True),
expected_inC_b,
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
)
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_1x1 = ConvBNReLU(
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
)
self.conv_3x3 = ConvBNReLU(
planes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
self.conv_1x4 = ConvBNReLU(
planes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=True,
has_bn=True,
has_relu=False,
)
elif inplanes != planes * self.expansion:
self.downsample = ConvBNReLU(
inplanes,
planes * self.expansion,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = "basic"
def get_range(self):
return (
self.conv_1x1.get_range()
+ self.conv_3x3.get_range()
+ self.conv_1x4.get_range()
)
def get_flops(self, channels):
assert len(channels) == 4, "invalid channels : {:}".format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, "get_flops"):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_D = (
channels[0]
* channels[-1]
* self.conv_1x4.OutShape[0]
* self.conv_1x4.OutShape[1]
)
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
)
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[2], indexes[2], probs[2])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return (
nn.functional.relu(out, inplace=True),
expected_inC_1x4,
sum(
[
expected_flop_1x1,
expected_flop_3x3,
expected_flop_1x4,
expected_flop_c,
]
),
)
class SearchShapeImagenetResNet(nn.Module):
def __init__(self, block_name, layers, deep_stem, num_classes):
super(SearchShapeImagenetResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == "BasicBlock":
block = ResNetBasicblock
elif block_name == "Bottleneck":
block = ResNetBottleneck
else:
raise ValueError("invalid block : {:}".format(block_name))
self.message = (
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
sum(layers) * block.num_conv, layers
)
)
self.num_classes = num_classes
if not deep_stem:
self.layers = nn.ModuleList(
[
ConvBNReLU(
3,
64,
7,
2,
3,
False,
has_avg=False,
has_bn=True,
has_relu=True,
last_max_pool=True,
)
]
)
self.channels = [64]
else:
self.layers = nn.ModuleList(
[
ConvBNReLU(
3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True
),
ConvBNReLU(
32,
64,
3,
1,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
last_max_pool=True,
),
]
)
self.channels = [32, 64]
meta_depth_info = get_depth_choices(layers)
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage, layer_blocks in enumerate(layers):
cur_block_choices = meta_depth_info[stage]
assert (
cur_block_choices[-1] == layer_blocks
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 64 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices:
block_choices.append(layer_index)
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {
"choices": block_choices,
"stage": stage,
"xstart": xstart,
}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append((xend, info))
xstart, xstage = info["xstart"], info["stage"]
for ilayer in range(xstart, xend + 1):
idx = bisect_right(info["choices"], ilayer - 1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = "basic"
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append((start_index, len(self.Ranges)))
self.register_parameter(
"width_attentions",
nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))),
)
self.register_parameter(
"depth_attentions",
nn.Parameter(torch.Tensor(len(layers), meta_depth_info["num"])),
)
nn.init.normal_(self.width_attentions, 0, 0.01)
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self, LR=None):
if LR is None:
return [self.width_attentions, self.depth_attentions]
else:
return [
{"params": self.width_attentions, "lr": LR},
{"params": self.depth_attentions, "lr": LR},
]
def base_parameters(self):
return (
list(self.layers.parameters())
+ list(self.avgpool.parameters())
+ list(self.classifier.parameters())
)
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None:
config_dict = config_dict.copy()
# select channels
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == "genotype":
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][torch.argmax(probe).item()]
else:
raise ValueError("invalid mode : {:}".format(mode))
channels.append(C)
# select depth
if mode == "genotype":
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
else:
raise ValueError("invalid mode : {:}".format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple(channels[s : e + 1])
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop += layer.get_flops(xchl)
else:
flop += 0 # do not use this layer
else:
flop += layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict["xchannels"] = channels
config_dict["xblocks"] = selected_layers
config_dict["super_type"] = "infer-shape"
config_dict["estimated_FLOP"] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = (
"for depth and width, there are {:} + {:} attention probabilities.".format(
len(self.depth_attentions), len(self.width_attentions)
)
)
string += "\n{:}".format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu()
selc = prob.argmax().item()
prob = prob.tolist()
prob = ["{:.3f}".format(x) for x in prob]
xstring = "{:03d}/{:03d}-th : {:}".format(
i, len(self.depth_attentions), " ".join(prob)
)
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:17s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
string += "\n-----------------------------------------------"
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu()
selc = prob.argmax().item()
prob = prob.tolist()
prob = ["{:.3f}".format(x) for x in prob]
xstring = "{:03d}/{:03d}-th : {:}".format(
i, len(self.width_attentions), " ".join(prob)
)
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:52s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || dis={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert (
epoch_ratio >= 0 and epoch_ratio <= 1
), "invalid epoch-ratio : {:}".format(epoch_ratio)
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, inputs):
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip(
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
)
selected_widths, selected_width_probs = select2withP(
self.width_attentions, self.tau
)
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
feature_maps = []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[
last_channel_idx : last_channel_idx + layer.num_conv
]
selected_w_probs = selected_width_probs[
last_channel_idx : last_channel_idx + layer.num_conv
]
layer_prob = flop_width_probs[
last_channel_idx : last_channel_idx + layer.num_conv
]
x, expected_inC, expected_flop = layer(
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
)
feature_maps.append(x)
last_channel_idx += layer.num_conv
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]["choices"]
xstagei = self.depth_info[i]["stage"]
# print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
# for A, W in zip(choices, selected_depth_probs[xstagei]):
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
possible_tensors = []
max_C = max(feature_maps[A].size(1) for A in choices)
for tempi, A in enumerate(choices):
xtensor = ChannelWiseInter(feature_maps[A], max_C)
possible_tensors.append(xtensor)
weighted_sum = sum(
xtensor * W
for xtensor, W in zip(
possible_tensors, selected_depth_probs[xstagei]
)
)
x = weighted_sum
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
else:
x_expected_flop = expected_flop
flops.append(x_expected_flop)
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack([sum(flops)])
def basic_forward(self, inputs):
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,466 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices as get_choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:, :oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
if has_avg:
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.avg = None
self.conv = nn.Conv2d(
nIn,
nOut,
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=1,
groups=1,
bias=bias,
)
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
# else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = "basic"
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range:
assert (
iC <= self.conv.in_channels and oC <= self.conv.out_channels
), "{:} vs {:} | {:} vs {:}".format(
iC, self.conv.in_channels, oC, self.conv.out_channels
)
assert (
isinstance(self.InShape, tuple) and len(self.InShape) == 2
), "invalid in-shape : {:}".format(self.InShape)
assert (
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
), "invalid out-shape : {:}".format(self.OutShape)
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None:
flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, "invalid length : {:}".format(index)
# compute expected flop
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg:
out = self.avg(inputs)
else:
out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu:
out = self.relu(out)
else:
out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg:
out = self.avg(inputs)
else:
out = inputs
conv = self.conv(out)
if self.has_bn:
out = self.BNs[-1](conv)
else:
out = conv
if self.relu:
out = self.relu(out)
else:
out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2), out.size(-1))
return out
class SimBlock(nn.Module):
expansion = 1
num_conv = 1
def __init__(self, inplanes, planes, stride):
super(SimBlock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv = ConvBNReLU(
inplanes,
planes,
3,
stride,
1,
False,
has_avg=False,
has_bn=True,
has_relu=True,
)
if stride == 2:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=True,
has_bn=False,
has_relu=False,
)
elif inplanes != planes:
self.downsample = ConvBNReLU(
inplanes,
planes,
1,
1,
0,
False,
has_avg=False,
has_bn=True,
has_relu=False,
)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = "basic"
def get_range(self):
return self.conv.get_range()
def get_flops(self, channels):
assert len(channels) == 2, "invalid channels : {:}".format(channels)
flop_A = self.conv.get_flops([channels[0], channels[1]])
if hasattr(self.downsample, "get_flops"):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if (
channels[0] != channels[-1] and self.downsample is None
): # this short-cut will be added during the infer-train
flop_C = (
channels[0]
* channels[-1]
* self.conv.OutShape[0]
* self.conv.OutShape[1]
)
return flop_A + flop_C
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, tuple_inputs):
assert (
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
), "invalid type input : {:}".format(type(tuple_inputs))
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert (
indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1
), "invalid size : {:}, {:}, {:}".format(
indexes.size(), probs.size(), probability.size()
)
out, expected_next_inC, expected_flop = self.conv(
(inputs, expected_inC, probability[0], indexes[0], probs[0])
)
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample(
(inputs, expected_inC, probability[-1], indexes[-1], probs[-1])
)
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out)
return (
nn.functional.relu(out, inplace=True),
expected_next_inC,
sum([expected_flop, expected_flop_c]),
)
def basic_forward(self, inputs):
basicblock = self.conv(inputs)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class SearchWidthSimResNet(nn.Module):
def __init__(self, depth, num_classes):
super(SearchWidthSimResNet, self).__init__()
assert (
depth - 2
) % 3 == 0, "depth should be one of 5, 8, 11, 14, ... instead of {:}".format(
depth
)
layer_blocks = (depth - 2) // 3
self.message = (
"SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}".format(
depth, layer_blocks
)
)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList(
[
ConvBNReLU(
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
)
]
)
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2 ** stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = SimBlock(iC, planes, stride)
self.channels.append(module.out_dim)
self.layers.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
stage,
iL,
layer_blocks,
len(self.layers) - 1,
iC,
module.out_dim,
stride,
)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = "basic"
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append((start_index, len(self.Ranges)))
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
len(self.Ranges) + 1, depth
)
self.register_parameter(
"width_attentions",
nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))),
)
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.width_attentions]
def base_parameters(self):
return (
list(self.layers.parameters())
+ list(self.avgpool.parameters())
+ list(self.classifier.parameters())
)
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None:
config_dict = config_dict.copy()
# weights = [F.softmax(x, dim=0) for x in self.width_attentions]
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == "genotype":
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][torch.argmax(probe).item()]
elif mode == "max":
C = self.Ranges[i][-1]
elif mode == "fix":
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
elif mode == "random":
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
extra_info
)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
for j in range(prob.size(0)):
prob[j] = 1 / (
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
)
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
else:
raise ValueError("invalid mode : {:}".format(mode))
channels.append(C)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple(channels[s : e + 1])
flop += layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict["xchannels"] = channels
config_dict["super_type"] = "infer-width"
config_dict["estimated_FLOP"] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for width, there are {:} attention probabilities.".format(
len(self.width_attentions)
)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu()
selc = prob.argmax().item()
prob = prob.tolist()
prob = ["{:.3f}".format(x) for x in prob]
xstring = "{:03d}/{:03d}-th : {:}".format(
i, len(self.width_attentions), " ".join(prob)
)
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
xstring += " || {:52s}".format(" ".join(logt))
prob = sorted([float(x) for x in prob])
disc = prob[-1] - prob[-2]
xstring += " || dis={:.2f} || select={:}/{:}".format(
disc, selc, len(prob)
)
discrepancy.append(disc)
string += "\n{:}".format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert (
epoch_ratio >= 0 and epoch_ratio <= 1
), "invalid epoch-ratio : {:}".format(epoch_ratio)
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == "basic":
return self.basic_forward(inputs)
elif self.search_mode == "search":
return self.search_forward(inputs)
else:
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[
last_channel_idx : last_channel_idx + layer.num_conv
]
selected_w_probs = selected_probs[
last_channel_idx : last_channel_idx + layer.num_conv
]
layer_prob = flop_probs[
last_channel_idx : last_channel_idx + layer.num_conv
]
x, expected_inC, expected_flop = layer(
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
)
last_channel_idx += layer.num_conv
flops.append(expected_flop)
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack([sum(flops)])
def basic_forward(self, inputs):
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,128 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
if tau <= 0:
new_logits = logits
probs = nn.functional.softmax(new_logits, dim=1)
else:
while True: # a trick to avoid the gumbels bug
gumbels = -torch.empty_like(logits).exponential_().log()
new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
probs = nn.functional.softmax(new_logits, dim=1)
if (
(not torch.isinf(gumbels).any())
and (not torch.isinf(probs).any())
and (not torch.isnan(probs).any())
):
break
if just_prob:
return probs
# with torch.no_grad(): # add eps for unexpected torch error
# probs = nn.functional.softmax(new_logits, dim=1)
# selected_index = torch.multinomial(probs + eps, 2, False)
with torch.no_grad(): # add eps for unexpected torch error
probs = probs.cpu()
selected_index = torch.multinomial(probs + eps, num, False).to(logits.device)
selected_logit = torch.gather(new_logits, 1, selected_index)
selcted_probs = nn.functional.softmax(selected_logit, dim=1)
return selected_index, selcted_probs
def ChannelWiseInter(inputs, oC, mode="v2"):
if mode == "v1":
return ChannelWiseInterV1(inputs, oC)
elif mode == "v2":
return ChannelWiseInterV2(inputs, oC)
else:
raise ValueError("invalid mode : {:}".format(mode))
def ChannelWiseInterV1(inputs, oC):
assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size())
def start_index(a, b, c):
return int(math.floor(float(a * c) / b))
def end_index(a, b, c):
return int(math.ceil(float((a + 1) * c) / b))
batch, iC, H, W = inputs.size()
outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device)
if iC == oC:
return inputs
for ot in range(oC):
istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC)
values = inputs[:, istartT:iendT].mean(dim=1)
outputs[:, ot, :, :] = values
return outputs
def ChannelWiseInterV2(inputs, oC):
assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size())
batch, C, H, W = inputs.size()
if C == oC:
return inputs
else:
return nn.functional.adaptive_avg_pool3d(inputs, (oC, H, W))
# inputs_5D = inputs.view(batch, 1, C, H, W)
# otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None)
# otputs = otputs_5D.view(batch, oC, H, W)
# otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False)
# return otputs
def linear_forward(inputs, linear):
if linear is None:
return inputs
iC = inputs.size(1)
weight = linear.weight[:, :iC]
if linear.bias is None:
bias = None
else:
bias = linear.bias
return nn.functional.linear(inputs, weight, bias)
def get_width_choices(nOut):
xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
if nOut is None:
return len(xsrange)
else:
Xs = [int(nOut * i) for i in xsrange]
# xs = [ int(nOut * i // 10) for i in range(2, 11)]
# Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1]
Xs = sorted(list(set(Xs)))
return tuple(Xs)
def get_depth_choices(nDepth):
if nDepth is None:
return 3
else:
assert nDepth >= 3, "nDepth should be greater than 2 vs {:}".format(nDepth)
if nDepth == 1:
return (1, 1, 1)
elif nDepth == 2:
return (1, 1, 2)
elif nDepth >= 3:
return (nDepth // 3, nDepth * 2 // 3, nDepth)
else:
raise ValueError("invalid Depth : {:}".format(nDepth))
def drop_path(x, drop_prob):
if drop_prob > 0.0:
keep_prob = 1.0 - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = x * (mask / keep_prob)
# x.div_(keep_prob)
# x.mul_(mask)
return x

View File

@@ -0,0 +1,9 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .SearchCifarResNet_width import SearchWidthCifarResNet
from .SearchCifarResNet_depth import SearchDepthCifarResNet
from .SearchCifarResNet import SearchShapeCifarResNet
from .SearchSimResNet_width import SearchWidthSimResNet
from .SearchImagenetResNet import SearchShapeImagenetResNet
from .generic_size_tiny_cell_model import GenericNAS301Model

View File

@@ -0,0 +1,209 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# Here, we utilized three techniques to search for the number of channels:
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
from typing import List, Text, Any
import random, torch
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from ..cell_infers.cells import InferCell
from .SoftSelect import select2withP, ChannelWiseInter
class GenericNAS301Model(nn.Module):
def __init__(
self,
candidate_Cs: List[int],
max_num_Cs: int,
genotype: Any,
num_classes: int,
affine: bool,
track_running_stats: bool,
):
super(GenericNAS301Model, self).__init__()
self._max_num_Cs = max_num_Cs
self._candidate_Cs = candidate_Cs
if max_num_Cs % 3 != 2:
raise ValueError("invalid number of layers : {:}".format(max_num_Cs))
self._num_stage = N = max_num_Cs // 3
self._max_C = max(candidate_Cs)
stem = nn.Sequential(
nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine),
nn.BatchNorm2d(
self._max_C, affine=affine, track_running_stats=track_running_stats
),
)
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
c_prev = self._max_C
self._cells = nn.ModuleList()
self._cells.append(stem)
for index, reduction in enumerate(layer_reductions):
if reduction:
cell = ResNetBasicblock(c_prev, self._max_C, 2, True)
else:
cell = InferCell(
genotype, c_prev, self._max_C, 1, affine, track_running_stats
)
self._cells.append(cell)
c_prev = cell.out_dim
self._num_layer = len(self._cells)
self.lastact = nn.Sequential(
nn.BatchNorm2d(
c_prev, affine=affine, track_running_stats=track_running_stats
),
nn.ReLU(inplace=True),
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, num_classes)
# algorithm related
self.register_buffer("_tau", torch.zeros(1))
self._algo = None
self._warmup_ratio = None
def set_algo(self, algo: Text):
# used for searching
assert self._algo is None, "This functioin can only be called once."
assert algo in ["mask_gumbel", "mask_rl", "tas"], "invalid algo : {:}".format(
algo
)
self._algo = algo
self._arch_parameters = nn.Parameter(
1e-3 * torch.randn(self._max_num_Cs, len(self._candidate_Cs))
)
# if algo == 'mask_gumbel' or algo == 'mask_rl':
self.register_buffer(
"_masks", torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))
)
for i in range(len(self._candidate_Cs)):
self._masks.data[i, : self._candidate_Cs[i]] = 1
@property
def tau(self):
return self._tau
def set_tau(self, tau):
self._tau.data[:] = tau
@property
def warmup_ratio(self):
return self._warmup_ratio
def set_warmup_ratio(self, ratio: float):
self._warmup_ratio = ratio
@property
def weights(self):
xlist = list(self._cells.parameters())
xlist += list(self.lastact.parameters())
xlist += list(self.global_pooling.parameters())
xlist += list(self.classifier.parameters())
return xlist
@property
def alphas(self):
return [self._arch_parameters]
def show_alphas(self):
with torch.no_grad():
return "arch-parameters :\n{:}".format(
nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
)
@property
def random(self):
cs = []
for i in range(self._max_num_Cs):
index = random.randint(0, len(self._candidate_Cs) - 1)
cs.append(str(self._candidate_Cs[index]))
return ":".join(cs)
@property
def genotype(self):
cs = []
for i in range(self._max_num_Cs):
with torch.no_grad():
index = self._arch_parameters[i].argmax().item()
cs.append(str(self._candidate_Cs[index]))
return ":".join(cs)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self._cells):
string += "\n {:02d}/{:02d} :: {:}".format(
i, len(self._cells), cell.extra_repr()
)
return string
def extra_repr(self):
return "{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})".format(
name=self.__class__.__name__, **self.__dict__
)
def forward(self, inputs):
feature = inputs
log_probs = []
for i, cell in enumerate(self._cells):
feature = cell(feature)
# apply different searching algorithms
idx = max(0, i - 1)
if self._warmup_ratio is not None:
if random.random() < self._warmup_ratio:
mask = self._masks[-1]
else:
mask = self._masks[random.randint(0, len(self._masks) - 1)]
feature = feature * mask.view(1, -1, 1, 1)
elif self._algo == "mask_gumbel":
weights = nn.functional.gumbel_softmax(
self._arch_parameters[idx : idx + 1], tau=self.tau, dim=-1
)
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
feature = feature * mask
elif self._algo == "tas":
selected_cs, selected_probs = select2withP(
self._arch_parameters[idx : idx + 1], self.tau, num=2
)
with torch.no_grad():
i1, i2 = selected_cs.cpu().view(-1).tolist()
c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2]
out_channel = max(c1, c2)
out1 = ChannelWiseInter(feature[:, :c1], out_channel)
out2 = ChannelWiseInter(feature[:, :c2], out_channel)
out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1]
if feature.shape[1] == out.shape[1]:
feature = out
else:
miss = torch.zeros(
feature.shape[0],
feature.shape[1] - out.shape[1],
feature.shape[2],
feature.shape[3],
device=feature.device,
)
feature = torch.cat((out, miss), dim=1)
elif self._algo == "mask_rl":
prob = nn.functional.softmax(
self._arch_parameters[idx : idx + 1], dim=-1
)
dist = torch.distributions.Categorical(prob)
action = dist.sample()
log_probs.append(dist.log_prob(action))
mask = self._masks[action.item()].view(1, -1, 1, 1)
feature = feature * mask
else:
raise ValueError("invalid algorithm : {:}".format(self._algo))
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits, log_probs

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.