fix small bugs in DARTS-V1 for NASNet-Space
This commit is contained in:
		| @@ -4,17 +4,22 @@ DARTS: Differentiable Architecture Search is accepted by ICLR 2019. | |||||||
| In this paper, Hanxiao proposed a differentiable neural architecture search method, named as DARTS. | In this paper, Hanxiao proposed a differentiable neural architecture search method, named as DARTS. | ||||||
| Recently, DARTS becomes very popular due to its simplicity and performance. | Recently, DARTS becomes very popular due to its simplicity and performance. | ||||||
|  |  | ||||||
| **Run DARTS on the NAS-Bench-201 search space**: | ## Run DARTS on the NAS-Bench-201 search space | ||||||
| ``` | ``` | ||||||
| CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1 | CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1 | ||||||
| CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh     cifar10 1 -1 | CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh     cifar10 1 -1 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| **Run the first-order DARTS on the NASNet search space**: | ## Run the first-order DARTS on the NASNet/DARTS search space | ||||||
|  | This command will start to use the first-order DARTS to search architectures on the DARTS search space. | ||||||
| ``` | ``` | ||||||
| CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/DARTS1V-search-NASNet-space.sh cifar10 -1 | CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/DARTS1V-search-NASNet-space.sh cifar10 -1 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  | After searching, if you want to train the searched architecture found by the above scripts, you need to add the config of that architecture (will be printed in log) in [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py). | ||||||
|  | In future, I will add a more eligent way to train the searched architecture from the DARTS search space. | ||||||
|  |  | ||||||
|  |  | ||||||
| # Citation | # Citation | ||||||
|  |  | ||||||
| ``` | ``` | ||||||
|   | |||||||
| @@ -199,7 +199,8 @@ def main(xargs): | |||||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) |       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) | ||||||
|       copy_checkpoint(model_base_path, model_best_path, logger) |       copy_checkpoint(model_base_path, model_best_path, logger) | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) |       #logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||||
|  |       logger.log('{:}'.format(search_model.show_alphas())) | ||||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) |     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||||
|     # measure elapsed time |     # measure elapsed time | ||||||
|     epoch_time.update(time.time() - start_time) |     epoch_time.update(time.time() - start_time) | ||||||
|   | |||||||
| @@ -53,6 +53,10 @@ class TinyNetworkDarts(nn.Module): | |||||||
|   def get_alphas(self): |   def get_alphas(self): | ||||||
|     return [self.arch_parameters] |     return [self.arch_parameters] | ||||||
|  |  | ||||||
|  |   def show_alphas(self): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |       return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) | ||||||
|  |  | ||||||
|   def get_message(self): |   def get_message(self): | ||||||
|     string = self.extra_repr() |     string = self.extra_repr() | ||||||
|     for i, cell in enumerate(self.cells): |     for i, cell in enumerate(self.cells): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user