Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
		| @@ -3,6 +3,8 @@ | ||||
| ################################################################### | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale # | ||||
| # required to install hpbandster ################################## | ||||
| # pip install hpbandster         ################################## | ||||
| ################################################################### | ||||
| # bash ./scripts-search/algos/BOHB.sh -1         ################## | ||||
| ################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| @@ -178,7 +180,7 @@ def main(xargs, nas_bench): | ||||
|   logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time)) | ||||
|   best_arch = config2structure( id2config[incumbent]['config'] ) | ||||
|  | ||||
|   info = nas_bench.query_by_arch( best_arch ) | ||||
|   info = nas_bench.query_by_arch(best_arch, '200') | ||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||
|   else           : logger.log('{:}'.format(info)) | ||||
|   logger.log('-'*100) | ||||
|   | ||||
| @@ -199,14 +199,14 @@ def main(xargs): | ||||
|     with torch.no_grad(): | ||||
|       #logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|       logger.log('{:}'.format(search_model.show_alphas())) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('\n' + '-'*100) | ||||
|   logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) | ||||
|   if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200'))) | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|   | ||||
| @@ -260,7 +260,7 @@ def main(xargs): | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
| @@ -268,7 +268,7 @@ def main(xargs): | ||||
|   logger.log('\n' + '-'*100) | ||||
|   # check the performance from the architecture dataset | ||||
|   logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) | ||||
|   if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1]), '200')) | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|   | ||||
| @@ -295,7 +295,7 @@ def main(xargs): | ||||
|     if find_best: | ||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc)) | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|   | ||||
| @@ -176,7 +176,7 @@ def main(xargs): | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('{:}'.format(search_model.show_alphas())) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
| @@ -184,7 +184,7 @@ def main(xargs): | ||||
|   logger.log('\n' + '-'*100) | ||||
|   # check the performance from the architecture dataset | ||||
|   logger.log('GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) | ||||
|   if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200'))) | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|   | ||||
| @@ -199,7 +199,7 @@ def main(xargs): | ||||
|     if find_best: | ||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
| @@ -210,7 +210,7 @@ def main(xargs): | ||||
|   best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) | ||||
|   search_time.update(time.time() - start_time) | ||||
|   logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) | ||||
|   if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch, '200'))) | ||||
|   logger.close() | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -74,7 +74,7 @@ def main(xargs, nas_bench): | ||||
|     logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) | ||||
|   logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time)) | ||||
|    | ||||
|   info = nas_bench.query_by_arch( best_arch ) | ||||
|   info = nas_bench.query_by_arch(best_arch, '200') | ||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||
|   else           : logger.log('{:}'.format(info)) | ||||
|   logger.log('-'*100) | ||||
|   | ||||
							
								
								
									
										5
									
								
								exps/algos/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								exps/algos/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # NAS Algorithms evaluated in NAS-Bench-201 | ||||
|  | ||||
| The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper. | ||||
|  | ||||
| We will upgrade the codes to be more general and extendable. The new codes are at [coming soon]. | ||||
| @@ -53,7 +53,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01 | ||||
|     assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) | ||||
|     xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12') | ||||
|     xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200') | ||||
|     info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). | ||||
|     info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', is_random=True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). | ||||
|     cost = nas_bench.get_cost_info(arch_index, dataname, hp='200') | ||||
|     # The following codes are used to estimate the time cost. | ||||
|     # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record. | ||||
| @@ -218,7 +218,7 @@ def main(xargs, nas_bench): | ||||
|   best_arch = best_arch.arch | ||||
|   logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) | ||||
|    | ||||
|   info = nas_bench.query_by_arch( best_arch ) | ||||
|   info = nas_bench.query_by_arch(best_arch, '200') | ||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||
|   else           : logger.log('{:}'.format(info)) | ||||
|   logger.log('-'*100) | ||||
|   | ||||
| @@ -235,7 +235,7 @@ def main(xargs): | ||||
|           }, logger.path('info'), logger) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('{:}'.format(search_model.show_alphas())) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
| @@ -251,7 +251,7 @@ def main(xargs): | ||||
|   logger.log('\n' + '-'*100) | ||||
|   # check the performance from the architecture dataset | ||||
|   logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype)) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(genotype) )) | ||||
|   if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') )) | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|   | ||||
| @@ -174,7 +174,7 @@ def main(xargs, nas_bench): | ||||
|   # best_arch = policy.genotype() # first version | ||||
|   best_arch = max(trace, key=lambda x: x[0])[1] | ||||
|   logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time)) | ||||
|   info = nas_bench.query_by_arch( best_arch ) | ||||
|   info = nas_bench.query_by_arch(best_arch, '200') | ||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||
|   else           : logger.log('{:}'.format(info)) | ||||
|   logger.log('-'*100) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user