NAS-sharing-parameters support 3 datasets / update ops / update pypi
This commit is contained in:
parent
96152a9904
commit
c66afa4df8
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Xuanyi Dong (GitHub: https://github.com/D-X-Y)
|
||||
Copyright (c) 2018-2020 Xuanyi Dong (GitHub: https://github.com/D-X-Y)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
@ -12,6 +12,10 @@ In this Markdown file, we provide:
|
||||
|
||||
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
||||
|
||||
Simply type `pip install nas-bench-102` to install our api.
|
||||
|
||||
If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/NAS-Projects/issues) or email me.
|
||||
|
||||
### Preparation and Download
|
||||
|
||||
The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
|
||||
@ -179,3 +183,17 @@ If researchers can provide better results with different hyper-parameters, we ar
|
||||
- [8] `bash ./scripts-search/algos/Random.sh -1`
|
||||
- [9] `bash ./scripts-search/algos/REINFORCE.sh -1`
|
||||
- [10] `bash ./scripts-search/algos/BOHB.sh -1`
|
||||
|
||||
|
||||
|
||||
# Citation
|
||||
|
||||
If you find that NAS-Bench-102 helps your research, please consider citing it:
|
||||
```
|
||||
@inproceedings{dong2020nasbench102,
|
||||
title = {NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search},
|
||||
author = {Dong, Xuanyi and Yang, Yi},
|
||||
booktitle = {International Conference on Learning Representations (ICLR)},
|
||||
year = {2020}
|
||||
}
|
||||
```
|
||||
|
@ -34,6 +34,8 @@ We build a new benchmark for neural architecture search, please see more details
|
||||
|
||||
The benchmark data file (v1.0) is `NAS-Bench-102-v1_0-e61699.pth`, which can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs).
|
||||
|
||||
Now you can simply use our API by `pip install nas-bench-102`.
|
||||
|
||||
## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717)
|
||||
[](https://paperswithcode.com/sota/network-pruning-on-cifar-100?p=network-pruning-via-transformable)
|
||||
|
||||
|
28
exps/NAS-Bench-102/dist-setup.py
Normal file
28
exps/NAS-Bench-102/dist-setup.py
Normal file
@ -0,0 +1,28 @@
|
||||
import os
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
def read(fname='README.md'):
|
||||
with open(os.path.join(os.path.dirname(__file__), fname), encoding='utf-8') as cfile:
|
||||
return cfile.read()
|
||||
|
||||
|
||||
setup(
|
||||
name = "nas_bench_102",
|
||||
version = "1.0",
|
||||
author = "Xuanyi Dong",
|
||||
author_email = "dongxuanyi888@gmail.com",
|
||||
description = "API for NAS-Bench-102 (a benchmark for neural architecture search).",
|
||||
license = "MIT",
|
||||
keywords = "NAS Dataset API DeepLearning",
|
||||
url = "https://github.com/D-X-Y/NAS-Projects",
|
||||
packages=['nas_102_api'],
|
||||
long_description=read('README.md'),
|
||||
long_description_content_type='text/markdown',
|
||||
classifiers=[
|
||||
"Programming Language :: Python",
|
||||
"Topic :: Database",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
],
|
||||
)
|
@ -1,8 +1,8 @@
|
||||
##################################################
|
||||
# NAS-Bench-102 ##################################
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
###############################################################
|
||||
# NAS-Bench-102, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019-2020 #
|
||||
###############################################################
|
||||
import os, sys, time, torch, random, argparse
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
@ -12,7 +12,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@ -107,35 +107,7 @@ def main(xargs):
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
if xargs.dataset == 'cifar10':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
|
||||
logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'cifar100':
|
||||
cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'ImageNet16-120':
|
||||
imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
@ -12,7 +12,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@ -169,28 +169,8 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@ -184,29 +184,14 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
logger.log('use config from : {:}'.format(xargs.config_path))
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
logger.log('config: {:}'.format(config))
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = test_data.transform
|
||||
valid_data = train_data_v2
|
||||
_, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
|
||||
# since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
|
||||
valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
|
||||
if hasattr(valid_loader.dataset, 'transforms'):
|
||||
valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms)
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
@ -12,7 +12,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@ -80,25 +80,10 @@ def main(xargs):
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
#config_path = 'configs/nas-benchmark/algos/GDAS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
@ -143,7 +128,7 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@ -117,32 +117,9 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
#elif xargs.dataset.startswith('ImageNet16'):
|
||||
# # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes)
|
||||
# # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :])
|
||||
# # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None)
|
||||
# # _ = configure2str(imagenet16_split, 'temp.txt')
|
||||
# split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
# imagenet16_split = load_config(split_Fpath, None, None)
|
||||
# train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
# logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
logger.log('config : {:}'.format(config))
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
|
||||
(config.batch_size, config.test_batch_size), xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
|
@ -12,7 +12,7 @@ from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
@ -135,29 +135,9 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
imagenet16_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
#config_path = 'configs/nas-benchmark/algos/SETN.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
|
||||
(config.batch_size, config.test_batch_size), xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
@ -202,7 +182,8 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num)
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: init_genotype}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
|
@ -1,5 +1,5 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .get_dataset_with_transform import get_datasets
|
||||
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
||||
from .SearchDatasetWrap import SearchDataset
|
||||
|
@ -6,8 +6,12 @@ import os.path as osp
|
||||
import numpy as np
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
from copy import deepcopy
|
||||
from PIL import Image
|
||||
|
||||
from .DownsampledImageNet import ImageNet16
|
||||
from .SearchDatasetWrap import SearchDataset
|
||||
from config_utils import load_config
|
||||
|
||||
|
||||
Dataset2Class = {'cifar10' : 10,
|
||||
@ -177,6 +181,47 @@ def get_datasets(name, root, cutout):
|
||||
class_num = Dataset2Class[name]
|
||||
return train_data, test_data, xshape, class_num
|
||||
|
||||
|
||||
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers):
|
||||
if isinstance(batch_size, (list,tuple)):
|
||||
batch, test_batch = batch_size
|
||||
else:
|
||||
batch, test_batch = batch_size, batch_size
|
||||
if dataset == 'cifar10':
|
||||
#split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config('{:}/cifar-split.txt'.format(config_root), None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
|
||||
#logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||
# To split data
|
||||
xvalid_data = deepcopy(train_data)
|
||||
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
|
||||
xvalid_data.transforms = valid_data.transform
|
||||
xvalid_data.transform = deepcopy( valid_data.transform )
|
||||
search_data = SearchDataset(dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True)
|
||||
elif dataset == 'cifar100':
|
||||
cifar100_test_split = load_config('{:}/cifar100-test-split.txt'.format(config_root), None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||
elif dataset == 'ImageNet16-120':
|
||||
imagenet_test_split = load_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
return search_loader, train_loader, valid_loader
|
||||
|
||||
#if __name__ == '__main__':
|
||||
# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
@ -13,16 +13,22 @@ OPS = {
|
||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
|
||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
|
||||
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
|
||||
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||
'aa-nas' : NAS_BENCH_102,
|
||||
'nas-bench-102': NAS_BENCH_102,
|
||||
'full' : sorted(list(OPS.keys()))}
|
||||
'darts' : DARTS_SPACE}
|
||||
#'full' : sorted(list(OPS.keys()))}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
@ -39,6 +45,34 @@ class ReLUConvBN(nn.Module):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DualSepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(DualSepConv, self).__init__()
|
||||
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
|
||||
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.op_a(x)
|
||||
x = self.op_b(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
|
@ -3,3 +3,5 @@
|
||||
##################################################
|
||||
from .api import NASBench102API
|
||||
from .api import ArchResults, ResultsCount
|
||||
|
||||
NAS_BENCH_102_API_VERSION="v1.0"
|
||||
|
@ -1,8 +1,12 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
#################################################################################
|
||||
# NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search #
|
||||
#################################################################################
|
||||
############################################################################################
|
||||
# NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# NAS-Bench-102-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||
#
|
||||
#
|
||||
#
|
||||
import os, sys, copy, random, torch, numpy as np
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
26
scripts-search/NAS-Bench-102/build.sh
Normal file
26
scripts-search/NAS-Bench-102/build.sh
Normal file
@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
# bash scripts-search/NAS-Bench-102/build.sh
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 0 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 0 parameters"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
save_dir=./output/nas_bench_102_package
|
||||
echo "Prepare to build the package in ${save_dir}"
|
||||
rm -rf ${save_dir}
|
||||
mkdir -p ${save_dir}
|
||||
|
||||
#cp NAS-Bench-102.md ${save_dir}/README.md
|
||||
sed '125,187d' NAS-Bench-102.md > ${save_dir}/README.md
|
||||
cp LICENSE.md ${save_dir}/LICENSE.md
|
||||
cp -r lib/nas_102_api ${save_dir}/
|
||||
rm -rf ${save_dir}/nas_102_api/__pycache__
|
||||
cp exps/NAS-Bench-102/dist-setup.py ${save_dir}/setup.py
|
||||
|
||||
cd ${save_dir}
|
||||
# python setup.py sdist bdist_wheel
|
||||
# twine upload --repository-url https://test.pypi.org/legacy/ dist/*
|
||||
# twine upload dist/*
|
Loading…
Reference in New Issue
Block a user