diff --git a/.gitignore b/.gitignore index fce4354..9d3d2e9 100644 --- a/.gitignore +++ b/.gitignore @@ -118,3 +118,5 @@ cx.sh NAS-Bench-102-v1_0.pth lib/NAS-Bench-102-v1_0.pth +others/TF +scripts-search/l2s-algos diff --git a/NAS-Bench-102.md b/NAS-Bench-102.md index f8783c1..2b8112b 100644 --- a/NAS-Bench-102.md +++ b/NAS-Bench-102.md @@ -14,11 +14,11 @@ Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. ### Preparation and Download -The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan]. +The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). You can move it to anywhere you want and send its path to our API for initialization. - v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. -The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan]. +The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data. ## How to Use NAS-Bench-102 diff --git a/README.md b/README.md index fd4d060..aa3099d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ More NAS resources can be found in [Awesome-NAS](https://github.com/D-X-Y/Awesom ## Requirements and Preparation -Please install `PyTorch>=1.1.0`, `Python>=3.6`, and `opencv`. +Please install `PyTorch>=1.2.0`, `Python>=3.6`, and `opencv`. The CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`. Some methods use knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Driver](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`. diff --git a/exps/NAS-Bench-102/main.py b/exps/NAS-Bench-102/main.py index 9cde574..e8124cf 100644 --- a/exps/NAS-Bench-102/main.py +++ b/exps/NAS-Bench-102/main.py @@ -213,7 +213,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se def generate_meta_info(save_dir, max_node, divide=40): - aa_nas_bench_ss = get_search_spaces('cell', 'aa-nas') + aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-102') archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index 4838c0e..64037e9 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -17,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces +from nas_102_api import NASBench102API as API def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): @@ -144,6 +145,11 @@ def main(xargs): flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() @@ -165,7 +171,7 @@ def main(xargs): start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training - start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup + start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) @@ -173,7 +179,8 @@ def main(xargs): logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) - logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) + search_time.update(time.time() - start_time) + logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy @@ -204,6 +211,8 @@ def main(xargs): if find_best: logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) + if api is not None: + logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) # measure elapsed time @@ -211,22 +220,8 @@ def main(xargs): start_time = time.time() logger.log('\n' + '-'*100) - # check the performance from the architecture dataset - #if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): - # logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) - #else: - # nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) - # geno = genotypes[total_epoch-1] - # logger.log('The last model is {:}'.format(geno)) - # info = nas_bench.query_by_arch( geno ) - # if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) - # else : logger.log('{:}'.format(info)) - # logger.log('-'*100) - # geno = genotypes['best'] - # logger.log('The best model is {:}'.format(geno)) - # info = nas_bench.query_by_arch( geno ) - # if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) - # else : logger.log('{:}'.format(info)) + logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) + if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) logger.close() diff --git a/exps/algos/R_EA.py b/exps/algos/R_EA.py index 03a38ed..e64d9a6 100644 --- a/exps/algos/R_EA.py +++ b/exps/algos/R_EA.py @@ -59,9 +59,9 @@ def train_and_eval(arch, nas_bench, extra_info): if nas_bench is not None: arch_index = nas_bench.query_index_by_arch( arch ) assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) - info = nas_bench.arch2infos[ arch_index ] - _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs - #import pdb; pdb.set_trace() + info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True) + import pdb; pdb.set_trace() + #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs else: # train a model from scratch. raise ValueError('NOT IMPLEMENT YET') diff --git a/lib/models/__init__.py b/lib/models/__init__.py index 5acd8ec..50a6b36 100644 --- a/lib/models/__init__.py +++ b/lib/models/__init__.py @@ -36,6 +36,7 @@ def get_cell_based_tiny_net(config): def get_search_spaces(xtype, name): if xtype == 'cell': from .cell_operations import SearchSpaceNames + assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) return SearchSpaceNames[name] else: raise ValueError('invalid search-space type is {:}'.format(xtype)) diff --git a/lib/models/cell_operations.py b/lib/models/cell_operations.py index 5f39a99..5e2b779 100644 --- a/lib/models/cell_operations.py +++ b/lib/models/cell_operations.py @@ -16,12 +16,13 @@ OPS = { 'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), } -CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] -AA_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] +CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] +NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] -SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, - 'aa-nas' : AA_NAS_BENCHMARK, - 'full' : sorted(list(OPS.keys()))} +SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, + 'aa-nas' : NAS_BENCH_102, + 'nas-bench-102': NAS_BENCH_102, + 'full' : sorted(list(OPS.keys()))} class ReLUConvBN(nn.Module): diff --git a/lib/nas_102_api/api.py b/lib/nas_102_api/api.py index 4e0053c..e6b2832 100644 --- a/lib/nas_102_api/api.py +++ b/lib/nas_102_api/api.py @@ -129,6 +129,27 @@ class NASBench102API(object): assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) return copy.deepcopy(self.meta_archs[index]) + def get_more_info(self, index, dataset, use_12epochs_result=False): + if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less + else : basestr, arch2infos = '200epochs', self.arch2infos_full + archresult = arch2infos[index] + if dataset == 'cifar10-valid': + train_info = archresult.get_metrics(dataset, 'train', is_random=True) + valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True) + test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True) + total = train_info['iepoch'] + 1 + return {'train-loss' : train_info['loss'], + 'train-accuracy': train_info['accuracy'], + 'train-all-time': train_info['all_time'], + 'valid-loss' : valid_info['loss'], + 'valid-accuracy': valid_info['accuracy'], + 'valid-all-time': valid_info['all_time'], + 'valid-per-time': valid_info['all_time'] / total, + 'test-loss' : test__info['loss'], + 'test-accuracy' : test__info['accuracy']} + else: + raise ValueError('coming soon...') + def show(self, index=-1): if index < 0: # show all architectures print(self) @@ -367,23 +388,28 @@ class ResultsCount(object): def get_train(self, iepoch=None): if iepoch is None: iepoch = self.epochs-1 assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) - if self.train_times is not None: xtime = self.train_times[iepoch] - else : xtime = None + if self.train_times is not None: + xtime = self.train_times[iepoch] + atime = sum([self.train_times[i] for i in range(iepoch+1)]) + else: xtime, atime = None, None return {'iepoch' : iepoch, 'loss' : self.train_losses[iepoch], 'accuracy': self.train_acc1es[iepoch], - 'time' : xtime} + 'cur_time': xtime, + 'all_time': atime} def get_eval(self, name, iepoch=None): if iepoch is None: iepoch = self.epochs-1 assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: xtime = self.eval_times['{:}@{:}'.format(name,iepoch)] - else: xtime = None + atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)]) + else: xtime, atime = None, None return {'iepoch' : iepoch, 'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)], 'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)], - 'time' : xtime} + 'cur_time': xtime, + 'all_time': atime} def get_net_param(self): return self.net_state_dict diff --git a/others/paddlepaddle/.gitignore b/others/paddlepaddle/.gitignore deleted file mode 100644 index ed615b6..0000000 --- a/others/paddlepaddle/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -.DS_Store -*.whl -snapshots diff --git a/others/paddlepaddle/README.md b/others/paddlepaddle/README.md deleted file mode 100644 index b6a1b01..0000000 --- a/others/paddlepaddle/README.md +++ /dev/null @@ -1,118 +0,0 @@ -# Image Classification based on NAS-Searched Models - -This directory contains 10 image classification models. -Nine of them are automatically searched models from different Neural Architecture Search (NAS) algorithms. The other is the residual network. -We provide codes and scripts to train these models on both CIFAR-10 and CIFAR-100. -We use the standard data augmentation, i.e., random crop, random flip, and normalization. - ---- -## Table of Contents -- [Installation](#installation) -- [Data Preparation](#data-preparation) -- [Training Models](#training-models) -- [Project Structure](#project-structure) -- [Citation](#citation) - - -### Installation -This project has the following requirements: -- Python = 3.6 -- PadddlePaddle Fluid >= v0.15.0 - - -### Data Preparation -Please download [CIFAR-10](https://dataset.bj.bcebos.com/cifar/cifar-10-python.tar.gz) and [CIFAR-100](https://dataset.bj.bcebos.com/cifar/cifar-100-python.tar.gz) before running the codes. -Note that the MD5 of CIFAR-10-Python compressed file is `c58f30108f718f92721af3b95e74349a` and the MD5 of CIFAR-100-Python compressed file is `eb9058c3a382ffc7106e4002c42a8d85`. -Please save the file into `${TORCH_HOME}/cifar.python`. -After data preparation, there should be two files `${TORCH_HOME}/cifar.python/cifar-10-python.tar.gz` and `${TORCH_HOME}/cifar.python/cifar-100-python.tar.gz`. - - -### Training Models - -After setting up the environment and preparing the data, one can train the model. The main function entrance is `train_cifar.py`. We also provide some scripts for easy usage. -``` -bash ./scripts/base-train.sh 0 cifar-10 ResNet110 -bash ./scripts/train-nas.sh 0 cifar-10 GDAS_V1 -bash ./scripts/train-nas.sh 0 cifar-10 GDAS_V2 -bash ./scripts/train-nas.sh 0 cifar-10 SETN -bash ./scripts/train-nas.sh 0 cifar-10 NASNet -bash ./scripts/train-nas.sh 0 cifar-10 ENASNet -bash ./scripts/train-nas.sh 0 cifar-10 AmoebaNet -bash ./scripts/train-nas.sh 0 cifar-10 PNASNet -bash ./scripts/train-nas.sh 0 cifar-100 SETN -``` -The first argument is the GPU-ID to train your program, the second argument is the dataset name, and the last one is the model name. -Please use `./scripts/base-train.sh` for ResNet and use `./scripts/train-nas.sh` for NAS-searched models. - - -### Project Structure -``` -. -├──train_cifar.py [Training CNN models] -├──lib [Library for dataset, models, and others] -│ └──models -│ ├──__init__.py [Import useful Classes and Functions in models] -│ ├──resnet.py [Define the ResNet models] -│ ├──operations.py [Define the atomic operation in NAS search space] -│ ├──genotypes.py [Define the topological structure of different NAS-searched models] -│ └──nas_net.py [Define the macro structure of NAS models] -│ └──utils -│ ├──__init__.py [Import useful Classes and Functions in utils] -│ ├──meter.py [Define the AverageMeter class to count the accuracy and loss] -│ ├──time_utils.py [Define some functions to print date or convert seconds into hours] -│ └──data_utils.py [Define data augmentation functions and dataset reader for CIFAR] -└──scripts [Scripts for running] -``` - - -### Citation -If you find that this project helps your research, please consider citing these papers: -``` -@inproceedings{dong2019one, - title = {One-Shot Neural Architecture Search via Self-Evaluated Template Network}, - author = {Dong, Xuanyi and Yang, Yi}, - booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, - year = {2019} -} -@inproceedings{dong2019search, - title = {Searching for A Robust Neural Architecture in Four GPU Hours}, - author = {Dong, Xuanyi and Yang, Yi}, - booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, - pages = {1761--1770}, - year = {2019} -} -@inproceedings{liu2018darts, - title = {Darts: Differentiable architecture search}, - author = {Liu, Hanxiao and Simonyan, Karen and Yang, Yiming}, - booktitle = {ICLR}, - year = {2018} -} -@inproceedings{pham2018efficient, - title = {Efficient Neural Architecture Search via Parameter Sharing}, - author = {Pham, Hieu and Guan, Melody and Zoph, Barret and Le, Quoc and Dean, Jeff}, - booktitle = {International Conference on Machine Learning (ICML)}, - pages = {4092--4101}, - year = {2018} -} -@inproceedings{liu2018progressive, - title = {Progressive neural architecture search}, - author = {Liu, Chenxi and Zoph, Barret and Neumann, Maxim and Shlens, Jonathon and Hua, Wei and Li, Li-Jia and Fei-Fei, Li and Yuille, Alan and Huang, Jonathan and Murphy, Kevin}, - booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)}, - pages = {19--34}, - year = {2018} -} -@inproceedings{zoph2018learning, - title = {Learning transferable architectures for scalable image recognition}, - author = {Zoph, Barret and Vasudevan, Vijay and Shlens, Jonathon and Le, Quoc V}, - booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, - pages = {8697--8710}, - year = {2018} -} -@inproceedings{real2019regularized, - title = {Regularized evolution for image classifier architecture search}, - author = {Real, Esteban and Aggarwal, Alok and Huang, Yanping and Le, Quoc V}, - booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence}, - pages = {4780--4789}, - year = {2019} -} -``` diff --git a/others/paddlepaddle/lib/models/__init__.py b/others/paddlepaddle/lib/models/__init__.py deleted file mode 100644 index 0bebe0b..0000000 --- a/others/paddlepaddle/lib/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .genotypes import Networks -from .nas_net import NASCifarNet -from .resnet import resnet_cifar diff --git a/others/paddlepaddle/lib/models/genotypes.py b/others/paddlepaddle/lib/models/genotypes.py deleted file mode 100644 index 08f145f..0000000 --- a/others/paddlepaddle/lib/models/genotypes.py +++ /dev/null @@ -1,175 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## -from collections import namedtuple - -Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') - - -# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 -NASNet = Genotype( - normal = [ - (('sep_conv_5x5', 1), ('sep_conv_3x3', 0)), - (('sep_conv_5x5', 0), ('sep_conv_3x3', 0)), - (('avg_pool_3x3', 1), ('skip_connect', 0)), - (('avg_pool_3x3', 0), ('avg_pool_3x3', 0)), - (('sep_conv_3x3', 1), ('skip_connect', 1)), - ], - normal_concat = [2, 3, 4, 5, 6], - reduce = [ - (('sep_conv_5x5', 1), ('sep_conv_7x7', 0)), - (('max_pool_3x3', 1), ('sep_conv_7x7', 0)), - (('avg_pool_3x3', 1), ('sep_conv_5x5', 0)), - (('skip_connect', 3), ('avg_pool_3x3', 2)), - (('sep_conv_3x3', 2), ('max_pool_3x3', 1)), - ], - reduce_concat = [4, 5, 6], -) - - -# Progressive Neural Architecture Search, ECCV 2018 -PNASNet = Genotype( - normal = [ - (('sep_conv_5x5', 0), ('max_pool_3x3', 0)), - (('sep_conv_7x7', 1), ('max_pool_3x3', 1)), - (('sep_conv_5x5', 1), ('sep_conv_3x3', 1)), - (('sep_conv_3x3', 4), ('max_pool_3x3', 1)), - (('sep_conv_3x3', 0), ('skip_connect', 1)), - ], - normal_concat = [2, 3, 4, 5, 6], - reduce = [ - (('sep_conv_5x5', 0), ('max_pool_3x3', 0)), - (('sep_conv_7x7', 1), ('max_pool_3x3', 1)), - (('sep_conv_5x5', 1), ('sep_conv_3x3', 1)), - (('sep_conv_3x3', 4), ('max_pool_3x3', 1)), - (('sep_conv_3x3', 0), ('skip_connect', 1)), - ], - reduce_concat = [2, 3, 4, 5, 6], -) - - -# Regularized Evolution for Image Classifier Architecture Search, AAAI 2019 -AmoebaNet = Genotype( - normal = [ - (('avg_pool_3x3', 0), ('max_pool_3x3', 1)), - (('sep_conv_3x3', 0), ('sep_conv_5x5', 2)), - (('sep_conv_3x3', 0), ('avg_pool_3x3', 3)), - (('sep_conv_3x3', 1), ('skip_connect', 1)), - (('skip_connect', 0), ('avg_pool_3x3', 1)), - ], - normal_concat = [4, 5, 6], - reduce = [ - (('avg_pool_3x3', 0), ('sep_conv_3x3', 1)), - (('max_pool_3x3', 0), ('sep_conv_7x7', 2)), - (('sep_conv_7x7', 0), ('avg_pool_3x3', 1)), - (('max_pool_3x3', 0), ('max_pool_3x3', 1)), - (('conv_7x1_1x7', 0), ('sep_conv_3x3', 5)), - ], - reduce_concat = [3, 4, 6] -) - - -# Efficient Neural Architecture Search via Parameter Sharing, ICML 2018 -ENASNet = Genotype( - normal = [ - (('sep_conv_3x3', 1), ('skip_connect', 1)), - (('sep_conv_5x5', 1), ('skip_connect', 0)), - (('avg_pool_3x3', 0), ('sep_conv_3x3', 1)), - (('sep_conv_3x3', 0), ('avg_pool_3x3', 1)), - (('sep_conv_5x5', 1), ('avg_pool_3x3', 0)), - ], - normal_concat = [2, 3, 4, 5, 6], - reduce = [ - (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)), # 2 - (('sep_conv_3x3', 1), ('avg_pool_3x3', 1)), # 3 - (('sep_conv_3x3', 1), ('avg_pool_3x3', 1)), # 4 - (('avg_pool_3x3', 1), ('sep_conv_5x5', 4)), # 5 - (('sep_conv_3x3', 5), ('sep_conv_5x5', 0)), - ], - reduce_concat = [2, 3, 4, 5, 6], -) - - -# DARTS: Differentiable Architecture Search, ICLR 2019 -DARTS_V1 = Genotype( - normal=[ - (('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1 - (('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2 - (('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3 - (('sep_conv_3x3', 0), ('skip_connect', 2)) # step 4 - ], - normal_concat=[2, 3, 4, 5], - reduce=[ - (('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1 - (('skip_connect', 2), ('max_pool_3x3', 0)), # step 2 - (('max_pool_3x3', 0), ('skip_connect', 2)), # step 3 - (('skip_connect', 2), ('avg_pool_3x3', 0)) # step 4 - ], - reduce_concat=[2, 3, 4, 5], -) - - -# DARTS: Differentiable Architecture Search, ICLR 2019 -DARTS_V2 = Genotype( - normal=[ - (('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1 - (('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2 - (('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3 - (('skip_connect', 0), ('dil_conv_3x3', 2)) # step 4 - ], - normal_concat=[2, 3, 4, 5], - reduce=[ - (('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1 - (('skip_connect', 2), ('max_pool_3x3', 1)), # step 2 - (('max_pool_3x3', 0), ('skip_connect', 2)), # step 3 - (('skip_connect', 2), ('max_pool_3x3', 1)) # step 4 - ], - reduce_concat=[2, 3, 4, 5], -) - - - -# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 -SETN = Genotype( - normal=[ - (('skip_connect', 0), ('sep_conv_5x5', 1)), - (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)), - (('sep_conv_5x5', 1), ('sep_conv_5x5', 3)), - (('max_pool_3x3', 1), ('conv_3x1_1x3', 4))], - normal_concat=[2, 3, 4, 5], - reduce=[ - (('sep_conv_3x3', 0), ('sep_conv_5x5', 1)), - (('avg_pool_3x3', 0), ('sep_conv_5x5', 1)), - (('avg_pool_3x3', 0), ('sep_conv_5x5', 1)), - (('avg_pool_3x3', 0), ('skip_connect', 1))], - reduce_concat=[2, 3, 4, 5], -) - - -# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 -GDAS_V1 = Genotype( - normal=[ - (('skip_connect', 0), ('skip_connect', 1)), - (('skip_connect', 0), ('sep_conv_5x5', 2)), - (('sep_conv_3x3', 3), ('skip_connect', 0)), - (('sep_conv_5x5', 4), ('sep_conv_3x3', 3))], - normal_concat=[2, 3, 4, 5], - reduce=[ - (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)), - (('sep_conv_5x5', 2), ('sep_conv_5x5', 1)), - (('dil_conv_5x5', 2), ('sep_conv_3x3', 1)), - (('sep_conv_5x5', 0), ('sep_conv_5x5', 1))], - reduce_concat=[2, 3, 4, 5], -) - - -Networks = {'DARTS_V1' : DARTS_V1, - 'DARTS_V2' : DARTS_V2, - 'DARTS' : DARTS_V2, - 'NASNet' : NASNet, - 'ENASNet' : ENASNet, - 'AmoebaNet': AmoebaNet, - 'GDAS_V1' : GDAS_V1, - 'PNASNet' : PNASNet, - 'SETN' : SETN, - } diff --git a/others/paddlepaddle/lib/models/nas_net.py b/others/paddlepaddle/lib/models/nas_net.py deleted file mode 100644 index 10815c7..0000000 --- a/others/paddlepaddle/lib/models/nas_net.py +++ /dev/null @@ -1,79 +0,0 @@ -import paddle -import paddle.fluid as fluid -from .operations import OPS - - -def AuxiliaryHeadCIFAR(inputs, C, class_num): - print ('AuxiliaryHeadCIFAR : inputs-shape : {:}'.format(inputs.shape)) - temp = fluid.layers.relu(inputs) - temp = fluid.layers.pool2d(temp, pool_size=5, pool_stride=3, pool_padding=0, pool_type='avg') - temp = fluid.layers.conv2d(temp, filter_size=1, num_filters=128, stride=1, padding=0, act=None, bias_attr=False) - temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None) - temp = fluid.layers.conv2d(temp, filter_size=1, num_filters=768, stride=2, padding=0, act=None, bias_attr=False) - temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None) - print ('AuxiliaryHeadCIFAR : last---shape : {:}'.format(temp.shape)) - predict = fluid.layers.fc(input=temp, size=class_num, act='softmax') - return predict - - -def InferCell(name, inputs_prev_prev, inputs_prev, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): - print ('[{:}] C_prev_prev={:} C_prev={:}, C={:}, reduction_prev={:}, reduction={:}'.format(name, C_prev_prev, C_prev, C, reduction_prev, reduction)) - print ('inputs_prev_prev : {:}'.format(inputs_prev_prev.shape)) - print ('inputs_prev : {:}'.format(inputs_prev.shape)) - inputs_prev_prev = OPS['skip_connect'](inputs_prev_prev, C_prev_prev, C, 2 if reduction_prev else 1) - inputs_prev = OPS['skip_connect'](inputs_prev, C_prev, C, 1) - print ('inputs_prev_prev : {:}'.format(inputs_prev_prev.shape)) - print ('inputs_prev : {:}'.format(inputs_prev.shape)) - if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat - else : step_ops, concat = genotype.normal, genotype.normal_concat - states = [inputs_prev_prev, inputs_prev] - for istep, operations in enumerate(step_ops): - op_a, op_b = operations - # the first operation - #print ('-->>[{:}/{:}] [{:}] + [{:}]'.format(istep, len(step_ops), op_a, op_b)) - stride = 2 if reduction and op_a[1] < 2 else 1 - tensor1 = OPS[ op_a[0] ](states[op_a[1]], C, C, stride) - stride = 2 if reduction and op_b[1] < 2 else 1 - tensor2 = OPS[ op_b[0] ](states[op_b[1]], C, C, stride) - state = fluid.layers.elementwise_add(x=tensor1, y=tensor2, act=None) - assert tensor1.shape == tensor2.shape, 'invalid shape {:} vs. {:}'.format(tensor1.shape, tensor2.shape) - print ('-->>[{:}/{:}] tensor={:} from {:} + {:}'.format(istep, len(step_ops), state.shape, tensor1.shape, tensor2.shape)) - states.append( state ) - states_to_cat = [states[x] for x in concat] - outputs = fluid.layers.concat(states_to_cat, axis=1) - print ('-->> output-shape : {:} from concat={:}'.format(outputs.shape, concat)) - return outputs - - - -# NASCifarNet(inputs, 36, 6, 3, 10, 'xxx', True) -def NASCifarNet(ipt, C, N, stem_multiplier, class_num, genotype, auxiliary): - # cifar head module - C_curr = stem_multiplier * C - stem = fluid.layers.conv2d(ipt, filter_size=3, num_filters=C_curr, stride=1, padding=1, act=None, bias_attr=False) - stem = fluid.layers.batch_norm(input=stem, act=None, bias_attr=None) - print ('stem-shape : {:}'.format(stem.shape)) - # N + 1 + N + 1 + N cells - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - reduction_prev = False - auxiliary_pred = None - - cell_results = [stem, stem] - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - xstr = '{:02d}/{:02d}'.format(index, len(layer_channels)) - cell_result = InferCell(xstr, cell_results[-2], cell_results[-1], genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) - reduction_prev = reduction - C_prev_prev, C_prev = C_prev, cell_result.shape[1] - cell_results.append( cell_result ) - if auxiliary and reduction and C_curr == C*4: - auxiliary_pred = AuxiliaryHeadCIFAR(cell_result, C_prev, class_num) - - global_P = fluid.layers.pool2d(input=cell_results[-1], pool_size=8, pool_type='avg', pool_stride=1) - predicts = fluid.layers.fc(input=global_P, size=class_num, act='softmax') - print ('predict-shape : {:}'.format(predicts.shape)) - if auxiliary_pred is None: - return predicts - else: - return [predicts, auxiliary_pred] diff --git a/others/paddlepaddle/lib/models/operations.py b/others/paddlepaddle/lib/models/operations.py deleted file mode 100644 index cbfe2b3..0000000 --- a/others/paddlepaddle/lib/models/operations.py +++ /dev/null @@ -1,91 +0,0 @@ -import paddle -import paddle.fluid as fluid - - -OPS = { - 'none' : lambda inputs, C_in, C_out, stride: ZERO(inputs, stride), - 'avg_pool_3x3' : lambda inputs, C_in, C_out, stride: POOL_3x3(inputs, C_in, C_out, stride, 'avg'), - 'max_pool_3x3' : lambda inputs, C_in, C_out, stride: POOL_3x3(inputs, C_in, C_out, stride, 'max'), - 'skip_connect' : lambda inputs, C_in, C_out, stride: Identity(inputs, C_in, C_out, stride), - 'sep_conv_3x3' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 3, stride, 1), - 'sep_conv_5x5' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 5, stride, 2), - 'sep_conv_7x7' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 7, stride, 3), - 'dil_conv_3x3' : lambda inputs, C_in, C_out, stride: DilConv(inputs, C_in, C_out, 3, stride, 2, 2), - 'dil_conv_5x5' : lambda inputs, C_in, C_out, stride: DilConv(inputs, C_in, C_out, 5, stride, 4, 2), - 'conv_3x1_1x3' : lambda inputs, C_in, C_out, stride: Conv313(inputs, C_in, C_out, stride), - 'conv_7x1_1x7' : lambda inputs, C_in, C_out, stride: Conv717(inputs, C_in, C_out, stride), -} - - -def ReLUConvBN(inputs, C_in, C_out, kernel, stride, padding): - temp = fluid.layers.relu(inputs) - temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_out, stride=stride, padding=padding, act=None, bias_attr=False) - temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None) - return temp - - -def ZERO(inputs, stride): - if stride == 1: - return inputs * 0 - elif stride == 2: - return fluid.layers.pool2d(inputs, filter_size=2, pool_stride=2, pool_padding=0, pool_type='avg') * 0 - else: - raise ValueError('invalid stride of {:} not [1, 2]'.format(stride)) - - -def Identity(inputs, C_in, C_out, stride): - if C_in == C_out and stride == 1: - return inputs - elif stride == 1: - return ReLUConvBN(inputs, C_in, C_out, 1, 1, 0) - else: - temp1 = fluid.layers.relu(inputs) - temp2 = fluid.layers.pad2d(input=temp1, paddings=[0, 1, 0, 1], mode='reflect') - temp2 = fluid.layers.slice(temp2, axes=[0, 1, 2, 3], starts=[0, 0, 1, 1], ends=[999, 999, 999, 999]) - temp1 = fluid.layers.conv2d(temp1, filter_size=1, num_filters=C_out//2, stride=stride, padding=0, act=None, bias_attr=False) - temp2 = fluid.layers.conv2d(temp2, filter_size=1, num_filters=C_out-C_out//2, stride=stride, padding=0, act=None, bias_attr=False) - temp = fluid.layers.concat([temp1,temp2], axis=1) - return fluid.layers.batch_norm(input=temp, act=None, bias_attr=None) - - -def POOL_3x3(inputs, C_in, C_out, stride, mode): - if C_in == C_out: - xinputs = inputs - else: - xinputs = ReLUConvBN(inputs, C_in, C_out, 1, 1, 0) - return fluid.layers.pool2d(xinputs, pool_size=3, pool_stride=stride, pool_padding=1, pool_type=mode) - - -def SepConv(inputs, C_in, C_out, kernel, stride, padding): - temp = fluid.layers.relu(inputs) - temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride=stride, padding=padding, act=None, bias_attr=False) - temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_in , stride= 1, padding= 0, act=None, bias_attr=False) - temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None) - temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride= 1, padding=padding, act=None, bias_attr=False) - temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_out, stride= 1, padding= 0, act=None, bias_attr=False) - temp = fluid.layers.batch_norm(input=temp, act=None , bias_attr=None) - return temp - - -def DilConv(inputs, C_in, C_out, kernel, stride, padding, dilation): - temp = fluid.layers.relu(inputs) - temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride=stride, padding=padding, dilation=dilation, act=None, bias_attr=False) - temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_out, stride= 1, padding= 0, act=None, bias_attr=False) - temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None) - return temp - - -def Conv313(inputs, C_in, C_out, stride): - temp = fluid.layers.relu(inputs) - temp = fluid.layers.conv2d(temp, filter_size=(1,3), num_filters=C_out, stride=(1,stride), padding=(0,1), act=None, bias_attr=False) - temp = fluid.layers.conv2d(temp, filter_size=(3,1), num_filters=C_out, stride=(stride,1), padding=(1,0), act=None, bias_attr=False) - temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None) - return temp - - -def Conv717(inputs, C_in, C_out, stride): - temp = fluid.layers.relu(inputs) - temp = fluid.layers.conv2d(temp, filter_size=(1,7), num_filters=C_out, stride=(1,stride), padding=(0,3), act=None, bias_attr=False) - temp = fluid.layers.conv2d(temp, filter_size=(7,1), num_filters=C_out, stride=(stride,1), padding=(3,0), act=None, bias_attr=False) - temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None) - return temp diff --git a/others/paddlepaddle/lib/models/resnet.py b/others/paddlepaddle/lib/models/resnet.py deleted file mode 100644 index 5c15fab..0000000 --- a/others/paddlepaddle/lib/models/resnet.py +++ /dev/null @@ -1,65 +0,0 @@ -import paddle -import paddle.fluid as fluid - - -def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='relu', - bias_attr=False): - tmp = fluid.layers.conv2d( - input=input, - filter_size=filter_size, - num_filters=ch_out, - stride=stride, - padding=padding, - act=None, - bias_attr=bias_attr) - return fluid.layers.batch_norm(input=tmp, act=act) - - -def shortcut(input, ch_in, ch_out, stride): - if stride == 2: - temp = fluid.layers.pool2d(input, pool_size=2, pool_type='avg', pool_stride=2) - temp = fluid.layers.conv2d(temp , filter_size=1, num_filters=ch_out, stride=1, padding=0, act=None, bias_attr=None) - return temp - elif ch_in != ch_out: - return conv_bn_layer(input, ch_out, 1, stride, 0, None, None) - else: - return input - - -def basicblock(input, ch_in, ch_out, stride): - tmp = conv_bn_layer(input, ch_out, 3, stride, 1) - tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) - short = shortcut(input, ch_in, ch_out, stride) - return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') - - -def layer_warp(block_func, input, ch_in, ch_out, count, stride): - tmp = block_func(input, ch_in, ch_out, stride) - for i in range(1, count): - tmp = block_func(tmp, ch_out, ch_out, 1) - return tmp - - -def resnet_cifar(ipt, depth, class_num): - # depth should be one of 20, 32, 44, 56, 110, 1202 - assert (depth - 2) % 6 == 0 - n = (depth - 2) // 6 - print('[resnet] depth : {:}, class_num : {:}'.format(depth, class_num)) - conv1 = conv_bn_layer(ipt, ch_out=16, filter_size=3, stride=1, padding=1) - print('conv-1 : shape = {:}'.format(conv1.shape)) - res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) - print('res--1 : shape = {:}'.format(res1.shape)) - res2 = layer_warp(basicblock, res1 , 16, 32, n, 2) - print('res--2 : shape = {:}'.format(res2.shape)) - res3 = layer_warp(basicblock, res2 , 32, 64, n, 2) - print('res--3 : shape = {:}'.format(res3.shape)) - pool = fluid.layers.pool2d(input=res3, pool_size=8, pool_type='avg', pool_stride=1) - print('pool : shape = {:}'.format(pool.shape)) - predict = fluid.layers.fc(input=pool, size=class_num, act='softmax') - print('predict: shape = {:}'.format(predict.shape)) - return predict diff --git a/others/paddlepaddle/lib/utils/__init__.py b/others/paddlepaddle/lib/utils/__init__.py deleted file mode 100644 index 2c02373..0000000 --- a/others/paddlepaddle/lib/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## -from .meter import AverageMeter -from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time -from .data_utils import reader_creator diff --git a/others/paddlepaddle/lib/utils/data_utils.py b/others/paddlepaddle/lib/utils/data_utils.py deleted file mode 100644 index 305c0e7..0000000 --- a/others/paddlepaddle/lib/utils/data_utils.py +++ /dev/null @@ -1,64 +0,0 @@ -import random, tarfile -import numpy, six -from six.moves import cPickle as pickle -from PIL import Image, ImageOps - - -def train_cifar_augmentation(image): - # flip - if random.random() < 0.5: image1 = image.transpose(Image.FLIP_LEFT_RIGHT) - else: image1 = image - # random crop - image2 = ImageOps.expand(image1, border=4, fill=0) - i = random.randint(0, 40 - 32) - j = random.randint(0, 40 - 32) - image3 = image2.crop((j,i,j+32,i+32)) - # to numpy - image3 = numpy.array(image3) / 255.0 - mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3) - std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3) - return (image3 - mean) / std - - -def valid_cifar_augmentation(image): - image3 = numpy.array(image) / 255.0 - mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3) - std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3) - return (image3 - mean) / std - - -def reader_creator(filename, sub_name, is_train, cycle=False): - def read_batch(batch): - data = batch[six.b('data')] - labels = batch.get( - six.b('labels'), batch.get(six.b('fine_labels'), None)) - assert labels is not None - for sample, label in six.moves.zip(data, labels): - sample = sample.reshape(3, 32, 32) - sample = sample.transpose((1, 2, 0)) - image = Image.fromarray(sample) - if is_train: - ximage = train_cifar_augmentation(image) - else: - ximage = valid_cifar_augmentation(image) - ximage = ximage.transpose((2, 0, 1)) - yield ximage.astype(numpy.float32), int(label) - - def reader(): - with tarfile.open(filename, mode='r') as f: - names = (each_item.name for each_item in f - if sub_name in each_item.name) - - while True: - for name in names: - if six.PY2: - batch = pickle.load(f.extractfile(name)) - else: - batch = pickle.load( - f.extractfile(name), encoding='bytes') - for item in read_batch(batch): - yield item - if not cycle: - break - - return reader diff --git a/others/paddlepaddle/lib/utils/meter.py b/others/paddlepaddle/lib/utils/meter.py deleted file mode 100644 index 1e3d02d..0000000 --- a/others/paddlepaddle/lib/utils/meter.py +++ /dev/null @@ -1,26 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## -import time, sys -import numpy as np - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() - - def reset(self): - self.val = 0.0 - self.avg = 0.0 - self.sum = 0.0 - self.count = 0.0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - def __repr__(self): - return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__)) diff --git a/others/paddlepaddle/lib/utils/time_utils.py b/others/paddlepaddle/lib/utils/time_utils.py deleted file mode 100644 index 7886fcc..0000000 --- a/others/paddlepaddle/lib/utils/time_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -import time, sys -import numpy as np - -def time_for_file(): - ISOTIMEFORMAT='%d-%h-at-%H-%M-%S' - return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) - -def time_string(): - ISOTIMEFORMAT='%Y-%m-%d %X' - string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) - return string - -def time_string_short(): - ISOTIMEFORMAT='%Y%m%d' - string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) - return string - -def time_print(string, is_print=True): - if (is_print): - print('{} : {}'.format(time_string(), string)) - -def convert_size2str(torch_size): - dims = len(torch_size) - string = '[' - for idim in range(dims): - string = string + ' {}'.format(torch_size[idim]) - return string + ']' - -def convert_secs2time(epoch_time, return_str=False): - need_hour = int(epoch_time / 3600) - need_mins = int((epoch_time - 3600*need_hour) / 60) - need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) - if return_str: - str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) - return str - else: - return need_hour, need_mins, need_secs - -def print_log(print_string, log): - #if isinstance(log, Logger): log.log('{:}'.format(print_string)) - if hasattr(log, 'log'): log.log('{:}'.format(print_string)) - else: - print("{:}".format(print_string)) - if log is not None: - log.write('{:}\n'.format(print_string)) - log.flush() diff --git a/others/paddlepaddle/scripts/base-train.sh b/others/paddlepaddle/scripts/base-train.sh deleted file mode 100644 index f4eed75..0000000 --- a/others/paddlepaddle/scripts/base-train.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -# bash ./scripts/base-train.sh 0 cifar-10 ResNet110 -echo script name: $0 -echo $# arguments -if [ "$#" -ne 3 ] ;then - echo "Input illegal number of parameters " $# - echo "Need 3 parameters for GPU and dataset and the-model-name" - exit 1 -fi -if [ "$TORCH_HOME" = "" ]; then - echo "Must set TORCH_HOME envoriment variable for data dir saving" - exit 1 -else - echo "TORCH_HOME : $TORCH_HOME" -fi - -GPU=$1 -dataset=$2 -model=$3 - -save_dir=snapshots/${dataset}-${model} - -export FLAGS_fraction_of_gpu_memory_to_use="0.005" -export FLAGS_free_idle_memory=True - -CUDA_VISIBLE_DEVICES=${GPU} python train_cifar.py \ - --data_path $TORCH_HOME/cifar.python/${dataset}-python.tar.gz \ - --log_dir ${save_dir} \ - --dataset ${dataset} \ - --model_name ${model} \ - --lr 0.1 --epochs 300 --batch_size 256 --step_each_epoch 196 diff --git a/others/paddlepaddle/scripts/train-nas.sh b/others/paddlepaddle/scripts/train-nas.sh deleted file mode 100644 index e2bdde1..0000000 --- a/others/paddlepaddle/scripts/train-nas.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -# bash ./scripts/base-train.sh 0 cifar-10 ResNet110 -echo script name: $0 -echo $# arguments -if [ "$#" -ne 3 ] ;then - echo "Input illegal number of parameters " $# - echo "Need 3 parameters for GPU and dataset and the-model-name" - exit 1 -fi -if [ "$TORCH_HOME" = "" ]; then - echo "Must set TORCH_HOME envoriment variable for data dir saving" - exit 1 -else - echo "TORCH_HOME : $TORCH_HOME" -fi - -GPU=$1 -dataset=$2 -model=$3 - -save_dir=snapshots/${dataset}-${model} - -export FLAGS_fraction_of_gpu_memory_to_use="0.005" -export FLAGS_free_idle_memory=True - -CUDA_VISIBLE_DEVICES=${GPU} python train_cifar.py \ - --data_path $TORCH_HOME/cifar.python/${dataset}-python.tar.gz \ - --log_dir ${save_dir} \ - --dataset ${dataset} \ - --model_name ${model} \ - --lr 0.025 --epochs 600 --batch_size 96 --step_each_epoch 521 diff --git a/others/paddlepaddle/train_cifar.py b/others/paddlepaddle/train_cifar.py deleted file mode 100644 index b501380..0000000 --- a/others/paddlepaddle/train_cifar.py +++ /dev/null @@ -1,189 +0,0 @@ -import os, sys, numpy as np, argparse -from pathlib import Path -import paddle.fluid as fluid -import math, time, paddle -import paddle.fluid.layers.ops as ops -#from tb_paddle import SummaryWriter - -lib_dir = (Path(__file__).parent / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from models import resnet_cifar, NASCifarNet, Networks -from utils import AverageMeter, time_for_file, time_string, convert_secs2time -from utils import reader_creator - - -def inference_program(model_name, num_class): - # The image is 32 * 32 with RGB representation. - data_shape = [3, 32, 32] - images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') - - if model_name == 'ResNet20': - predict = resnet_cifar(images, 20, num_class) - elif model_name == 'ResNet32': - predict = resnet_cifar(images, 32, num_class) - elif model_name == 'ResNet110': - predict = resnet_cifar(images, 110, num_class) - else: - predict = NASCifarNet(images, 36, 6, 3, num_class, Networks[model_name], True) - return predict - - -def train_program(predict): - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - if isinstance(predict, (list, tuple)): - predict, aux_predict = predict - x_losses = fluid.layers.cross_entropy(input=predict, label=label) - aux_losses = fluid.layers.cross_entropy(input=aux_predict, label=label) - x_loss = fluid.layers.mean(x_losses) - aux_loss = fluid.layers.mean(aux_losses) - loss = x_loss + aux_loss * 0.4 - accuracy = fluid.layers.accuracy(input=predict, label=label) - else: - losses = fluid.layers.cross_entropy(input=predict, label=label) - loss = fluid.layers.mean(losses) - accuracy = fluid.layers.accuracy(input=predict, label=label) - return [loss, accuracy] - - -# For training test cost -def evaluation(program, reader, fetch_list, place): - feed_var_list = [program.global_block().var('pixel'), program.global_block().var('label')] - feeder_test = fluid.DataFeeder(feed_list=feed_var_list, place=place) - test_exe = fluid.Executor(place) - losses, accuracies = AverageMeter(), AverageMeter() - for tid, test_data in enumerate(reader()): - loss, acc = test_exe.run(program=program, feed=feeder_test.feed(test_data), fetch_list=fetch_list) - losses.update(float(loss), len(test_data)) - accuracies.update(float(acc)*100, len(test_data)) - return losses.avg, accuracies.avg - - -def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120): - """Applies cosine decay to the learning rate. - lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1) - decrease lr for every mini-batch and start with warmup. - """ - from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter - from paddle.fluid.initializer import init_on_cpu - global_step = _decay_step_counter() - lr = fluid.layers.tensor.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate") - - warmup_epoch = fluid.layers.fill_constant( - shape=[1], dtype='float32', value=float(5), force_cpu=True) - - with init_on_cpu(): - epoch = ops.floor(global_step / step_each_epoch) - with fluid.layers.control_flow.Switch() as switch: - with switch.case(epoch < warmup_epoch): - decayed_lr = learning_rate * (global_step / (step_each_epoch * warmup_epoch)) - fluid.layers.tensor.assign(input=decayed_lr, output=lr) - with switch.default(): - decayed_lr = learning_rate * \ - (ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / (epochs * step_each_epoch))) + 1)/2 - fluid.layers.tensor.assign(input=decayed_lr, output=lr) - return lr - - -def main(xargs): - - save_dir = Path(xargs.log_dir) / time_for_file() - save_dir.mkdir(parents=True, exist_ok=True) - - print ('save dir : {:}'.format(save_dir)) - print ('xargs : {:}'.format(xargs)) - - if xargs.dataset == 'cifar-10': - train_data = reader_creator(xargs.data_path, 'data_batch', True , False) - test__data = reader_creator(xargs.data_path, 'test_batch', False, False) - class_num = 10 - print ('create cifar-10 dataset') - elif xargs.dataset == 'cifar-100': - train_data = reader_creator(xargs.data_path, 'train', True , False) - test__data = reader_creator(xargs.data_path, 'test' , False, False) - class_num = 100 - print ('create cifar-100 dataset') - else: - raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) - - train_reader = paddle.batch( - paddle.reader.shuffle(train_data, buf_size=5000), - batch_size=xargs.batch_size) - - # Reader for testing. A separated data set for testing. - test_reader = paddle.batch(test__data, batch_size=xargs.batch_size) - - place = fluid.CUDAPlace(0) - - main_program = fluid.default_main_program() - star_program = fluid.default_startup_program() - - # programs - predict = inference_program(xargs.model_name, class_num) - [loss, accuracy] = train_program(predict) - print ('training program setup done') - test_program = main_program.clone(for_test=True) - print ('testing program setup done') - - #infer_writer = SummaryWriter( str(save_dir / 'infer') ) - #infer_writer.add_paddle_graph(fluid_program=fluid.default_main_program(), verbose=True) - #infer_writer.close() - #print(test_program.to_string(True)) - - #learning_rate = fluid.layers.cosine_decay(learning_rate=xargs.lr, step_each_epoch=xargs.step_each_epoch, epochs=xargs.epochs) - #learning_rate = fluid.layers.cosine_decay(learning_rate=0.1, step_each_epoch=196, epochs=300) - learning_rate = cosine_decay_with_warmup(xargs.lr, xargs.step_each_epoch, xargs.epochs) - optimizer = fluid.optimizer.Momentum( - learning_rate=learning_rate, - momentum=0.9, - regularization=fluid.regularizer.L2Decay(0.0005), - use_nesterov=True) - optimizer.minimize( loss ) - - exe = fluid.Executor(place) - - feed_var_list_loop = [main_program.global_block().var('pixel'), main_program.global_block().var('label')] - feeder = fluid.DataFeeder(feed_list=feed_var_list_loop, place=place) - exe.run(star_program) - - start_time, epoch_time = time.time(), AverageMeter() - for iepoch in range(xargs.epochs): - losses, accuracies, steps = AverageMeter(), AverageMeter(), 0 - for step_id, train_data in enumerate(train_reader()): - tloss, tacc, xlr = exe.run(main_program, feed=feeder.feed(train_data), fetch_list=[loss, accuracy, learning_rate]) - tloss, tacc, xlr = float(tloss), float(tacc) * 100, float(xlr) - steps += 1 - losses.update(tloss, len(train_data)) - accuracies.update(tacc, len(train_data)) - if step_id % 100 == 0: - print('{:} [{:03d}/{:03d}] [{:03d}] lr = {:.7f}, loss = {:.4f} ({:.4f}), accuracy = {:.2f} ({:.2f}), error={:.2f}'.format(time_string(), iepoch, xargs.epochs, step_id, xlr, tloss, losses.avg, tacc, accuracies.avg, 100-accuracies.avg)) - test_loss, test_acc = evaluation(test_program, test_reader, [loss, accuracy], place) - need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (xargs.epochs-iepoch), True) ) - print('{:}x[{:03d}/{:03d}] {:} train-loss = {:.4f}, train-accuracy = {:.2f}, test-loss = {:.4f}, test-accuracy = {:.2f} test-error = {:.2f} [{:} steps per epoch]\n'.format(time_string(), iepoch, xargs.epochs, need_time, losses.avg, accuracies.avg, test_loss, test_acc, 100-test_acc, steps)) - if isinstance(predict, list): - fluid.io.save_inference_model(str(save_dir / 'inference_model'), ["pixel"], predict, exe) - else: - fluid.io.save_inference_model(str(save_dir / 'inference_model'), ["pixel"], [predict], exe) - # measure elapsed time - epoch_time.update(time.time() - start_time) - start_time = time.time() - - print('finish training and evaluation with {:} epochs in {:}'.format(xargs.epochs, convert_secs2time(epoch_time.sum, True))) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Train.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--log_dir' , type=str, help='Save dir.') - parser.add_argument('--dataset', type=str, help='The dataset name.') - parser.add_argument('--data_path', type=str, help='The dataset path.') - parser.add_argument('--model_name', type=str, help='The model name.') - parser.add_argument('--lr', type=float, help='The learning rate.') - parser.add_argument('--batch_size', type=int, help='The batch size.') - parser.add_argument('--step_each_epoch',type=int, help='The batch size.') - parser.add_argument('--epochs' , type=int, help='The total training epochs.') - args = parser.parse_args() - main(args) diff --git a/scripts-search/algos/DARTS-V1.sh b/scripts-search/algos/DARTS-V1.sh index 68e3080..2104bda 100644 --- a/scripts-search/algos/DARTS-V1.sh +++ b/scripts-search/algos/DARTS-V1.sh @@ -19,6 +19,7 @@ seed=$2 channel=16 num_cells=5 max_nodes=4 +space=nas-bench-102 if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then data_path="$TORCH_HOME/cifar.python" @@ -26,11 +27,12 @@ else data_path="$TORCH_HOME/cifar.python/ImageNet16" fi -save_dir=./output/cell-search-tiny/DARTS-V1-${dataset} +save_dir=./output/search-cell-${space}/DARTS-V1-${dataset} OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --dataset ${dataset} --data_path ${data_path} \ - --search_space_name aa-nas \ + --search_space_name ${space} \ + --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --workers 4 --print_freq 200 --rand_seed ${seed} diff --git a/scripts-search/algos/R-EA.sh b/scripts-search/algos/R-EA.sh index 078d0c7..83c6f2b 100644 --- a/scripts-search/algos/R-EA.sh +++ b/scripts-search/algos/R-EA.sh @@ -32,7 +32,7 @@ save_dir=./output/cell-search-tiny/R-EA-${dataset} OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --dataset ${dataset} --data_path ${data_path} \ - --search_space_name aa-nas \ - --arch_nas_dataset ./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth \ + --search_space_name nas-bench-102 \ + --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --ea_cycles 30 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \ --workers 4 --print_freq 200 --rand_seed ${seed}