NAS-sharing-parameters support 3 datasets / update ops / update pypi
This commit is contained in:
		| @@ -1,6 +1,6 @@ | |||||||
| MIT License | MIT License | ||||||
|  |  | ||||||
| Copyright (c) 2019 Xuanyi Dong (GitHub: https://github.com/D-X-Y) | Copyright (c) 2018-2020 Xuanyi Dong (GitHub: https://github.com/D-X-Y) | ||||||
|  |  | ||||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
| of this software and associated documentation files (the "Software"), to deal | of this software and associated documentation files (the "Software"), to deal | ||||||
|   | |||||||
| @@ -12,6 +12,10 @@ In this Markdown file, we provide: | |||||||
|  |  | ||||||
| Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. | Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. | ||||||
|  |  | ||||||
|  | Simply type `pip install nas-bench-102` to install our api. | ||||||
|  |  | ||||||
|  | If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/NAS-Projects/issues) or email me. | ||||||
|  |  | ||||||
| ### Preparation and Download | ### Preparation and Download | ||||||
|  |  | ||||||
| The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). | The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). | ||||||
| @@ -179,3 +183,17 @@ If researchers can provide better results with different hyper-parameters, we ar | |||||||
| - [8] `bash ./scripts-search/algos/Random.sh -1` | - [8] `bash ./scripts-search/algos/Random.sh -1` | ||||||
| - [9] `bash ./scripts-search/algos/REINFORCE.sh -1` | - [9] `bash ./scripts-search/algos/REINFORCE.sh -1` | ||||||
| - [10] `bash ./scripts-search/algos/BOHB.sh -1` | - [10] `bash ./scripts-search/algos/BOHB.sh -1` | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Citation | ||||||
|  |  | ||||||
|  | If you find that NAS-Bench-102 helps your research, please consider citing it: | ||||||
|  | ``` | ||||||
|  | @inproceedings{dong2020nasbench102, | ||||||
|  |   title     = {NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search}, | ||||||
|  |   author    = {Dong, Xuanyi and Yang, Yi}, | ||||||
|  |   booktitle = {International Conference on Learning Representations (ICLR)}, | ||||||
|  |   year      = {2020} | ||||||
|  | } | ||||||
|  | ``` | ||||||
|   | |||||||
| @@ -34,6 +34,8 @@ We build a new benchmark for neural architecture search, please see more details | |||||||
|  |  | ||||||
| The benchmark data file (v1.0) is `NAS-Bench-102-v1_0-e61699.pth`, which can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs). | The benchmark data file (v1.0) is `NAS-Bench-102-v1_0-e61699.pth`, which can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs). | ||||||
|  |  | ||||||
|  | Now you can simply use our API by `pip install nas-bench-102`. | ||||||
|  |  | ||||||
| ## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717) | ## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717) | ||||||
| [](https://paperswithcode.com/sota/network-pruning-on-cifar-100?p=network-pruning-via-transformable) | [](https://paperswithcode.com/sota/network-pruning-on-cifar-100?p=network-pruning-via-transformable) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										28
									
								
								exps/NAS-Bench-102/dist-setup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								exps/NAS-Bench-102/dist-setup.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | import os | ||||||
|  | from setuptools import setup | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def read(fname='README.md'): | ||||||
|  |   with open(os.path.join(os.path.dirname(__file__), fname), encoding='utf-8') as cfile: | ||||||
|  |     return cfile.read() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | setup( | ||||||
|  |     name = "nas_bench_102", | ||||||
|  |     version = "1.0", | ||||||
|  |     author = "Xuanyi Dong", | ||||||
|  |     author_email = "dongxuanyi888@gmail.com", | ||||||
|  |     description = "API for NAS-Bench-102 (a benchmark for neural architecture search).", | ||||||
|  |     license = "MIT", | ||||||
|  |     keywords = "NAS Dataset API DeepLearning", | ||||||
|  |     url = "https://github.com/D-X-Y/NAS-Projects", | ||||||
|  |     packages=['nas_102_api'], | ||||||
|  |     long_description=read('README.md'), | ||||||
|  |     long_description_content_type='text/markdown', | ||||||
|  |     classifiers=[ | ||||||
|  |         "Programming Language :: Python", | ||||||
|  |         "Topic :: Database", | ||||||
|  |         "Topic :: Scientific/Engineering :: Artificial Intelligence", | ||||||
|  |         "License :: OSI Approved :: MIT License", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
| @@ -1,8 +1,8 @@ | |||||||
| ################################################## | ############################################################### | ||||||
| # NAS-Bench-102 ################################## | # NAS-Bench-102, ICLR 2020 (https://arxiv.org/abs/2001.00326) # | ||||||
| ################################################## | ############################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019-2020         # | ||||||
| ################################################## | ############################################################### | ||||||
| import os, sys, time, torch, random, argparse | import os, sys, time, torch, random, argparse | ||||||
| from PIL     import ImageFile | from PIL     import ImageFile | ||||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ from pathlib import Path | |||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config, configure2str | from config_utils import load_config, dict2config, configure2str | ||||||
| from datasets     import get_datasets, SearchDataset | from datasets     import get_datasets, get_nas_search_loaders | ||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| @@ -107,35 +107,7 @@ def main(xargs): | |||||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|   #config_path = 'configs/nas-benchmark/algos/DARTS.config' |   #config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|   if xargs.dataset == 'cifar10': |   search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |  | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set |  | ||||||
|     # To split data |  | ||||||
|     train_data_v2 = deepcopy(train_data) |  | ||||||
|     train_data_v2.transform = valid_data.transform |  | ||||||
|     valid_data    = train_data_v2 |  | ||||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |  | ||||||
|     # data loader |  | ||||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |  | ||||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   elif xargs.dataset == 'cifar100': |  | ||||||
|     cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) |  | ||||||
|     search_train_data = train_data |  | ||||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform |  | ||||||
|     search_data   = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid) |  | ||||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |  | ||||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   elif xargs.dataset == 'ImageNet16-120': |  | ||||||
|     imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) |  | ||||||
|     search_train_data = train_data |  | ||||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform |  | ||||||
|     search_data   = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) |  | ||||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |  | ||||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |  | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ from pathlib import Path | |||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config, configure2str | from config_utils import load_config, dict2config, configure2str | ||||||
| from datasets     import get_datasets, SearchDataset | from datasets     import get_datasets, get_nas_search_loaders | ||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| @@ -169,28 +169,8 @@ def main(xargs): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |  | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   elif xargs.dataset.startswith('ImageNet16'): |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) |  | ||||||
|     imagenet16_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |  | ||||||
|   #config_path = 'configs/nas-benchmark/algos/DARTS.config' |  | ||||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|   # To split data |   search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) | ||||||
|   train_data_v2 = deepcopy(train_data) |  | ||||||
|   train_data_v2.transform = valid_data.transform |  | ||||||
|   valid_data    = train_data_v2 |  | ||||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |  | ||||||
|   # data loader |  | ||||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ from pathlib import Path | |||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config, configure2str | from config_utils import load_config, dict2config, configure2str | ||||||
| from datasets     import get_datasets, SearchDataset | from datasets     import get_datasets, get_nas_search_loaders | ||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| @@ -184,29 +184,14 @@ def main(xargs): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |  | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |  | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   elif xargs.dataset.startswith('ImageNet16'): |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) |  | ||||||
|     imagenet16_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |  | ||||||
|   logger.log('use config from : {:}'.format(xargs.config_path)) |   logger.log('use config from : {:}'.format(xargs.config_path)) | ||||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|   logger.log('config: {:}'.format(config)) |   _, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) | ||||||
|   # To split data |   # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader | ||||||
|   train_data_v2 = deepcopy(train_data) |   valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) | ||||||
|   train_data_v2.transform = test_data.transform |   if hasattr(valid_loader.dataset, 'transforms'): | ||||||
|   valid_data    = train_data_v2 |     valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms) | ||||||
|   # data loader |   # data loader | ||||||
|   train_loader  = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ from pathlib import Path | |||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config | from config_utils import load_config, dict2config | ||||||
| from datasets     import get_datasets, SearchDataset | from datasets     import get_datasets, get_nas_search_loaders | ||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| @@ -80,25 +80,10 @@ def main(xargs): | |||||||
|   prepare_seed(xargs.rand_seed) |   prepare_seed(xargs.rand_seed) | ||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |  | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |  | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   elif xargs.dataset.startswith('ImageNet16'): |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) |  | ||||||
|     imagenet16_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |  | ||||||
|   #config_path = 'configs/nas-benchmark/algos/GDAS.config' |   #config_path = 'configs/nas-benchmark/algos/GDAS.config' | ||||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |   search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) | ||||||
|   # data loader |  | ||||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), config.batch_size)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
| @@ -143,7 +128,7 @@ def main(xargs): | |||||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) |     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||||
|   else: |   else: | ||||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) |     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} |     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} | ||||||
|  |  | ||||||
|   # start training |   # start training | ||||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ from pathlib import Path | |||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config, configure2str | from config_utils import load_config, dict2config, configure2str | ||||||
| from datasets     import get_datasets, SearchDataset | from datasets     import get_datasets, get_nas_search_loaders | ||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| @@ -117,32 +117,9 @@ def main(xargs): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |  | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   #elif xargs.dataset.startswith('ImageNet16'): |  | ||||||
|   #  # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes) |  | ||||||
|   #  # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :]) |  | ||||||
|   #  # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None) |  | ||||||
|   #  # _ = configure2str(imagenet16_split, 'temp.txt') |  | ||||||
|   #  split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) |  | ||||||
|   #  imagenet16_split = load_config(split_Fpath, None, None) |  | ||||||
|   #  train_split, valid_split = imagenet16_split.train, imagenet16_split.valid |  | ||||||
|   #  logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |  | ||||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|   logger.log('config : {:}'.format(config)) |   search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ | ||||||
|   # To split data |                                         (config.batch_size, config.test_batch_size), xargs.workers) | ||||||
|   train_data_v2 = deepcopy(train_data) |  | ||||||
|   train_data_v2.transform = valid_data.transform |  | ||||||
|   valid_data    = train_data_v2 |  | ||||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |  | ||||||
|   # data loader |  | ||||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ from pathlib import Path | |||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config, configure2str | from config_utils import load_config, dict2config, configure2str | ||||||
| from datasets     import get_datasets, SearchDataset | from datasets     import get_datasets, get_nas_search_loaders | ||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
| @@ -135,29 +135,9 @@ def main(xargs): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |  | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |  | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   elif xargs.dataset.startswith('ImageNet16'): |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) |  | ||||||
|     imagenet16_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |  | ||||||
|   #config_path = 'configs/nas-benchmark/algos/SETN.config' |  | ||||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|   # To split data |   search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ | ||||||
|   train_data_v2 = deepcopy(train_data) |                                         (config.batch_size, config.test_batch_size), xargs.workers) | ||||||
|   train_data_v2.transform = valid_data.transform |  | ||||||
|   valid_data    = train_data_v2 |  | ||||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |  | ||||||
|   # data loader |  | ||||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data,  batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
| @@ -202,7 +182,8 @@ def main(xargs): | |||||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) |     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||||
|   else: |   else: | ||||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) |     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} |     init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num) | ||||||
|  |     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: init_genotype} | ||||||
|  |  | ||||||
|   # start training |   # start training | ||||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ################################################## | ||||||
| from .get_dataset_with_transform import get_datasets | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||||
| from .SearchDatasetWrap import SearchDataset | from .SearchDatasetWrap import SearchDataset | ||||||
|   | |||||||
| @@ -6,8 +6,12 @@ import os.path as osp | |||||||
| import numpy as np | import numpy as np | ||||||
| import torchvision.datasets as dset | import torchvision.datasets as dset | ||||||
| import torchvision.transforms as transforms | import torchvision.transforms as transforms | ||||||
|  | from copy import deepcopy | ||||||
| from PIL import Image | from PIL import Image | ||||||
|  |  | ||||||
| from .DownsampledImageNet import ImageNet16 | from .DownsampledImageNet import ImageNet16 | ||||||
|  | from .SearchDatasetWrap import SearchDataset | ||||||
|  | from config_utils import load_config | ||||||
|  |  | ||||||
|  |  | ||||||
| Dataset2Class = {'cifar10' : 10, | Dataset2Class = {'cifar10' : 10, | ||||||
| @@ -177,6 +181,47 @@ def get_datasets(name, root, cutout): | |||||||
|   class_num = Dataset2Class[name] |   class_num = Dataset2Class[name] | ||||||
|   return train_data, test_data, xshape, class_num |   return train_data, test_data, xshape, class_num | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers): | ||||||
|  |   if isinstance(batch_size, (list,tuple)): | ||||||
|  |     batch, test_batch = batch_size | ||||||
|  |   else: | ||||||
|  |     batch, test_batch = batch_size, batch_size | ||||||
|  |   if dataset == 'cifar10': | ||||||
|  |     #split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|  |     cifar_split = load_config('{:}/cifar-split.txt'.format(config_root), None, None) | ||||||
|  |     train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set | ||||||
|  |     #logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||||
|  |     # To split data | ||||||
|  |     xvalid_data  = deepcopy(train_data) | ||||||
|  |     if hasattr(xvalid_data, 'transforms'): # to avoid a print issue | ||||||
|  |       xvalid_data.transforms = valid_data.transform | ||||||
|  |     xvalid_data.transform  = deepcopy( valid_data.transform ) | ||||||
|  |     search_data   = SearchDataset(dataset, train_data, train_split, valid_split) | ||||||
|  |     # data loader | ||||||
|  |     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||||
|  |     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True) | ||||||
|  |     valid_loader  = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True) | ||||||
|  |   elif dataset == 'cifar100': | ||||||
|  |     cifar100_test_split = load_config('{:}/cifar100-test-split.txt'.format(config_root), None, None) | ||||||
|  |     search_train_data = train_data | ||||||
|  |     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||||
|  |     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid) | ||||||
|  |     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||||
|  |     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||||
|  |     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||||
|  |   elif dataset == 'ImageNet16-120': | ||||||
|  |     imagenet_test_split = load_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None) | ||||||
|  |     search_train_data = train_data | ||||||
|  |     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||||
|  |     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) | ||||||
|  |     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||||
|  |     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||||
|  |     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||||
|  |   else: | ||||||
|  |     raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||||
|  |   return search_loader, train_loader, valid_loader | ||||||
|  |  | ||||||
| #if __name__ == '__main__': | #if __name__ == '__main__': | ||||||
| #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | ||||||
| #  import pdb; pdb.set_trace() | #  import pdb; pdb.set_trace() | ||||||
|   | |||||||
| @@ -13,16 +13,22 @@ OPS = { | |||||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), |   'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), | ||||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), |   'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), | ||||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), |   'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), | ||||||
|  |   'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), | ||||||
|  |   'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats), | ||||||
|  |   'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats:     SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats), | ||||||
|  |   'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats:     SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats), | ||||||
|   'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), |   'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), | ||||||
| } | } | ||||||
|  |  | ||||||
| CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] | CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||||
| NAS_BENCH_102         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | NAS_BENCH_102         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||||
|  | DARTS_SPACE           = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3'] | ||||||
|  |  | ||||||
| SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | ||||||
|                     'aa-nas'       : NAS_BENCH_102, |                     'aa-nas'       : NAS_BENCH_102, | ||||||
|                     'nas-bench-102': NAS_BENCH_102, |                     'nas-bench-102': NAS_BENCH_102, | ||||||
|                     'full'         : sorted(list(OPS.keys()))} |                     'darts'        : DARTS_SPACE} | ||||||
|  |                     #'full'         : sorted(list(OPS.keys()))} | ||||||
|  |  | ||||||
|  |  | ||||||
| class ReLUConvBN(nn.Module): | class ReLUConvBN(nn.Module): | ||||||
| @@ -39,6 +45,34 @@ class ReLUConvBN(nn.Module): | |||||||
|     return self.op(x) |     return self.op(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SepConv(nn.Module): | ||||||
|  |      | ||||||
|  |   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||||
|  |     super(SepConv, self).__init__() | ||||||
|  |     self.op = nn.Sequential( | ||||||
|  |       nn.ReLU(inplace=False), | ||||||
|  |       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), | ||||||
|  |       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||||
|  |       nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), | ||||||
|  |       ) | ||||||
|  |  | ||||||
|  |   def forward(self, x): | ||||||
|  |     return self.op(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DualSepConv(nn.Module): | ||||||
|  |      | ||||||
|  |   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||||
|  |     super(DualSepConv, self).__init__() | ||||||
|  |     self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats) | ||||||
|  |     self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats) | ||||||
|  |  | ||||||
|  |   def forward(self, x): | ||||||
|  |     x = self.op_a(x) | ||||||
|  |     x = self.op_b(x) | ||||||
|  |     return x | ||||||
|  |  | ||||||
|  |  | ||||||
| class ResNetBasicblock(nn.Module): | class ResNetBasicblock(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, inplanes, planes, stride, affine=True): |   def __init__(self, inplanes, planes, stride, affine=True): | ||||||
|   | |||||||
| @@ -3,3 +3,5 @@ | |||||||
| ################################################## | ################################################## | ||||||
| from .api import NASBench102API | from .api import NASBench102API | ||||||
| from .api import ArchResults, ResultsCount | from .api import ArchResults, ResultsCount | ||||||
|  |  | ||||||
|  | NAS_BENCH_102_API_VERSION="v1.0" | ||||||
|   | |||||||
| @@ -1,8 +1,12 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################################################# | ############################################################################################ | ||||||
| # NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search # | # NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 # | ||||||
| ################################################################################# | ############################################################################################ | ||||||
|  | # NAS-Bench-102-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID. | ||||||
|  | # | ||||||
|  | # | ||||||
|  | # | ||||||
| import os, sys, copy, random, torch, numpy as np | import os, sys, copy, random, torch, numpy as np | ||||||
| from collections import OrderedDict, defaultdict | from collections import OrderedDict, defaultdict | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										26
									
								
								scripts-search/NAS-Bench-102/build.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								scripts-search/NAS-Bench-102/build.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # bash scripts-search/NAS-Bench-102/build.sh | ||||||
|  | echo script name: $0 | ||||||
|  | echo $# arguments | ||||||
|  | if [ "$#" -ne 0 ] ;then | ||||||
|  |   echo "Input illegal number of parameters " $# | ||||||
|  |   echo "Need 0 parameters" | ||||||
|  |   exit 1 | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | save_dir=./output/nas_bench_102_package | ||||||
|  | echo "Prepare to build the package in ${save_dir}" | ||||||
|  | rm -rf ${save_dir} | ||||||
|  | mkdir -p ${save_dir} | ||||||
|  |  | ||||||
|  | #cp NAS-Bench-102.md ${save_dir}/README.md | ||||||
|  | sed '125,187d' NAS-Bench-102.md > ${save_dir}/README.md | ||||||
|  | cp LICENSE.md ${save_dir}/LICENSE.md | ||||||
|  | cp -r lib/nas_102_api ${save_dir}/ | ||||||
|  | rm -rf ${save_dir}/nas_102_api/__pycache__ | ||||||
|  | cp exps/NAS-Bench-102/dist-setup.py ${save_dir}/setup.py | ||||||
|  |  | ||||||
|  | cd ${save_dir} | ||||||
|  | # python setup.py sdist bdist_wheel | ||||||
|  | # twine upload --repository-url https://test.pypi.org/legacy/ dist/* | ||||||
|  | # twine upload dist/* | ||||||
		Reference in New Issue
	
	Block a user