Initial commit

This commit is contained in:
jack-willturner 2020-06-03 12:59:01 +01:00
commit 357e877e8d
68 changed files with 7189 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*.pth
__pycache__

14
README.md Normal file
View File

@ -0,0 +1,14 @@
# Neural Architecture Search Without Training
**IMPORTANT** : our codebase relies on use of the NASBench-201 dataset. As such we make use of cloned code from [this repository](https://github.com/D-X-Y/AutoDL-Projects). We have left the copyright notices in the code that has been cloned, which includes the name of the author of the open source library that our code relies on.
The datasets can also be downloaded as instructed from the NASBench-201 README: [https://github.com/D-X-Y/NAS-Bench-201](https://github.com/D-X-Y/NAS-Bench-201).
To exactly reproduce our results:
```
conda env create -f environment.yml
conda activate nas-wot
./reproduce.sh
```

13
config_utils/__init__.py Normal file
View File

@ -0,0 +1,13 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .configure_utils import load_config, dict2config, configure2str
from .basic_args import obtain_basic_args
from .attention_args import obtain_attention_args
from .random_baseline import obtain_RandomSearch_args
from .cls_kd_args import obtain_cls_kd_args
from .cls_init_args import obtain_cls_init_args
from .search_single_args import obtain_search_single_args
from .search_args import obtain_search_args
# for network pruning
from .pruning_args import obtain_pruning_args

View File

@ -0,0 +1,22 @@
import random, argparse
from .share_args import add_shared_args
def obtain_attention_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--att_channel' , type=int, help='.')
parser.add_argument('--att_spatial' , type=str, help='.')
parser.add_argument('--att_active' , type=str, help='.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
return args

View File

@ -0,0 +1,24 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##################################################
import random, argparse
from .share_args import add_shared_args
def obtain_basic_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.')
parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
return args

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,20 @@
import random, argparse
from .share_args import add_shared_args
def obtain_cls_init_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--init_checkpoint', type=str, help='The checkpoint path to the initial model.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
return args

View File

@ -0,0 +1,23 @@
import random, argparse
from .share_args import add_shared_args
def obtain_cls_kd_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--KD_checkpoint', type=str, help='The teacher checkpoint in knowledge distillation.')
parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.')
parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.')
#parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
return args

View File

@ -0,0 +1,106 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os, json
from os import path as osp
from pathlib import Path
from collections import namedtuple
support_types = ('str', 'int', 'bool', 'float', 'none')
def convert_param(original_lists):
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
ctype, value = original_lists[0], original_lists[1]
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
is_list = isinstance(value, list)
if not is_list: value = [value]
outs = []
for x in value:
if ctype == 'int':
x = int(x)
elif ctype == 'str':
x = str(x)
elif ctype == 'bool':
x = bool(int(x))
elif ctype == 'float':
x = float(x)
elif ctype == 'none':
if x.lower() != 'none':
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
x = None
else:
raise TypeError('Does not know this type : {:}'.format(ctype))
outs.append(x)
if not is_list: outs = outs[0]
return outs
def load_config(path, extra, logger):
path = str(path)
if hasattr(logger, 'log'): logger.log(path)
assert os.path.exists(path), 'Can not find {:}'.format(path)
# Reading data back
with open(path, 'r') as f:
data = json.load(f)
content = { k: convert_param(v) for k,v in data.items()}
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
if isinstance(extra, dict): content = {**content, **extra}
Arguments = namedtuple('Configure', ' '.join(content.keys()))
content = Arguments(**content)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content
def configure2str(config, xpath=None):
if not isinstance(config, dict):
config = config._asdict()
def cstring(x):
return "\"{:}\"".format(x)
def gtype(x):
if isinstance(x, list): x = x[0]
if isinstance(x, str) : return 'str'
elif isinstance(x, bool) : return 'bool'
elif isinstance(x, int): return 'int'
elif isinstance(x, float): return 'float'
elif x is None : return 'none'
else: raise ValueError('invalid : {:}'.format(x))
def cvalue(x, xtype):
if isinstance(x, list): is_list = True
else:
is_list, x = False, [x]
temps = []
for temp in x:
if xtype == 'bool' : temp = cstring(int(temp))
elif xtype == 'none': temp = cstring('None')
else : temp = cstring(temp)
temps.append( temp )
if is_list:
return "[{:}]".format( ', '.join( temps ) )
else:
return temps[0]
xstrings = []
for key, value in config.items():
xtype = gtype(value)
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
xstrings.append(string)
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
if xpath is not None:
parent = Path(xpath).resolve().parent
parent.mkdir(parents=True, exist_ok=True)
if osp.isfile(xpath): os.remove(xpath)
with open(xpath, "w") as text_file:
text_file.write('{:}'.format(Fstring))
return Fstring
def dict2config(xdict, logger):
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
content = Arguments(**xdict)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content

View File

@ -0,0 +1,26 @@
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_pruning_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--keep_ratio' , type=float, help='The left channel ratio compared to the original network.')
parser.add_argument('--model_version', type=str, help='The network version.')
parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.')
parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.')
parser.add_argument('--Regular_W_feat', type=float, help='The .')
parser.add_argument('--Regular_W_conv', type=float, help='The .')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio)
return args

View File

@ -0,0 +1,24 @@
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_RandomSearch_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--expect_flop', type=float, help='The expected flop keep ratio.')
parser.add_argument('--arch_nums' , type=int, help='The maximum number of running random arch generating..')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
#assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max)
return args

View File

@ -0,0 +1,32 @@
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_search_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--model_config' , type=str, help='The path to the model configuration')
parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration')
parser.add_argument('--split_path' , type=str, help='The split file path.')
#parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.')
parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.')
parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.')
parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.')
# ablation studies
parser.add_argument('--ablation_num_select', type=int, help='The number of randomly selected channels.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant)
#assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
#args.arch_para_pure = bool(args.arch_para_pure)
return args

View File

@ -0,0 +1,31 @@
import os, sys, time, random, argparse
from .share_args import add_shared_args
def obtain_search_single_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--model_config' , type=str, help='The path to the model configuration')
parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration')
parser.add_argument('--split_path' , type=str, help='The split file path.')
parser.add_argument('--search_shape' , type=str, help='The shape to be searched.')
#parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.')
parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.')
parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.')
parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None'
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant)
#assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
#args.arch_para_pure = bool(args.arch_para_pure)
return args

View File

@ -0,0 +1,17 @@
import os, sys, time, random, argparse
def add_shared_args( parser ):
# Data Generation
parser.add_argument('--dataset', type=str, help='The dataset name.')
parser.add_argument('--data_path', type=str, help='The dataset name.')
parser.add_argument('--cutout_length', type=int, help='The cutout length, negative means not use.')
# Printing
parser.add_argument('--print_freq', type=int, default=100, help='print frequency (default: 200)')
parser.add_argument('--print_freq_eval', type=int, default=100, help='print frequency (default: 200)')
# Checkpoints
parser.add_argument('--eval_frequency', type=int, default=1, help='evaluation frequency (default: 200)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
# Acceleration
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)')
# Random Seed
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')

View File

@ -0,0 +1,129 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, hashlib, torch
import numpy as np
from PIL import Image
import torch.utils.data as data
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath, md5, **kwargs):
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None):
if not os.path.isfile(fpath): return False
if md5 is None: return True
else : return check_md5(fpath, md5)
class ImageNet16(data.Dataset):
# http://image-net.org/download-images
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
# https://arxiv.org/pdf/1707.08819.pdf
train_list = [
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
]
valid_list = [
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
]
def __init__(self, root, train, transform, use_num_of_class_only=None):
self.root = root
self.transform = transform
self.train = train # training set or valid set
if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
if self.train: downloaded_list = self.train_list
else : downloaded_list = self.valid_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for i, (file_name, checksum) in enumerate(downloaded_list):
file_path = os.path.join(self.root, file_name)
#print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
self.targets.extend(entry['labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
if use_num_of_class_only is not None:
assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
new_data, new_targets = [], []
for I, L in zip(self.data, self.targets):
if 1 <= L <= use_num_of_class_only:
new_data.append( I )
new_targets.append( L )
self.data = new_data
self.targets = new_targets
# self.mean.append(entry['mean'])
#self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
#self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
#print ('Mean : {:}'.format(self.mean))
#temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
#std_data = np.std(temp, axis=0)
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
#print ('Std : {:}'.format(std_data))
def __getitem__(self, index):
img, target = self.data[index], self.targets[index] - 1
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.valid_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, filename)
if not check_integrity(fpath, md5):
return False
return True
#
if __name__ == '__main__':
train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
print ( len(train) )
print ( len(valid) )
image, label = train[111]
trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
print ( len(trainX) )
print ( len(validX) )
#import pdb; pdb.set_trace()

191
datasets/LandmarkDataset.py Normal file
View File

@ -0,0 +1,191 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from os import path as osp
from copy import deepcopy as copy
from tqdm import tqdm
import warnings, time, random, numpy as np
from pts_utils import generate_label_map
from xvision import denormalize_points
from xvision import identity2affine, solve2theta, affine2image
from .dataset_utils import pil_loader
from .landmark_utils import PointMeta2V
from .augmentation_utils import CutOut
import torch
import torch.utils.data as data
class LandmarkDataset(data.Dataset):
def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None):
self.transform = transform
self.sigma = sigma
self.downsample = downsample
self.heatmap_type = heatmap_type
self.dataset_name = data_indicator
self.shape = shape # [H,W]
self.use_gray = use_gray
assert transform is not None, 'transform : {:}'.format(transform)
self.mean_file = mean_file
if mean_file is None:
self.mean_data = None
warnings.warn('LandmarkDataset initialized with mean_data = None')
else:
assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file)
self.mean_data = torch.load(mean_file)
self.reset()
self.cutout = None
self.cache_images = cache_images
print ('The general dataset initialization done : {:}'.format(self))
warnings.simplefilter( 'once' )
def __repr__(self):
return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__))
def set_cutout(self, length):
if length is not None and length >= 1:
self.cutout = CutOut( int(length) )
else: self.cutout = None
def reset(self, num_pts=-1, boxid='default', only_pts=False):
self.NUM_PTS = num_pts
if only_pts: return
self.length = 0
self.datas = []
self.labels = []
self.NormDistances = []
self.BOXID = boxid
if self.mean_data is None:
self.mean_face = None
else:
self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face)
#assert self.dataset_name is not None, 'The dataset name is None'
def __len__(self):
assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length)
return self.length
def append(self, data, label, distance):
assert osp.isfile(data), 'The image path is not a file : {:}'.format(data)
self.datas.append( data ) ; self.labels.append( label )
self.NormDistances.append( distance )
self.length = self.length + 1
def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
if reset: self.reset(num_pts, boxindicator)
else : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts)
if isinstance(file_lists, str): file_lists = [file_lists]
samples = []
for idx, file_path in enumerate(file_lists):
print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path))
xdata = torch.load(file_path)
if isinstance(xdata, list) : data = xdata # image or video dataset list
elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list
else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) ))
samples = samples + data
# samples is a dict, where the key is the image-path and the value is the annotation
# each annotation is a dict, contains 'points' (3,num_pts), and various box
print ('GeneralDataset-V2 : {:} samples'.format(len(samples)))
#for index, annotation in enumerate(samples):
for index in tqdm( range( len(samples) ) ):
annotation = samples[index]
image_path = annotation['current_frame']
points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)]
label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name)
if normalizeL is None: normDistance = None
else : normDistance = annotation['normalizeL-{:}'.format(normalizeL)]
self.append(image_path, label, normDistance)
assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas))
assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels))
assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance))
print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length))
def __getitem__(self, index):
assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index)
if self.cache_images is not None and self.datas[index] in self.cache_images:
image = self.cache_images[ self.datas[index] ].clone()
else:
image = pil_loader(self.datas[index], self.use_gray)
target = self.labels[index].copy()
return self._process_(image, target, index)
def _process_(self, image, target, index):
# transform the image and points
image, target, theta = self.transform(image, target)
(C, H, W), (height, width) = image.size(), self.shape
# obtain the visiable indicator vector
if target.is_none(): nopoints = True
else : nopoints = False
if index == -1: __path = None
else : __path = self.datas[index]
if isinstance(theta, list) or isinstance(theta, tuple):
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], []
for _theta in theta:
_affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \
= self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path))
affineImage.append(_affineImage)
heatmaps.append(_heatmaps)
mask.append(_mask)
norm_trans_points.append(_norm_trans_points)
THETA.append(_theta)
transpose_theta.append(_transpose_theta)
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \
torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta)
else:
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path))
torch_index = torch.IntTensor([index])
torch_nopoints = torch.ByteTensor( [ nopoints ] )
torch_shape = torch.IntTensor([H,W])
return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape
def __process_affine(self, image, target, theta, nopoints, aux_info=None):
image, target, theta = image.clone(), target.copy(), theta.clone()
(C, H, W), (height, width) = image.size(), self.shape
if nopoints: # do not have label
norm_trans_points = torch.zeros((3, self.NUM_PTS))
heatmaps = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample))
mask = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8)
transpose_theta = identity2affine(False)
else:
norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W))
norm_trans_points = apply_boundary(norm_trans_points)
real_trans_points = norm_trans_points.clone()
real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:])
heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C
heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor)
mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
if self.mean_face is None:
#warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
transpose_theta = identity2affine(False)
else:
if torch.sum(norm_trans_points[2,:] == 1) < 3:
warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info))
transpose_theta = identity2affine(False)
else:
transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone())
affineImage = affine2image(image, theta, self.shape)
if self.cutout is not None: affineImage = self.cutout( affineImage )
return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta

View File

@ -0,0 +1,46 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch, copy, random
import torch.utils.data as data
class SearchDataset(data.Dataset):
def __init__(self, name, data, train_split, valid_split, check=True):
self.datasetname = name
if isinstance(data, (list, tuple)): # new type of SearchDataset
assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
self.train_data = data[0]
self.valid_data = data[1]
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
self.mode_str = 'V2' # new mode
else:
self.mode_str = 'V1' # old mode
self.data = data
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
if check:
intersection = set(train_split).intersection(set(valid_split))
assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
self.length = len(self.train_split)
def __repr__(self):
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
def __len__(self):
return self.length
def __getitem__(self, index):
assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
train_index = self.train_split[index]
valid_index = random.choice( self.valid_split )
if self.mode_str == 'V1':
train_image, train_label = self.data[train_index]
valid_image, valid_label = self.data[valid_index]
elif self.mode_str == 'V2':
train_image, train_label = self.train_data[train_index]
valid_image, valid_label = self.valid_data[valid_index]
else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
return train_image, train_label, valid_image, valid_label

5
datasets/__init__.py Normal file
View File

@ -0,0 +1,5 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset

View File

@ -0,0 +1,227 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, torch
import os.path as osp
import numpy as np
import torchvision.datasets as dset
import torchvision.transforms as transforms
from copy import deepcopy
from PIL import Image
from .DownsampledImageNet import ImageNet16
from .SearchDatasetWrap import SearchDataset
from config_utils import load_config
Dataset2Class = {'cifar10' : 10,
'cifar100': 100,
'imagenet-1k-s':1000,
'imagenet-1k' : 1000,
'ImageNet16' : 1000,
'ImageNet16-150': 150,
'ImageNet16-120': 120,
'ImageNet16-200': 200}
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class Lighting(object):
def __init__(self, alphastd,
eigval=imagenet_pca['eigval'],
eigvec=imagenet_pca['eigvec']):
self.alphastd = alphastd
assert eigval.shape == (3,)
assert eigvec.shape == (3, 3)
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0.:
return img
rnd = np.random.randn(3) * self.alphastd
rnd = rnd.astype('float32')
v = rnd
old_dtype = np.asarray(img).dtype
v = v * self.eigval
v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,))
img = np.add(img, inc)
if old_dtype == np.uint8:
img = np.clip(img, 0, 255)
img = Image.fromarray(img.astype(old_dtype), 'RGB')
return img
def __repr__(self):
return self.__class__.__name__ + '()'
def get_datasets(name, root, cutout):
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('imagenet-1k'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('ImageNet16'):
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
else:
raise TypeError("Unknow dataset : {:}".format(name))
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('ImageNet16'):
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 16, 16)
elif name == 'tiered':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('imagenet-1k'):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if name == 'imagenet-1k':
xlists = [transforms.RandomResizedCrop(224)]
xlists.append(
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2))
xlists.append( Lighting(0.1))
elif name == 'imagenet-1k-s':
xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
else: raise ValueError('invalid name : {:}'.format(name))
xlists.append( transforms.RandomHorizontalFlip(p=0.5) )
xlists.append( transforms.ToTensor() )
xlists.append( normalize )
train_transform = transforms.Compose(xlists)
test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
xshape = (1, 3, 224, 224)
else:
raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name.startswith('imagenet-1k'):
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000)
elif name == 'ImageNet16':
train_data = ImageNet16(root, True , train_transform)
test_data = ImageNet16(root, False, test_transform)
assert len(train_data) == 1281167 and len(test_data) == 50000
elif name == 'ImageNet16-120':
train_data = ImageNet16(root, True , train_transform, 120)
test_data = ImageNet16(root, False, test_transform , 120)
assert len(train_data) == 151700 and len(test_data) == 6000
elif name == 'ImageNet16-150':
train_data = ImageNet16(root, True , train_transform, 150)
test_data = ImageNet16(root, False, test_transform , 150)
assert len(train_data) == 190272 and len(test_data) == 7500
elif name == 'ImageNet16-200':
train_data = ImageNet16(root, True , train_transform, 200)
test_data = ImageNet16(root, False, test_transform , 200)
assert len(train_data) == 254775 and len(test_data) == 10000
else: raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name]
return train_data, test_data, xshape, class_num
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers):
if isinstance(batch_size, (list,tuple)):
batch, test_batch = batch_size
else:
batch, test_batch = batch_size, batch_size
if dataset == 'cifar10':
#split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config('{:}/cifar-split.txt'.format(config_root), None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
#logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
# To split data
xvalid_data = deepcopy(train_data)
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
xvalid_data.transforms = valid_data.transform
xvalid_data.transform = deepcopy( valid_data.transform )
search_data = SearchDataset(dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True)
elif dataset == 'cifar100':
cifar100_test_split = load_config('{:}/cifar100-test-split.txt'.format(config_root), None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
search_data = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
elif dataset == 'ImageNet16-120':
imagenet_test_split = load_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
search_data = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
return search_loader, train_loader, valid_loader
#if __name__ == '__main__':
# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
# import pdb; pdb.set_trace()

View File

@ -0,0 +1 @@
from .point_meta import PointMeta2V, apply_affine2point, apply_boundary

View File

@ -0,0 +1,116 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import copy, math, torch, numpy as np
from xvision import normalize_points
from xvision import denormalize_points
class PointMeta():
# points : 3 x num_pts (x, y, oculusion)
# image_size: original [width, height]
def __init__(self, num_point, points, box, image_path, dataset_name):
self.num_point = num_point
if box is not None:
assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
self.box = torch.Tensor(box)
else: self.box = None
if points is None:
self.points = points
else:
assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points )
self.points = torch.Tensor(points.copy())
self.image_path = image_path
self.datasets = dataset_name
def __repr__(self):
if self.box is None: boxstr = 'None'
else : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist())
return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')')
def get_box(self, return_diagonal=False):
if self.box is None: return None
if not return_diagonal:
return self.box.clone()
else:
W = (self.box[2]-self.box[0]).item()
H = (self.box[3]-self.box[1]).item()
return math.sqrt(H*H+W*W)
def get_points(self, ignore_indicator=False):
if ignore_indicator: last = 2
else : last = 3
if self.points is not None: return self.points.clone()[:last, :]
else : return torch.zeros((last, self.num_point))
def is_none(self):
#assert self.box is not None, 'The box should not be None'
return self.points is None
#if self.box is None: return True
#else : return self.points is None
def copy(self):
return copy.deepcopy(self)
def visiable_pts_num(self):
with torch.no_grad():
ans = self.points[2,:] > 0
ans = torch.sum(ans)
ans = ans.item()
return ans
def special_fun(self, indicator):
if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points.
assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point)
self.num_point = 49
out = torch.ones((68), dtype=torch.uint8)
out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0
if self.points is not None: self.points = self.points.clone()[:, out]
else:
raise ValueError('Invalid indicator : {:}'.format( indicator ))
def apply_horizontal_flip(self):
#self.points[0, :] = width - self.points[0, :] - 1
# Mugsy spefic or Synthetic
if self.datasets.startswith('HandsyROT'):
ori = np.array(list(range(0, 42)))
pos = np.array(list(range(21,42)) + list(range(0,21)))
self.points[:, pos] = self.points[:, ori]
elif self.datasets.startswith('face68'):
ori = np.array(list(range(0, 68)))
pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1
self.points[:, ori] = self.points[:, pos]
else:
raise ValueError('Does not support {:}'.format(self.datasets))
# shape = (H,W)
def apply_affine2point(points, theta, shape):
assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size())
with torch.no_grad():
ok_points = points[2,:] == 1
assert torch.sum(ok_points).item() > 0, 'there is no visiable point'
points[:2,:] = normalize_points(shape, points[:2,:])
norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
trans_points, ___ = torch.gesv(points[:, ok_points], theta)
norm_trans_points[:, ok_points] = trans_points
return norm_trans_points
def apply_boundary(norm_trans_points):
with torch.no_grad():
norm_trans_points = norm_trans_points.clone()
oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0))
oks = torch.sum(oks, dim=0) == 5
norm_trans_points[2, :] = oks
return norm_trans_points

20
datasets/test_utils.py Normal file
View File

@ -0,0 +1,20 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os
def test_imagenet_data(imagenet):
total_length = len(imagenet)
assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length)
map_id = {}
for index in range(total_length):
path, target = imagenet.imgs[index]
folder, image_name = os.path.split(path)
_, folder = os.path.split(folder)
if folder not in map_id:
map_id[folder] = target
else:
assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target)
assert image_name.find(folder) == 0, '{} is wrong.'.format(path)
print ('Check ImageNet Dataset OK')

50
environment.yml Normal file
View File

@ -0,0 +1,50 @@
name: nas-wot
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- blas=1.0=mkl
- ca-certificates=2020.1.1=0
- certifi=2020.4.5.1=py38_0
- cudatoolkit=10.2.89=hfd86e86_1
- freetype=2.9.1=h8a8886c_1
- intel-openmp=2020.1=217
- jpeg=9b=h024ee3a_2
- ld_impl_linux-64=2.33.1=h53a641e_7
- libedit=3.1.20181209=hc058e9b_0
- libffi=3.3=he6710b0_1
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h2733197_1
- lz4-c=1.9.2=he6710b0_0
- mkl=2020.1=217
- mkl-service=2.3.0=py38he904b0f_0
- mkl_fft=1.0.15=py38ha843d7b_0
- mkl_random=1.1.1=py38h0573a6f_0
- ncurses=6.2=he6710b0_1
- ninja=1.9.0=py38hfd86e86_0
- numpy=1.18.1=py38h4f9e942_0
- numpy-base=1.18.1=py38hde5b4d6_1
- olefile=0.46=py_0
- openssl=1.1.1g=h7b6447c_0
- pillow=7.1.2=py38hb39fc2d_0
- pip=20.0.2=py38_3
- python=3.8.3=hcff3b4d_0
- pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0
- readline=8.0=h7b6447c_0
- setuptools=46.4.0=py38_0
- six=1.14.0=py38_0
- sqlite=3.31.1=h62c20be_1
- tk=8.6.8=hbc83047_0
- torchvision=0.6.0=py38_cu102
- tqdm=4.46.0=py_0
- wheel=0.34.2=py38_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.4=h0b5b093_3
- pip:
- argparse==1.4.0
- nas-bench-201==1.3

105
models/CifarDenseNet.py Normal file
View File

@ -0,0 +1,105 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
class Bottleneck(nn.Module):
def __init__(self, nChannels, growthRate):
super(Bottleneck, self).__init__()
interChannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat((x, out), 1)
return out
class SingleLayer(nn.Module):
def __init__(self, nChannels, growthRate):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = torch.cat((x, out), 1)
return out
class Transition(nn.Module):
def __init__(self, nChannels, nOutChannels):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = F.avg_pool2d(out, 2)
return out
class DenseNet(nn.Module):
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
super(DenseNet, self).__init__()
if bottleneck: nDenseBlocks = int( (depth-4) / 6 )
else : nDenseBlocks = int( (depth-4) / 3 )
self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses)
nChannels = 2*growthRate
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.trans1 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.trans2 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
self.act = nn.Sequential(
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True),
nn.AvgPool2d(8))
self.fc = nn.Linear(nChannels, nClasses)
self.apply(initialize_resnet)
def get_message(self):
return self.message
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels, growthRate))
else:
layers.append(SingleLayer(nChannels, growthRate))
nChannels += growthRate
return nn.Sequential(*layers)
def forward(self, inputs):
out = self.conv1( inputs )
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.dense3(out)
features = self.act(out)
features = features.view(features.size(0), -1)
out = self.fc(features)
return features, out

157
models/CifarResNet.py Normal file
View File

@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
from .SharedUtils import additive_func
class Downsample(nn.Module):
def __init__(self, nIn, nOut, stride):
super(Downsample, self).__init__()
assert stride == 2 and nOut == 2*nIn, 'stride:{} IO:{},{}'.format(stride, nIn, nOut)
self.in_dim = nIn
self.out_dim = nOut
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
x = self.avg(x)
out = self.conv(x)
return out
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias)
self.bn = nn.BatchNorm2d(nOut)
if relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.out_dim = nOut
self.num_conv = 1
def forward(self, x):
conv = self.conv( x )
bn = self.bn( conv )
if self.relu: return self.relu( bn )
else : return bn
class ResNetBasicblock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, False)
if stride == 2:
self.downsample = Downsample(inplanes, planes, stride)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False)
else:
self.downsample = None
self.out_dim = planes
self.num_conv = 2
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, basicblock)
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, False)
if stride == 2:
self.downsample = Downsample(inplanes, planes*self.expansion, stride)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.num_conv = 3
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = additive_func(residual, bottleneck)
return F.relu(out, inplace=True)
class CifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes, zero_init_residual):
super(CifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}'.format(block_name, depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, True) ] )
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

94
models/CifarWideResNet.py Normal file
View File

@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
class WideBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, dropout=False):
super(WideBasicblock, self).__init__()
self.bn_a = nn.BatchNorm2d(inplanes)
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn_b = nn.BatchNorm2d(planes)
if dropout:
self.dropout = nn.Dropout2d(p=0.5, inplace=True)
else:
self.dropout = None
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
if inplanes != planes:
self.downsample = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False)
else:
self.downsample = None
def forward(self, x):
basicblock = self.bn_a(x)
basicblock = F.relu(basicblock)
basicblock = self.conv_a(basicblock)
basicblock = self.bn_b(basicblock)
basicblock = F.relu(basicblock)
if self.dropout is not None:
basicblock = self.dropout(basicblock)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
x = self.downsample(x)
return x + basicblock
class CifarWideResNet(nn.Module):
"""
ResNet optimized for the Cifar dataset, as specified in
https://arxiv.org/abs/1512.03385.pdf
"""
def __init__(self, depth, widen_factor, num_classes, dropout):
super(CifarWideResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
assert (depth - 4) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 4) // 6
print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks))
self.num_classes = num_classes
self.dropout = dropout
self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.message = 'Wide ResNet : depth={:}, widen_factor={:}, class={:}'.format(depth, widen_factor, num_classes)
self.inplanes = 16
self.stage_1 = self._make_layer(WideBasicblock, 16*widen_factor, layer_blocks, 1)
self.stage_2 = self._make_layer(WideBasicblock, 32*widen_factor, layer_blocks, 2)
self.stage_3 = self._make_layer(WideBasicblock, 64*widen_factor, layer_blocks, 2)
self.lastact = nn.Sequential(nn.BatchNorm2d(64*widen_factor), nn.ReLU(inplace=True))
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(64*widen_factor, num_classes)
self.apply(initialize_resnet)
def get_message(self):
return self.message
def _make_layer(self, block, planes, blocks, stride):
layers = []
layers.append(block(self.inplanes, planes, stride, self.dropout))
self.inplanes = planes
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, 1, self.dropout))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv_3x3(x)
x = self.stage_1(x)
x = self.stage_2(x)
x = self.stage_3(x)
x = self.lastact(x)
x = self.avgpool(x)
features = x.view(x.size(0), -1)
outs = self.classifier(features)
return features, outs

View File

@ -0,0 +1,101 @@
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from .initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU6(inplace=True)
def forward(self, x):
out = self.conv( x )
out = self.bn ( out )
out = self.relu( out )
return out
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout):
super(MobileNetV2, self).__init__()
if block_name == 'InvertedResidual':
block = InvertedResidual
else:
raise ValueError('invalid block name : {:}'.format(block_name))
inverted_residual_setting = [
# t, c, n, s
[1, 16 , 1, 1],
[6, 24 , 2, 2],
[6, 32 , 3, 2],
[6, 64 , 4, 2],
[6, 96 , 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * max(1.0, width_mult))
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.last_channel, num_classes),
)
self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout)
# weight initialization
self.apply( initialize_resnet )
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

172
models/ImageNet_ResNet.py Normal file
View File

@ -0,0 +1,172 @@
# Deep Residual Learning for Image Recognition, CVPR 2016
import torch.nn as nn
from .initialization import initialize_resnet
def conv3x3(in_planes, out_planes, stride=1, groups=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64):
super(BasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64):
super(Bottleneck, self).__init__()
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = conv3x3(width, width, stride, groups)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block_name, layers, deep_stem, num_classes, zero_init_residual, groups, width_per_group):
super(ResNet, self).__init__()
#planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
if block_name == 'BasicBlock' : block= BasicBlock
elif block_name == 'Bottleneck': block= Bottleneck
else : raise ValueError('invalid block-name : {:}'.format(block_name))
if not deep_stem:
self.conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True))
else:
self.conv = nn.Sequential(
nn.Conv2d( 3, 32, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.inplanes = 64
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64 , layers[0], stride=1, groups=groups, base_width=width_per_group)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.message = 'block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}'.format(block, layers, deep_stem, num_classes)
self.apply( initialize_resnet )
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride, groups, base_width):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if stride == 2:
downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
conv1x1(self.inplanes, planes * block.expansion, 1),
nn.BatchNorm2d(planes * block.expansion),
)
elif stride == 1:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
)
else: raise ValueError('invalid stride [{:}] for downsample'.format(stride))
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, groups, base_width))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, 1, None, groups, base_width))
return nn.Sequential(*layers)
def get_message(self):
return self.message
def forward(self, x):
x = self.conv(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.fc(features)
return features, logits

34
models/SharedUtils.py Normal file
View File

@ -0,0 +1,34 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
def additive_func(A, B):
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
C = min(A.size(1), B.size(1))
if A.size(1) == B.size(1):
return A + B
elif A.size(1) < B.size(1):
out = B.clone()
out[:,:C] += A
return out
else:
out = A.clone()
out[:,:C] += B
return out
def change_key(key, value):
def func(m):
if hasattr(m, key):
setattr(m, key, value)
return func
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

185
models/__init__.py Normal file
View File

@ -0,0 +1,185 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from os import path as osp
from typing import List, Text
import torch
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
'CellStructure', 'CellArchitectures'
]
# useful modules
from config_utils import dict2config
from .SharedUtils import change_key
from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict
super_type = getattr(config, 'super_type', 'basic')
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
if super_type == 'basic' and config.name in group_names:
from .cell_searchs import nas201_super_nets as nas_super_nets
try:
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
except:
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
elif super_type == 'nasnet-super':
from .cell_searchs import nasnet_super_nets as nas_super_nets
return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \
config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats)
elif config.name == 'infer.tiny':
from .cell_infers import TinyNetwork
if hasattr(config, 'genotype'):
genotype = config.genotype
elif hasattr(config, 'arch_str'):
genotype = CellStructure.str2structure(config.arch_str)
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
elif config.name == 'infer.shape.tiny':
from .shape_infers import DynamicShapeTinyNet
if isinstance(config.channels, str):
channels = tuple([int(x) for x in config.channels.split(':')])
else: channels = config.channels
genotype = CellStructure.str2structure(config.genotype)
return DynamicShapeTinyNet(channels, genotype, config.num_classes)
elif config.name == 'infer.nasnet-cifar':
from .cell_infers import NASNetonCIFAR
raise NotImplementedError
else:
raise ValueError('invalid network name : {:}'.format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name) -> List[Text]:
if xtype == 'cell':
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))
def get_cifar_models(config, extra_path=None):
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
if config.arch == 'resnet':
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
elif config.arch == 'densenet':
return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck)
elif config.arch == 'wideresnet':
return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout)
else:
raise ValueError('invalid module type : {:}'.format(config.arch))
elif super_type.startswith('infer'):
from .shape_infers import InferWidthCifarResNet
from .shape_infers import InferDepthCifarResNet
from .shape_infers import InferCifarResNet
from .cell_infers import NASNetonCIFAR
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1]
if infer_mode == 'width':
return InferWidthCifarResNet(config.module, config.depth, config.xchannels, config.class_num, config.zero_init_residual)
elif infer_mode == 'depth':
return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual)
elif infer_mode == 'shape':
return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual)
elif infer_mode == 'nasnet.cifar':
genotype = config.genotype
if extra_path is not None: # reload genotype by extra_path
if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path))
xdata = torch.load(extra_path)
current_epoch = xdata['epoch']
genotype = xdata['genotypes'][current_epoch-1]
C = config.C if hasattr(config, 'C') else config.ichannel
N = config.N if hasattr(config, 'N') else config.layers
return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary)
else:
raise ValueError('invalid infer-mode : {:}'.format(infer_mode))
else:
raise ValueError('invalid super-type : {:}'.format(super_type))
def get_imagenet_models(config):
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
from .ImageNet_ResNet import ResNet
from .ImageNet_MobileNetV2 import MobileNetV2
if config.arch == 'resnet':
return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group)
elif config.arch == 'mobilenet_v2':
return MobileNetV2(config.class_num, config.width_multi, config.input_channel, config.last_channel, 'InvertedResidual', config.dropout)
else:
raise ValueError('invalid arch : {:}'.format( config.arch ))
elif super_type.startswith('infer'): # NAS searched architecture
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1]
if infer_mode == 'shape':
from .shape_infers import InferImagenetResNet
from .shape_infers import InferMobileNetV2
if config.arch == 'resnet':
return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual)
elif config.arch == "MobileNetV2":
return InferMobileNetV2(config.class_num, config.xchannels, config.xblocks, config.dropout)
else:
raise ValueError('invalid arch-mode : {:}'.format(config.arch))
else:
raise ValueError('invalid infer-mode : {:}'.format(infer_mode))
else:
raise ValueError('invalid super-type : {:}'.format(super_type))
# Try to obtain the network by config.
def obtain_model(config, extra_path=None):
if config.dataset == 'cifar':
return get_cifar_models(config, extra_path)
elif config.dataset == 'imagenet':
return get_imagenet_models(config)
else:
raise ValueError('invalid dataset in the model config : {:}'.format(config))
def obtain_search_model(config):
if config.dataset == 'cifar':
if config.arch == 'resnet':
from .shape_searchs import SearchWidthCifarResNet
from .shape_searchs import SearchDepthCifarResNet
from .shape_searchs import SearchShapeCifarResNet
if config.search_mode == 'width':
return SearchWidthCifarResNet(config.module, config.depth, config.class_num)
elif config.search_mode == 'depth':
return SearchDepthCifarResNet(config.module, config.depth, config.class_num)
elif config.search_mode == 'shape':
return SearchShapeCifarResNet(config.module, config.depth, config.class_num)
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
elif config.arch == 'simres':
from .shape_searchs import SearchWidthSimResNet
if config.search_mode == 'width':
return SearchWidthSimResNet(config.depth, config.class_num)
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
else:
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
elif config.dataset == 'imagenet':
from .shape_searchs import SearchShapeImagenetResNet
assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode )
if config.arch == 'resnet':
return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num)
else:
raise ValueError('invalid model config : {:}'.format(config))
else:
raise ValueError('invalid dataset in the model config : {:}'.format(config))
def load_net_from_checkpoint(checkpoint):
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
checkpoint = torch.load(checkpoint)
model_config = dict2config(checkpoint['model-config'], None)
model = obtain_model(model_config)
model.load_state_dict(checkpoint['base-model'])
return model

View File

@ -0,0 +1,5 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .tiny_network import TinyNetwork
from .nasnet_cifar import NASNetonCIFAR

120
models/cell_infers/cells.py Normal file
View File

@ -0,0 +1,120 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
# Cell for NAS-Bench-201
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i-1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride, True, True)
else:
layer = OPS[op_name](C_out, C_out, 1, True, True)
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )
self.node_IX.append( cur_index )
self.node_IN.append( cur_innod )
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
x = '{:}<-({:})'.format(i+1, ','.join(y))
laystr.append( x )
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
if not reduction:
nodes, concats = genotype['normal'], genotype['normal_concat']
else:
nodes, concats = genotype['reduce'], genotype['reduce_concat']
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = '{:}<-{:}'.format(i+2, j)
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
# [TODO] to support drop_prob in this function..
def forward(self, s0, s1, unused_drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i, node in enumerate(self._nodes):
clist = []
for in_node in node:
name, j = in_node[0], in_node[1]
node_str = '{:}<-{:}'.format(i+2, j)
op = self.edges[ node_str ]
clist.append( op(states[j]) )
states.append( sum(clist) )
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@ -0,0 +1,71 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
# The macro structure is based on NASNet
class NASNetonCIFAR(nn.Module):
def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True):
super(NASNetonCIFAR, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.auxiliary_index = None
self.auxiliary_head = None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction
if reduction and C_curr == C*4 and auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
self.auxiliary_index = index
self._Layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.drop_path_prob = -1
def update_drop_path(self, drop_path_prob):
self.drop_path_prob = drop_path_prob
def auxiliary_param(self):
if self.auxiliary_head is None: return []
else: return list( self.auxiliary_head.parameters() )
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
stem_feature, logits_aux = self.stem(inputs), None
cell_results = [stem_feature, stem_feature]
for i, cell in enumerate(self.cells):
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
cell_results.append( cell_feature )
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training:
logits_aux = self.auxiliary_head( cell_results[-1] )
out = self.lastact(cell_results[-1])
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if logits_aux is None: return out, logits
else: return out, [logits, logits_aux]

View File

@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append( cell )
C_prev = cell.out_dim
self._Layer= len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

297
models/cell_operations.py Normal file
View File

@ -0,0 +1,297 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
}
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'nas-bench-201': NAS_BENCH_201,
'darts' : DARTS_SPACE}
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
)
def forward(self, x):
return self.op(x)
class DualSepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(DualSepConv, self).__init__()
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
def forward(self, x):
x = self.op_a(x)
x = self.op_b(x)
return x
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def extra_repr(self):
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
return string
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return residual + basicblock
class POOLING(nn.Module):
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def forward(self, inputs):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1: return x.mul(0.)
else : return x[:,:,::self.stride,::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
def forward(self, x):
if self.stride == 2:
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
else:
out = self.conv(x)
out = self.bn(out)
return out
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
class PartAwareOp(nn.Module):
def __init__(self, C_in, C_out, stride, part=4):
super().__init__()
self.part = 4
self.hidden = C_in // 3
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.local_conv_list = nn.ModuleList()
for i in range(self.part):
self.local_conv_list.append(
nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True))
)
self.W_K = nn.Linear(self.hidden, self.hidden)
self.W_Q = nn.Linear(self.hidden, self.hidden)
if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
else: raise ValueError('Invalid Stride : {:}'.format(stride))
def forward(self, x):
batch, C, H, W = x.size()
assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part)
IHs = [0]
for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) )
local_feat_list = []
for i in range(self.part):
feature = x[:, :, IHs[i]:IHs[i+1], :]
xfeax = self.avg_pool(feature)
xfea = self.local_conv_list[i]( xfeax )
local_feat_list.append( xfea )
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
part_feature = part_feature.transpose(1,2).contiguous()
part_K = self.W_K(part_feature)
part_Q = self.W_Q(part_feature).transpose(1,2).contiguous()
weight_att = torch.bmm(part_K, part_Q)
attention = torch.softmax(weight_att, dim=2)
aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous()
features = []
for i in range(self.part):
feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i])
feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1)
features.append( feature )
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
final_fea = torch.cat((x,features), dim=1)
outputs = self.last( final_fea )
return outputs
# Searching for A Robust Neural Architecture in Four GPU Hours
class GDAS_Reduction_Cell(nn.Module):
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats):
super(GDAS_Reduction_Cell, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.multiplier = multiplier
self.reduction = True
self.ops1 = nn.ModuleList(
[nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True))])
self.ops2 = nn.ModuleList(
[nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=True))])
def forward(self, s0, s1, drop_prob = -1):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
X0 = self.ops1[0] (s0)
X1 = self.ops1[1] (s1)
if self.training and drop_prob > 0.:
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
#X2 = self.ops2[0] (X0+X1)
X2 = self.ops2[0] (s0)
X3 = self.ops2[1] (s1)
if self.training and drop_prob > 0.:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
return torch.cat([X0, X1, X2, X3], dim=1)

View File

@ -0,0 +1,24 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# The macro structure is defined in NAS-Bench-201
from .search_model_darts import TinyNetworkDarts
from .search_model_gdas import TinyNetworkGDAS
from .search_model_setn import TinyNetworkSETN
from .search_model_enas import TinyNetworkENAS
from .search_model_random import TinyNetworkRANDOM
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
from .search_model_gdas_nasnet import NASNetworkGDAS
from .search_model_darts_nasnet import NASNetworkDARTS
nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
"DARTS-V2": TinyNetworkDarts,
"GDAS": TinyNetworkGDAS,
"SETN": TinyNetworkSETN,
"ENAS": TinyNetworkENAS,
"RANDOM": TinyNetworkRANDOM}
nasnet_super_nets = {"GDAS": NASNetworkGDAS,
"DARTS": NASNetworkDARTS}

View File

@ -0,0 +1,12 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from search_model_enas_utils import Controller
def main():
controller = Controller(6, 4)
predictions = controller()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,199 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from copy import deepcopy
def get_combination(space, num):
combs = []
for i in range(num):
if i == 0:
for func in space:
combs.append( [(func, i)] )
else:
new_combs = []
for string in combs:
for func in space:
xstring = string + [(func, i)]
new_combs.append( xstring )
combs = new_combs
return combs
class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
self.node_N.append( len(node_info) )
self.nodes.append( tuple(deepcopy(node_info)) )
def tolist(self, remove_str):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
# note that we re-order the input node in this function
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
genotypes = []
for node_info in self.nodes:
node_info = list( node_info )
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0: return None, False
genotypes.append( node_info )
return genotypes, True
def node(self, index):
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
return self.nodes[index]
def tostr(self):
strings = []
for node_info in self.nodes:
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
string = '|{:}|'.format(string)
strings.append( string )
return '+'.join(strings)
def check_valid(self):
nodes = {0: True}
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == 'none' or nodes[xin] is False: x = False
else: x = True
sums.append( x )
nodes[i+1] = sum(sums) > 0
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
# two operations are special, i.e., none and skip_connect
nodes = {0: '0'}
for i_node, node_info in enumerate(self.nodes):
cur_node = []
for op, xin in node_info:
if consider_zero is None:
x = '('+nodes[xin]+')' + '@{:}'.format(op)
elif consider_zero:
if op == 'none' or nodes[xin] == '#': x = '#' # zero
elif op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
else:
if op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
cur_node.append(x)
nodes[i_node+1] = '+'.join( sorted(cur_node) )
return nodes[ len(self.nodes) ]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names: return False
return True
def __repr__(self):
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return Structure( genotypes )
@staticmethod
def str2fullstructure(xstr, default_name='none'):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes= list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes: input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append( tuple(node_info) )
return Structure( genotypes )
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [ tuple(arch) ]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append( previous_arch + [tuple(cur_node)] )
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 1), ), # node-2
(('skip_connect', 0), ('skip_connect', 2))] # node-3
)
AllConv3x3_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
)
AllFull_CODE = Structure(
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
)
AllConv1x1_CODE = Structure(
[(('nor_conv_1x1', 0), ), # node-1
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
)
AllIdentity_CODE = Structure(
[(('skip_connect', 0), ), # node-1
(('skip_connect', 0), ('skip_connect', 1)), # node-2
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
)
architectures = {'resnet' : ResNet_CODE,
'all_c3x3': AllConv3x3_CODE,
'all_c1x1': AllConv1x1_CODE,
'all_idnt': AllIdentity_CODE,
'all_full': AllFull_CODE}

View File

@ -0,0 +1,197 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, random, torch
import warnings
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from ..cell_operations import OPS
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
class NAS201SearchCell(nn.Module):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
super(NAS201SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.edges = nn.ModuleDict()
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
self.edges[ node_str ] = nn.ModuleList( xlists )
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def extra_repr(self):
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
return string
def forward(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# GDAS
def forward_gdas(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = hardwts[ self.edge2index[node_str] ]
argmaxs = index[ self.edge2index[node_str] ].item()
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
inter_nodes.append( weigsum )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# joint
def forward_joint(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
inter_nodes.append( aggregation )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# uniform random sampling per iteration, SETN
def forward_urs(self, inputs):
nodes = [inputs]
for i in range(1, self.max_nodes):
while True: # to avoid select zero for all ops
sops, has_non_zero = [], False
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
candidates = self.edges[node_str]
select_op = random.choice(candidates)
sops.append( select_op )
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
if has_non_zero: break
inter_nodes = []
for j, select_op in enumerate(sops):
inter_nodes.append( select_op(nodes[j]) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# select the argmax
def forward_select(self, inputs, weightss):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = weightss[ self.edge2index[node_str] ]
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
# forward with a specific structure
def forward_dynamic(self, inputs, structure):
nodes = [inputs]
for i in range(1, self.max_nodes):
cur_op_node = structure.nodes[i-1]
inter_nodes = []
for op_name, j in cur_op_node:
node_str = '{:}<-{:}'.format(i, j)
op_index = self.op_names.index( op_name )
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
nodes.append( sum(inter_nodes) )
return nodes[-1]
class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in space:
op = OPS[primitive](C, C, stride, affine, track_running_stats)
self._ops.append(op)
def forward_gdas(self, x, weights, index):
return self._ops[index](x) * weights[index]
def forward_darts(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module):
def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetSearchCell, self).__init__()
self.reduction = reduction
self.op_names = deepcopy(space)
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self.edges = nn.ModuleDict()
for i in range(self._steps):
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j) # indicate the edge from node-(j) to node-(i+2)
stride = 2 if reduction and j < 2 else 1
op = MixedOp(space, C, stride, affine, track_running_stats)
self.edges[ node_str ] = op
self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edges)
def forward_gdas(self, s0, s1, weightss, indexs):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
index = indexs[ self.edge2index[node_str] ].item()
clist.append( op.forward_gdas(h, weights, index) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1)
def forward_darts(self, s0, s1, weightss):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
clist = []
for j, h in enumerate(states):
node_str = '{:}<-{:}'.format(i, j)
op = self.edges[ node_str ]
weights = weightss[ self.edge2index[node_str] ]
clist.append( op.forward_darts(h, weights) )
states.append( sum(clist) )
return torch.cat(states[-self._multiplier:], dim=1)

View File

@ -0,0 +1,97 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
########################################################
# DARTS: Differentiable Architecture Search, ICLR 2019 #
########################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkDarts(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkDarts, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() )
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, alphas)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,108 @@
####################
# DARTS, ICLR 2019 #
####################
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module):
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
super(NASNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
def get_weights(self) -> List[torch.nn.Parameter]:
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self) -> List[torch.nn.Parameter]:
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self) -> Text:
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self) -> Text:
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self) -> Dict[Text, List]:
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs):
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction: ww = reduce_w
else : ww = normal_w
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,94 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
from .search_model_enas_utils import Controller
class TinyNetworkENAS(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkENAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# to maintain the sampled architecture
self.sampled_arch = None
def update_arch(self, _arch):
if _arch is None:
self.sampled_arch = None
elif isinstance(_arch, Structure):
self.sampled_arch = _arch
elif isinstance(_arch, (list, tuple)):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_index = _arch[ self.edge2index[node_str] ]
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
self.sampled_arch = Structure(genotypes)
else:
raise ValueError('invalid type of input architecture : {:}'.format(_arch))
return self.sampled_arch
def create_controller(self):
return Controller(len(self.edge2index), len(self.op_names))
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.sampled_arch)
else: feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,55 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
super(Controller, self).__init__()
# assign the attributes
self.num_edge = num_edge
self.num_ops = num_ops
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars , -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight , -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight , -0.1, 0.1)
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append( op_index.item() )
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append( op_log_prob.view(-1) )
op_entropy = op_distribution.entropy()
entropys.append( op_entropy.view(-1) )
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch

View File

@ -0,0 +1,111 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkGDAS(nn.Module):
#def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.tau = 10
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_parameters]
def show_alphas(self):
with torch.no_grad():
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() )
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs):
while True:
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue
else: break
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_gdas(feature, hardwts, index)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,125 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkGDAS(nn.Module):
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats):
super(NASNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.tau = 10
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def set_tau(self, tau):
self.tau = tau
def get_tau(self):
return self.tau
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs):
def get_gumbel_prob(xins):
while True:
gumbels = -torch.empty_like(xins).exponential_().log()
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
probs = nn.functional.softmax(logits, dim=1)
index = probs.max(-1, keepdim=True)[1]
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
hardwts = one_h - probs.detach() + probs
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
continue
else: break
return hardwts, index
normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters)
reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
if cell.reduction: hardwts, index = reduce_hardwts, reduce_index
else : hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,81 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##############################################################################
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
##############################################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkRANDOM(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkRANDOM, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_cache = None
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def random_genotype(self, set_cache):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = random.choice( self.op_names )
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
arch = Structure( genotypes )
if set_cache: self.arch_cache = arch
return arch
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell.forward_dynamic(feature, self.arch_cache)
else: feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,152 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch, random
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
class TinyNetworkSETN(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.mode = 'urs'
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ['urs', 'joint', 'select', 'dynamic']
self.mode = mode
if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell )
else : self.dynamic_cell = None
def get_cal_mode(self):
return self.mode
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
return [self.arch_parameters]
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[ self.edge2index[node_str] ]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def get_log_prob(self, arch):
with torch.no_grad():
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = '{:}<-{:}'.format(i+1, xin)
op_index = self.op_names.index(op)
select_logits.append( logits[self.edge2index[node_str], op_index] )
return sum(select_logits).item()
def return_topK(self, K):
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
if K < 0 or K >= len(archs): K = len(archs)
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
return return_pairs
def forward(self, inputs):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
with torch.no_grad():
alphas_cpu = alphas.detach().cpu()
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
if self.mode == 'urs':
feature = cell.forward_urs(feature)
elif self.mode == 'select':
feature = cell.forward_select(feature, alphas_cpu)
elif self.mode == 'joint':
feature = cell.forward_joint(feature, alphas)
elif self.mode == 'dynamic':
feature = cell.forward_dynamic(feature, self.dynamic_cell)
else: raise ValueError('invalid mode={:}'.format(self.mode))
else: feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,139 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet
class NASNetworkSETN(nn.Module):
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
super(NASNetworkSETN, self).__init__()
self._C = C
self._layerN = N
self._steps = steps
self._multiplier = multiplier
self.stem = nn.Sequential(
nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C*stem_multiplier))
# config for each layer
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * (N-1) + [C*4 ] + [C*4 ] * (N-1)
layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1)
num_edge, edge2index = None, None
C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell = SearchCell(search_space, steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier*C_curr, reduction
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.mode = 'urs'
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ['urs', 'joint', 'select', 'dynamic']
self.mode = mode
if mode == 'dynamic':
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[ self.edge2index[node_str] ]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def genotype(self):
def _parse(weights):
gene = []
for i in range(self._steps):
edges = []
for j in range(2+i):
node_str = '{:}<-{:}'.format(i, j)
ws = weights[ self.edge2index[node_str] ]
for k, op_name in enumerate(self.op_names):
if op_name == 'none': continue
edges.append( (op_name, j, ws[k]) )
edges = sorted(edges, key=lambda x: -x[-1])
selected_edges = edges[:2]
gene.append( tuple(selected_edges) )
return gene
with torch.no_grad():
gene_normal = _parse(torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy())
gene_reduce = _parse(torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy())
return {'normal': gene_normal, 'normal_concat': list(range(2+self._steps-self._multiplier, self._steps+2)),
'reduce': gene_reduce, 'reduce_concat': list(range(2+self._steps-self._multiplier, self._steps+2))}
def forward(self, inputs):
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells):
# [TODO]
raise NotImplementedError
if cell.reduction: hardwts, index = reduce_hardwts, reduce_index
else : hardwts, index = normal_hardwts, normal_index
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
out = self.lastact(s1)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

62
models/clone_weights.py Normal file
View File

@ -0,0 +1,62 @@
import torch
import torch.nn as nn
def copy_conv(module, init):
assert isinstance(module, nn.Conv2d), 'invalid module : {:}'.format(module)
assert isinstance(init , nn.Conv2d), 'invalid module : {:}'.format(init)
new_i, new_o = module.in_channels, module.out_channels
module.weight.copy_( init.weight.detach()[:new_o, :new_i] )
if module.bias is not None:
module.bias.copy_( init.bias.detach()[:new_o] )
def copy_bn (module, init):
assert isinstance(module, nn.BatchNorm2d), 'invalid module : {:}'.format(module)
assert isinstance(init , nn.BatchNorm2d), 'invalid module : {:}'.format(init)
num_features = module.num_features
if module.weight is not None:
module.weight.copy_( init.weight.detach()[:num_features] )
if module.bias is not None:
module.bias.copy_( init.bias.detach()[:num_features] )
if module.running_mean is not None:
module.running_mean.copy_( init.running_mean.detach()[:num_features] )
if module.running_var is not None:
module.running_var.copy_( init.running_var.detach()[:num_features] )
def copy_fc (module, init):
assert isinstance(module, nn.Linear), 'invalid module : {:}'.format(module)
assert isinstance(init , nn.Linear), 'invalid module : {:}'.format(init)
new_i, new_o = module.in_features, module.out_features
module.weight.copy_( init.weight.detach()[:new_o, :new_i] )
if module.bias is not None:
module.bias.copy_( init.bias.detach()[:new_o] )
def copy_base(module, init):
assert type(module).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format(module)
assert type( init).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format( init)
if module.conv is not None:
copy_conv(module.conv, init.conv)
if module.bn is not None:
copy_bn (module.bn, init.bn)
def copy_basic(module, init):
copy_base(module.conv_a, init.conv_a)
copy_base(module.conv_b, init.conv_b)
if module.downsample is not None:
if init.downsample is not None:
copy_base(module.downsample, init.downsample)
#else:
# import pdb; pdb.set_trace()
def init_from_model(network, init_model):
with torch.no_grad():
copy_fc(network.classifier, init_model.classifier)
for base, target in zip(init_model.layers, network.layers):
assert type(base).__name__ == type(target).__name__, 'invalid type : {:} vs {:}'.format(base, target)
if type(base).__name__ == 'ConvBNReLU':
copy_base(target, base)
elif type(base).__name__ == 'ResNetBasicblock':
copy_basic(target, base)
else:
raise ValueError('unknown type name : {:}'.format( type(base).__name__ ))

18
models/initialization.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.nn as nn
def initialize_resnet(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

View File

@ -0,0 +1,167 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual):
super(InferCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,150 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
else:
self.downsample = None
self.out_dim = planes*self.expansion
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
super(InferDepthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.channels = [16]
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.channels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,160 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
super(InferWidthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,170 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferImagenetResNet(nn.Module):
def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual):
super(InferImagenetResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'BasicBlock':
block = ResNetBasicblock
elif block_name == 'Bottleneck':
block = ResNetBottleneck
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks)
self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks)
self.num_classes = num_classes
self.xchannels = xchannels
if not deep_stem:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
else:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 2
self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
for stage, layer_blocks in enumerate(layers):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,122 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from ..initialization import initialize_resnet
from ..SharedUtils import parse_channel_info
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
if has_bn: self.bn = nn.BatchNorm2d(out_planes)
else : self.bn = None
if has_relu: self.relu = nn.ReLU6(inplace=True)
else : self.relu = None
def forward(self, x):
out = self.conv( x )
if self.bn: out = self.bn ( out )
if self.relu: out = self.relu( out )
return out
class InvertedResidual(nn.Module):
def __init__(self, channels, stride, expand_ratio, additive):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2], 'invalid stride : {:}'.format(stride)
assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels)
if len(channels) == 2:
layers = []
else:
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
layers.extend([
# dw
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
# pw-linear
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
])
self.conv = nn.Sequential(*layers)
self.additive = additive
if self.additive and channels[0] != channels[-1]:
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
else:
self.shortcut = None
self.out_dim = channels[-1]
def forward(self, x):
out = self.conv(x)
# if self.additive: return additive_func(out, x)
if self.shortcut: return out + self.shortcut(x)
else : return out
class InferMobileNetV2(nn.Module):
def __init__(self, num_classes, xchannels, xblocks, dropout):
super(InferMobileNetV2, self).__init__()
block = InvertedResidual
inverted_residual_setting = [
# t, c, n, s
[1, 16 , 1, 1],
[6, 24 , 2, 2],
[6, 32 , 3, 2],
[6, 64 , 4, 2],
[6, 96 , 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks))
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting)
xchannels = parse_channel_info(xchannels)
#for i, chs in enumerate(xchannels):
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
self.xchannels = xchannels
self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks)
# building first layer
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
last_channel_idx = 1
# building inverted residual blocks
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
for i in range(n):
stride = s if i == 0 else 1
additv = True if i > 0 else False
module = block(self.xchannels[last_channel_idx], stride, t, additv)
features.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c)
last_channel_idx += 1
if i + 1 == xblocks[stage]:
out_channel = module.out_dim
for iiL in range(i+1, n):
last_channel_idx += 1
self.xchannels[last_channel_idx][0] = module.out_dim
break
# building last several layers
features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1))
assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
)
# weight initialization
self.apply( initialize_resnet )
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

View File

@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from typing import List, Text, Any
import torch.nn as nn
from models.cell_operations import ResNetBasicblock
from models.cell_infers.cells import InferCell
class DynamicShapeTinyNet(nn.Module):
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
super(DynamicShapeTinyNet, self).__init__()
self._channels = channels
if len(channels) % 3 != 2:
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
self._num_stage = N = len(channels) // 3
self.stem = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels[0]))
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
c_prev = channels[0]
self.cells = nn.ModuleList()
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
else : cell = InferCell(genotype, c_prev, c_curr, 1)
self.cells.append( cell )
c_prev = cell.out_dim
self._num_layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, num_classes)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -0,0 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .InferCifarResNet_width import InferWidthCifarResNet
from .InferImagenetResNet import InferImagenetResNet
from .InferCifarResNet_depth import InferDepthCifarResNet
from .InferCifarResNet import InferCifarResNet
from .InferMobileNetV2 import InferMobileNetV2
from .InferTinyCellNet import DynamicShapeTinyNet

View File

@ -0,0 +1,5 @@
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@ -0,0 +1,502 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(nDepth, return_num):
if nDepth == 2:
choices = (1, 2)
elif nDepth == 3:
choices = (1, 2, 3)
elif nDepth > 3:
choices = list(range(1, nDepth+1, 2))
if choices[-1] < nDepth: choices.append(nDepth)
else:
raise ValueError('invalid nDepth : {:}'.format(nDepth))
if return_num: return len(choices)
else : return choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
#else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, 'invalid length : {:}'.format(index)
# compute expected flop
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu: out = self.relu( out )
else : out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.has_bn:out= self.BNs[-1]( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1]
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range()
def get_flops(self, channels):
assert len(channels) == 4, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, 'get_flops'):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1]
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) )
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
class SearchShapeCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchShapeCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage in range(3):
cur_block_choices = get_depth_choices(layer_blocks, False)
assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks)
self.message += "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(stage, cur_block_choices, layer_blocks)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices: block_choices.append( layer_index )
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {'choices': block_choices,
'stage' : stage,
'xstart' : xstart}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append( (xend, info) )
xstart, xstage = info['xstart'], info['stage']
for ilayer in range(xstart, xend+1):
idx = bisect_right(info['choices'], ilayer-1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))))
self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))))
nn.init.normal_(self.width_attentions, 0, 0.01)
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self, LR=None):
if LR is None:
return [self.width_attentions, self.depth_attentions]
else:
return [
{"params": self.width_attentions, "lr": LR},
{"params": self.depth_attentions, "lr": LR},
]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
# select channels
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == 'genotype':
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][ torch.argmax(probe).item() ]
elif mode == 'max':
C = self.Ranges[i][-1]
elif mode == 'fix':
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
elif mode == 'random':
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
for j in range(prob.size(0)):
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
# select depth
if mode == 'genotype':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
elif mode == 'max' or mode == 'fix':
choices = [depth_probs.size(1)-1 for _ in range(depth_probs.size(0))]
elif mode == 'random':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
else:
raise ValueError('invalid mode : {:}'.format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop+= layer.get_flops(xchl)
else:
flop+= 0 # do not use this layer
else:
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['xblocks'] = selected_layers
config_dict['super_type'] = 'infer-shape'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for depth and width, there are {:} + {:} attention probabilities.".format(len(self.depth_attentions), len(self.width_attentions))
string+= '\n{:}'.format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob))
logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:17s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
string += '\n-----------------------------------------------'
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:52s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] )
selected_widths, selected_width_probs = select2withP(self.width_attentions, self.tau)
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
feature_maps = []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths [last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_width_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_width_probs [last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
feature_maps.append( x )
last_channel_idx += layer.num_conv
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]['choices']
xstagei = self.depth_info[i]['stage']
#print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
#for A, W in zip(choices, selected_depth_probs[xstagei]):
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
possible_tensors = []
max_C = max( feature_maps[A].size(1) for A in choices )
for tempi, A in enumerate(choices):
xtensor = ChannelWiseInter(feature_maps[A], max_C)
#drop_ratio = 1-(tempi+1.0)/len(choices)
#xtensor = drop_path(xtensor, drop_ratio)
possible_tensors.append( xtensor )
weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) )
x = weighted_sum
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
else:
x_expected_flop = expected_flop
flops.append( x_expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,340 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(nDepth, return_num):
if nDepth == 2:
choices = (1, 2)
elif nDepth == 3:
choices = (1, 2, 3)
elif nDepth > 3:
choices = list(range(1, nDepth+1, 2))
if choices[-1] < nDepth: choices.append(nDepth)
else:
raise ValueError('invalid nDepth : {:}'.format(nDepth))
if return_num: return len(choices)
else : return choices
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=False)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
def get_flops(self, divide=1):
iC, oC = self.in_dim, self.out_dim
assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_flops(self, divide=1):
flop_A = self.conv_a.get_flops(divide)
flop_B = self.conv_b.get_flops(divide)
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops(divide)
else:
flop_C = 0
return flop_A + flop_B + flop_C
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range()
def get_flops(self, divide):
flop_A = self.conv_1x1.get_flops(divide)
flop_B = self.conv_3x3.get_flops(divide)
flop_C = self.conv_1x4.get_flops(divide)
if hasattr(self.downsample, 'get_flops'):
flop_D = self.downsample.get_flops(divide)
else:
flop_D = 0
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
class SearchDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchDepthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage in range(3):
cur_block_choices = get_depth_choices(layer_blocks, False)
assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks)
self.message += "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(stage, cur_block_choices, layer_blocks)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices: block_choices.append( layer_index )
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {'choices': block_choices,
'stage' : stage,
'xstart' : xstart}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append( (xend, info) )
xstart, xstage = info['xstart'], info['stage']
for ilayer in range(xstart, xend+1):
idx = bisect_right(info['choices'], ilayer-1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))))
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.depth_attentions]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
# select depth
if mode == 'genotype':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
elif mode == 'max':
choices = [depth_probs.size(1)-1 for _ in range(depth_probs.size(0))]
elif mode == 'random':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
else:
raise ValueError('invalid mode : {:}'.format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop+= layer.get_flops()
else:
flop+= 0 # do not use this layer
else:
flop+= layer.get_flops()
# the last fc layer
flop += self.classifier.in_features * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xblocks'] = selected_layers
config_dict['super_type'] = 'infer-depth'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for depth, there are {:} attention probabilities.".format(len(self.depth_attentions))
string+= '\n{:}'.format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob))
logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:17s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] )
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
x, flops = inputs, []
feature_maps = []
for i, layer in enumerate(self.layers):
layer_i = layer( x )
feature_maps.append( layer_i )
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]['choices']
xstagei = self.depth_info[i]['stage']
possible_tensors = []
for tempi, A in enumerate(choices):
xtensor = feature_maps[A]
possible_tensors.append( xtensor )
weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) )
x = weighted_sum
else:
x = layer_i
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
#print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6)))
x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops(1e6)
else:
x_expected_flop = layer.get_flops(1e6)
flops.append( x_expected_flop )
flops.append( (self.classifier.in_features * self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,393 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices as get_choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
#else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, 'invalid length : {:}'.format(index)
# compute expected flop
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu: out = self.relu( out )
else : out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.has_bn:out= self.BNs[-1]( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1]
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range()
def get_flops(self, channels):
assert len(channels) == 4, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, 'get_flops'):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1]
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) )
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
class SearchWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, num_classes):
super(SearchWidthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))))
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.width_attentions]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
#weights = [F.softmax(x, dim=0) for x in self.width_attentions]
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == 'genotype':
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][ torch.argmax(probe).item() ]
elif mode == 'max':
C = self.Ranges[i][-1]
elif mode == 'fix':
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
elif mode == 'random':
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
for j in range(prob.size(0)):
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['super_type'] = 'infer-width'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions))
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:52s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
last_channel_idx += layer.num_conv
flops.append( expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,482 @@
import math, torch
from collections import OrderedDict
from bisect import bisect_right
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices
def get_depth_choices(layers):
min_depth = min(layers)
info = {'num': min_depth}
for i, depth in enumerate(layers):
choices = []
for j in range(1, min_depth+1):
choices.append( int( float(depth)*j/min_depth ) )
info[i] = choices
return info
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu, last_max_pool=False):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_width_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
#else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
if last_max_pool: self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else : self.maxpool = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, 'invalid length : {:}'.format(index)
# compute expected flop
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu : out = self.relu( out )
if self.maxpool: out = self.maxpool(out)
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.has_bn:out= self.BNs[-1]( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
if self.maxpool: out = self.maxpool(out)
return out
class ResNetBasicblock(nn.Module):
expansion = 1
num_conv = 2
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True, has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self):
return self.conv_a.get_range() + self.conv_b.get_range()
def get_flops(self, channels):
assert len(channels) == 3, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv_b.OutShape[0] * self.conv_b.OutShape[1]
return flop_A + flop_B + flop_C
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
#import pdb; pdb.set_trace()
out_a, expected_inC_a, expected_flop_a = self.conv_a( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_b, expected_inC_b, expected_flop_b = self.conv_b( (out_a , expected_inC_a, probability[1], indexes[1], probs[1]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[1], indexes[1], probs[1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_b)
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
def basic_forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False,has_bn=True, has_relu=False)
else:
self.downsample = None
self.out_dim = planes * self.expansion
self.search_mode = 'basic'
def get_range(self):
return self.conv_1x1.get_range() + self.conv_3x3.get_range() + self.conv_1x4.get_range()
def get_flops(self, channels):
assert len(channels) == 4, 'invalid channels : {:}'.format(channels)
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
if hasattr(self.downsample, 'get_flops'):
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_D = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_D = channels[0] * channels[-1] * self.conv_1x4.OutShape[0] * self.conv_1x4.OutShape[1]
return flop_A + flop_B + flop_C + flop_D
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def basic_forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, bottleneck)
return nn.functional.relu(out, inplace=True)
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( (out_1x1,expected_inC_1x1, probability[1], indexes[1], probs[1]) )
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( (out_3x3,expected_inC_3x3, probability[2], indexes[2], probs[2]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[2], indexes[2], probs[2]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out_1x4)
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
class SearchShapeImagenetResNet(nn.Module):
def __init__(self, block_name, layers, deep_stem, num_classes):
super(SearchShapeImagenetResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'BasicBlock':
block = ResNetBasicblock
elif block_name == 'Bottleneck':
block = ResNetBottleneck
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}'.format(sum(layers)*block.num_conv, layers)
self.num_classes = num_classes
if not deep_stem:
self.layers = nn.ModuleList( [ ConvBNReLU(3, 64, 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True, last_max_pool=True) ] )
self.channels = [64]
else:
self.layers = nn.ModuleList( [ ConvBNReLU(3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
,ConvBNReLU(32,64, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True, last_max_pool=True) ] )
self.channels = [32, 64]
meta_depth_info = get_depth_choices(layers)
self.InShape = None
self.depth_info = OrderedDict()
self.depth_at_i = OrderedDict()
for stage, layer_blocks in enumerate(layers):
cur_block_choices = meta_depth_info[stage]
assert cur_block_choices[-1] == layer_blocks, 'stage={:}, {:} vs {:}'.format(stage, cur_block_choices, layer_blocks)
block_choices, xstart = [], len(self.layers)
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 64 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
# added for depth
layer_index = len(self.layers) - 1
if iL + 1 in cur_block_choices: block_choices.append( layer_index )
if iL + 1 == layer_blocks:
self.depth_info[layer_index] = {'choices': block_choices,
'stage' : stage,
'xstart' : xstart}
self.depth_info_list = []
for xend, info in self.depth_info.items():
self.depth_info_list.append( (xend, info) )
xstart, xstage = info['xstart'], info['stage']
for ilayer in range(xstart, xend+1):
idx = bisect_right(info['choices'], ilayer-1)
self.depth_at_i[ilayer] = (xstage, idx)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))))
self.register_parameter('depth_attentions', nn.Parameter(torch.Tensor(len(layers), meta_depth_info['num'])))
nn.init.normal_(self.width_attentions, 0, 0.01)
nn.init.normal_(self.depth_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self, LR=None):
if LR is None:
return [self.width_attentions, self.depth_attentions]
else:
return [
{"params": self.width_attentions, "lr": LR},
{"params": self.depth_attentions, "lr": LR},
]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
# select channels
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == 'genotype':
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][ torch.argmax(probe).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
# select depth
if mode == 'genotype':
with torch.no_grad():
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
else:
raise ValueError('invalid mode : {:}'.format(mode))
selected_layers = []
for choice, xvalue in zip(choices, self.depth_info_list):
xtemp = xvalue[1]['choices'][choice] - xvalue[1]['xstart'] + 1
selected_layers.append(xtemp)
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
if xatti <= choices[xstagei]: # leave this depth
flop+= layer.get_flops(xchl)
else:
flop+= 0 # do not use this layer
else:
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['xblocks'] = selected_layers
config_dict['super_type'] = 'infer-shape'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for depth and width, there are {:} + {:} attention probabilities.".format(len(self.depth_attentions), len(self.width_attentions))
string+= '\n{:}'.format(self.depth_info)
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.depth_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.depth_attentions), ' '.join(prob))
logt = ['{:.4f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:17s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || discrepancy={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
string += '\n-----------------------------------------------'
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:52s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
flop_depth_probs = torch.flip( torch.cumsum( torch.flip(flop_depth_probs, [1]), 1 ), [1] )
selected_widths, selected_width_probs = select2withP(self.width_attentions, self.tau)
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
feature_maps = []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths [last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_width_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_width_probs [last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
feature_maps.append( x )
last_channel_idx += layer.num_conv
if i in self.depth_info: # aggregate the information
choices = self.depth_info[i]['choices']
xstagei = self.depth_info[i]['stage']
#print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
#for A, W in zip(choices, selected_depth_probs[xstagei]):
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
possible_tensors = []
max_C = max( feature_maps[A].size(1) for A in choices )
for tempi, A in enumerate(choices):
xtensor = ChannelWiseInter(feature_maps[A], max_C)
possible_tensors.append( xtensor )
weighted_sum = sum( xtensor * W for xtensor, W in zip(possible_tensors, selected_depth_probs[xstagei]) )
x = weighted_sum
if i in self.depth_at_i:
xstagei, xatti = self.depth_at_i[i]
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
else:
x_expected_flop = expected_flop
flops.append( x_expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,316 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices as get_choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
#else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, 'invalid length : {:}'.format(index)
# compute expected flop
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu: out = self.relu( out )
else : out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.has_bn:out= self.BNs[-1]( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
return out
class SimBlock(nn.Module):
expansion = 1
num_conv = 1
def __init__(self, inplanes, planes, stride):
super(SimBlock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self):
return self.conv.get_range()
def get_flops(self, channels):
assert len(channels) == 2, 'invalid channels : {:}'.format(channels)
flop_A = self.conv.get_flops([channels[0], channels[1]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv.OutShape[0] * self.conv.OutShape[1]
return flop_A + flop_C
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1, 'invalid size : {:}, {:}, {:}'.format(indexes.size(), probs.size(), probability.size())
out, expected_next_inC, expected_flop = self.conv( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[-1], indexes[-1], probs[-1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out)
return nn.functional.relu(out, inplace=True), expected_next_inC, sum([expected_flop, expected_flop_c])
def basic_forward(self, inputs):
basicblock = self.conv(inputs)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class SearchWidthSimResNet(nn.Module):
def __init__(self, depth, num_classes):
super(SearchWidthSimResNet, self).__init__()
assert (depth - 2) % 3 == 0, 'depth should be one of 5, 8, 11, 14, ... instead of {:}'.format(depth)
layer_blocks = (depth - 2) // 3
self.message = 'SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = SimBlock(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))))
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.width_attentions]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
#weights = [F.softmax(x, dim=0) for x in self.width_attentions]
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == 'genotype':
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][ torch.argmax(probe).item() ]
elif mode == 'max':
C = self.Ranges[i][-1]
elif mode == 'fix':
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
elif mode == 'random':
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
for j in range(prob.size(0)):
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['super_type'] = 'infer-width'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions))
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:52s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
last_channel_idx += layer.num_conv
flops.append( expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -0,0 +1,111 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
if tau <= 0:
new_logits = logits
probs = nn.functional.softmax(new_logits, dim=1)
else :
while True: # a trick to avoid the gumbels bug
gumbels = -torch.empty_like(logits).exponential_().log()
new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
probs = nn.functional.softmax(new_logits, dim=1)
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break
if just_prob: return probs
#with torch.no_grad(): # add eps for unexpected torch error
# probs = nn.functional.softmax(new_logits, dim=1)
# selected_index = torch.multinomial(probs + eps, 2, False)
with torch.no_grad(): # add eps for unexpected torch error
probs = probs.cpu()
selected_index = torch.multinomial(probs + eps, num, False).to(logits.device)
selected_logit = torch.gather(new_logits, 1, selected_index)
selcted_probs = nn.functional.softmax(selected_logit, dim=1)
return selected_index, selcted_probs
def ChannelWiseInter(inputs, oC, mode='v2'):
if mode == 'v1':
return ChannelWiseInterV1(inputs, oC)
elif mode == 'v2':
return ChannelWiseInterV2(inputs, oC)
else:
raise ValueError('invalid mode : {:}'.format(mode))
def ChannelWiseInterV1(inputs, oC):
assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size())
def start_index(a, b, c):
return int( math.floor(float(a * c) / b) )
def end_index(a, b, c):
return int( math.ceil(float((a + 1) * c) / b) )
batch, iC, H, W = inputs.size()
outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device)
if iC == oC: return inputs
for ot in range(oC):
istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC)
values = inputs[:, istartT:iendT].mean(dim=1)
outputs[:, ot, :, :] = values
return outputs
def ChannelWiseInterV2(inputs, oC):
assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size())
batch, C, H, W = inputs.size()
if C == oC: return inputs
else : return nn.functional.adaptive_avg_pool3d(inputs, (oC,H,W))
#inputs_5D = inputs.view(batch, 1, C, H, W)
#otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None)
#otputs = otputs_5D.view(batch, oC, H, W)
#otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False)
#return otputs
def linear_forward(inputs, linear):
if linear is None: return inputs
iC = inputs.size(1)
weight = linear.weight[:, :iC]
if linear.bias is None: bias = None
else : bias = linear.bias
return nn.functional.linear(inputs, weight, bias)
def get_width_choices(nOut):
xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
if nOut is None:
return len(xsrange)
else:
Xs = [int(nOut * i) for i in xsrange]
#xs = [ int(nOut * i // 10) for i in range(2, 11)]
#Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1]
Xs = sorted( list( set(Xs) ) )
return tuple(Xs)
def get_depth_choices(nDepth):
if nDepth is None:
return 3
else:
assert nDepth >= 3, 'nDepth should be greater than 2 vs {:}'.format(nDepth)
if nDepth == 1 : return (1, 1, 1)
elif nDepth == 2: return (1, 1, 2)
elif nDepth >= 3:
return (nDepth//3, nDepth*2//3, nDepth)
else:
raise ValueError('invalid Depth : {:}'.format(nDepth))
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = x * (mask / keep_prob)
#x.div_(keep_prob)
#x.mul_(mask)
return x

View File

@ -0,0 +1,8 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .SearchCifarResNet_width import SearchWidthCifarResNet
from .SearchCifarResNet_depth import SearchDepthCifarResNet
from .SearchCifarResNet import SearchShapeCifarResNet
from .SearchSimResNet_width import SearchWidthSimResNet
from .SearchImagenetResNet import SearchShapeImagenetResNet

View File

@ -0,0 +1,20 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
from SoftSelect import ChannelWiseInter
if __name__ == '__main__':
tensors = torch.rand((16, 128, 7, 7))
for oc in range(200, 210):
out_v1 = ChannelWiseInter(tensors, oc, 'v1')
out_v2 = ChannelWiseInter(tensors, oc, 'v2')
assert (out_v1 == out_v2).any().item() == 1
for oc in range(48, 160):
out_v1 = ChannelWiseInter(tensors, oc, 'v1')
out_v2 = ChannelWiseInter(tensors, oc, 'v2')
assert (out_v1 == out_v2).any().item() == 1

4
reproduce.sh Normal file
View File

@ -0,0 +1,4 @@
python search.py --dataset cifar10
python search.py --dataset cifar10 --trainval
python search.py --dataset cifar100
python search.py --dataset ImageNet16-120

156
search.py Normal file
View File

@ -0,0 +1,156 @@
import os
import time
import argparse
import random
import numpy as np
from tqdm import trange
from statistics import mean
parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../datasets/cifar', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='NAS-Bench-201-v1_1-096897.pth',
type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torch.optim as optim
from models import get_cell_based_tiny_net
# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
import torchvision.transforms as transforms
from datasets import get_datasets
from nas_201_api import NASBench201API as API
def get_batch_jacobian(net, x, target, to, device, args=None):
net.zero_grad()
x.requires_grad_(True)
_, y = net(x)
y.backward(torch.ones_like(y))
jacob = x.grad.detach()
return jacob, target.detach()
def evidenceapprox_eval_score(jacob, labels=None):
corrs = np.corrcoef(jacob)
v, _ = np.linalg.eig(corrs)
k = 1e-5
return -np.sum(np.log(v + k) + 1./(v + k))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
THE_START = time.time()
api = API(args.api_loc)
os.makedirs(args.save_loc, exist_ok=True)
train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_loc, cutout=0)
if args.dataset == 'cifar10':
acc_type = 'ori-test'
val_acc_type = 'x-valid'
else:
acc_type = 'x-test'
val_acc_type = 'x-valid'
if args.trainval:
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
else:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
num_workers=0, pin_memory=True)
times = []
chosen = []
acc = []
val_acc = []
topscores = []
dset = args.dataset if not args.trainval else 'cifar10-valid'
order_fn = np.nanargmax
runs = trange(args.n_runs, desc='acc: ')
for N in runs:
start = time.time()
indices = np.random.randint(0,15625,args.n_samples)
scores = []
for arch in indices:
data_iterator = iter(train_loader)
x, target = next(data_iterator)
x, target = x.to(device), target.to(device)
config = api.get_net_config(arch, args.dataset)
config['num_classes'] = 1
network = get_cell_based_tiny_net(config) # create the network from configuration
network = network.to(device)
network.eval()
jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args)
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()
try:
s = evidenceapprox_eval_score(jacobs, labels)
except Exception as e:
print(e)
s = np.nan
scores.append(s)
best_arch = indices[order_fn(scores)]
info = api.query_by_index(best_arch)
topscores.append(scores[order_fn(scores)])
chosen.append(best_arch)
acc.append(info.get_metrics(dset, acc_type)['accuracy'])
if not args.dataset == 'cifar10' or args.trainval:
val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy'])
times.append(time.time()-start)
runs.set_description(f"acc: {mean(acc if not args.trainval else val_acc):.2f}%")
print(f"Final mean test accuracy: {np.mean(acc)}")
if len(val_acc) > 1:
print(f"Final mean validation accuracy: {np.mean(val_acc)}")
state = {'accs': acc,
'val_accs': val_acc,
'chosen': chosen,
'times': times,
'topscores': topscores,
}
dset = args.dataset if not args.trainval else 'cifar10-valid'
fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.mc_samples}_{args.alpha}_{args.seed}.t7"
torch.save(state, fname)