simplify baselines
This commit is contained in:
		
							
								
								
									
										193
									
								
								exps/NAS-Bench-102/test-correlation.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										193
									
								
								exps/NAS-Bench-102/test-correlation.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,193 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ######################################################## | ||||||
|  | # python exps/NAS-Bench-102/test-correlation.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth | ||||||
|  | ######################################################## | ||||||
|  | import os, sys, time, glob, random, argparse | ||||||
|  | import numpy as np | ||||||
|  | from copy import deepcopy | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from pathlib import Path | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  | from config_utils import load_config, dict2config, configure2str | ||||||
|  | from datasets     import get_datasets, SearchDataset | ||||||
|  | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
|  | from utils        import get_model_infos, obtain_accuracy | ||||||
|  | from log_utils    import AverageMeter, time_string, convert_secs2time | ||||||
|  | from models       import get_cell_based_tiny_net, get_search_spaces, CellStructure | ||||||
|  | from nas_102_api  import NASBench102API as API | ||||||
|  |  | ||||||
|  |    | ||||||
|  | def valid_func(xloader, network, criterion): | ||||||
|  |   data_time, batch_time = AverageMeter(), AverageMeter() | ||||||
|  |   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|  |   network.eval() | ||||||
|  |   end = time.time() | ||||||
|  |   with torch.no_grad(): | ||||||
|  |     for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||||
|  |       arch_targets = arch_targets.cuda(non_blocking=True) | ||||||
|  |       # measure data loading time | ||||||
|  |       data_time.update(time.time() - end) | ||||||
|  |       # prediction | ||||||
|  |       _, logits = network(arch_inputs) | ||||||
|  |       arch_loss = criterion(logits, arch_targets) | ||||||
|  |       # record | ||||||
|  |       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||||
|  |       arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||||
|  |       arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) | ||||||
|  |       arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) | ||||||
|  |       # measure elapsed time | ||||||
|  |       batch_time.update(time.time() - end) | ||||||
|  |       end = time.time() | ||||||
|  |   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(xargs): | ||||||
|  |   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||||
|  |   torch.backends.cudnn.enabled   = True | ||||||
|  |   torch.backends.cudnn.benchmark = False | ||||||
|  |   torch.backends.cudnn.deterministic = True | ||||||
|  |   torch.set_num_threads( xargs.workers ) | ||||||
|  |   prepare_seed(xargs.rand_seed) | ||||||
|  |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|  |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|  |   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': | ||||||
|  |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|  |     cifar_split = load_config(split_Fpath, None, None) | ||||||
|  |     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||||
|  |     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||||
|  |   elif xargs.dataset.startswith('ImageNet16'): | ||||||
|  |     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) | ||||||
|  |     imagenet16_split = load_config(split_Fpath, None, None) | ||||||
|  |     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid | ||||||
|  |     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||||
|  |   else: | ||||||
|  |     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||||
|  |   config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||||
|  |   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|  |   # To split data | ||||||
|  |   train_data_v2 = deepcopy(train_data) | ||||||
|  |   train_data_v2.transform = valid_data.transform | ||||||
|  |   valid_data    = train_data_v2 | ||||||
|  |   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||||
|  |   # data loader | ||||||
|  |   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||||
|  |   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||||
|  |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
|  |   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |  | ||||||
|  |   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||||
|  |   model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, | ||||||
|  |                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||||
|  |                               'space'    : search_space}, None) | ||||||
|  |   search_model = get_cell_based_tiny_net(model_config) | ||||||
|  |   logger.log('search-model :\n{:}'.format(search_model)) | ||||||
|  |    | ||||||
|  |   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||||
|  |   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||||
|  |   logger.log('w-optimizer : {:}'.format(w_optimizer)) | ||||||
|  |   logger.log('a-optimizer : {:}'.format(a_optimizer)) | ||||||
|  |   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||||
|  |   logger.log('criterion   : {:}'.format(criterion)) | ||||||
|  |   flop, param  = get_model_infos(search_model, xshape) | ||||||
|  |   #logger.log('{:}'.format(search_model)) | ||||||
|  |   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||||
|  |   if xargs.arch_nas_dataset is None: | ||||||
|  |     api = None | ||||||
|  |   else: | ||||||
|  |     api = API(xargs.arch_nas_dataset) | ||||||
|  |   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||||
|  |  | ||||||
|  |   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||||
|  |   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||||
|  |  | ||||||
|  |   logger.close() | ||||||
|  |    | ||||||
|  |  | ||||||
|  | def check_unique_arch(meta_file): | ||||||
|  |   api = API(str(meta_file)) | ||||||
|  |   arch_strs = deepcopy(api.meta_archs) | ||||||
|  |   xarchs = [CellStructure.str2structure(x) for x in arch_strs] | ||||||
|  |   def get_unique_matrix(archs, consider_zero): | ||||||
|  |     UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] | ||||||
|  |     print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs))) | ||||||
|  |     Unique2Index = dict() | ||||||
|  |     for index, xstr in enumerate(UniquStrs): | ||||||
|  |       if xstr not in Unique2Index: Unique2Index[xstr] = list() | ||||||
|  |       Unique2Index[xstr].append( index ) | ||||||
|  |     sm_matrix = torch.eye(len(archs)).bool() | ||||||
|  |     for _, xlist in Unique2Index.items(): | ||||||
|  |       for i in xlist: | ||||||
|  |         for j in xlist: | ||||||
|  |           sm_matrix[i,j] = True | ||||||
|  |     unique_ids, unique_num = [-1 for _ in archs], 0 | ||||||
|  |     for i in range(len(unique_ids)): | ||||||
|  |       if unique_ids[i] > -1: continue | ||||||
|  |       neighbours = sm_matrix[i].nonzero().view(-1).tolist() | ||||||
|  |       for nghb in neighbours: | ||||||
|  |         assert unique_ids[nghb] == -1, 'impossible' | ||||||
|  |         unique_ids[nghb] = unique_num | ||||||
|  |       unique_num += 1 | ||||||
|  |     return sm_matrix, unique_ids, unique_num | ||||||
|  |  | ||||||
|  |   print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in xarchs) )) | ||||||
|  |   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) | ||||||
|  |   print ('{:} There are {:} unique architectures (considering nothing).'.format(time_string(), unique_num)) | ||||||
|  |   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) | ||||||
|  |   print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num)) | ||||||
|  |   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs,  True) | ||||||
|  |   print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): | ||||||
|  |   if isinstance(meta_file, API): | ||||||
|  |     api = meta_file | ||||||
|  |   else: | ||||||
|  |     api = API(str(meta_file)) | ||||||
|  |   cifar10_valid     = [] | ||||||
|  |   cifar10_test      = [] | ||||||
|  |   cifar100_test     = [] | ||||||
|  |   imagenet_test     = [] | ||||||
|  |   for idx, arch in enumerate(api): | ||||||
|  |     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) | ||||||
|  |     cifar10_valid.append( results['valid-accuracy'] ) | ||||||
|  |     results = api.get_more_info(idx, 'cifar10'       , None, False, is_rand) | ||||||
|  |     cifar10_test.append( results['test-accuracy'] ) | ||||||
|  |     results = api.get_more_info(idx, 'cifar100'      , None, False, is_rand) | ||||||
|  |     cifar100_test.append( results['test-accuracy'] ) | ||||||
|  |     results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand) | ||||||
|  |     imagenet_test.append( results['test-accuracy'] ) | ||||||
|  |   def get_cor(A, B): | ||||||
|  |     return float(np.corrcoef(A, B)[0,1]) | ||||||
|  |   cors = [] | ||||||
|  |   for basestr, xlist in zip(['CIFAR-010', 'CIFAR-100', 'ImageNet16'], [cifar10_test,cifar100_test, imagenet_test]): | ||||||
|  |     correlation = get_cor(cifar10_valid, xlist) | ||||||
|  |     print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(less_epoch, '012' if use_less_or_not else '200', basestr, correlation)) | ||||||
|  |     cors.append( correlation ) | ||||||
|  |     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) | ||||||
|  |     #print('-'*200) | ||||||
|  |   #print('*'*230) | ||||||
|  |   return cors | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |   parser = argparse.ArgumentParser("Analysis of NAS-Bench-102") | ||||||
|  |   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||||
|  |   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-102 benchmark file.') | ||||||
|  |   args = parser.parse_args() | ||||||
|  |  | ||||||
|  |   vis_save_dir = Path(args.save_dir) | ||||||
|  |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |   meta_file = Path(args.api_path) | ||||||
|  |   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||||
|  |  | ||||||
|  |   #check_unique_arch(meta_file) | ||||||
|  |   api = API(str(meta_file)) | ||||||
|  |   #for iepoch in [11, 25, 50, 100, 150, 175, 200]: | ||||||
|  |   #  check_cor_for_bandit(api,  6, iepoch) | ||||||
|  |   #  check_cor_for_bandit(api, 12, iepoch) | ||||||
|  |   correlations = check_cor_for_bandit(api, 6, True, True) | ||||||
|  |   import pdb; pdb.set_trace() | ||||||
| @@ -370,17 +370,17 @@ def write_video(save_dir): | |||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  |  | ||||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |   parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.') |   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||||
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-102 benchmark file.') |   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-102 benchmark file.') | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|    |    | ||||||
|   vis_save_dir = Path(args.save_dir) / 'visuals' |   vis_save_dir = Path(args.save_dir) | ||||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|   meta_file = Path(args.api_path) |   meta_file = Path(args.api_path) | ||||||
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) |   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||||
|   visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time') |   #visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time') | ||||||
|   write_video(vis_save_dir / 'over-time') |   #write_video(vis_save_dir / 'over-time') | ||||||
|   visualize_info(str(meta_file), 'cifar10' , vis_save_dir) |   #visualize_info(str(meta_file), 'cifar10' , vis_save_dir) | ||||||
|   visualize_info(str(meta_file), 'cifar100', vis_save_dir) |   #visualize_info(str(meta_file), 'cifar100', vis_save_dir) | ||||||
|   visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) |   #visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) | ||||||
|   visualize_relative_ranking(vis_save_dir) |   visualize_relative_ranking(vis_save_dir) | ||||||
|   | |||||||
| @@ -110,6 +110,7 @@ def main(xargs, nas_bench): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||||
|  |   if xargs.data_path is not None: | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |     cifar_split = load_config(split_Fpath, None, None) | ||||||
| @@ -128,7 +129,11 @@ def main(xargs, nas_bench): | |||||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) |     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} |     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||||
|    |   else: | ||||||
|  |     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||||
|  |     config = load_config(config_path, None, logger) | ||||||
|  |     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||||
|  |  | ||||||
|   # nas dataset load |   # nas dataset load | ||||||
|   assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset) |   assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset) | ||||||
|   | |||||||
| @@ -29,6 +29,7 @@ def main(xargs, nas_bench): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||||
|  |   if xargs.data_path is not None: | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |     cifar_split = load_config(split_Fpath, None, None) | ||||||
| @@ -47,7 +48,11 @@ def main(xargs, nas_bench): | |||||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) |     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} |     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||||
|  |   else: | ||||||
|  |     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||||
|  |     config = load_config(config_path, None, logger) | ||||||
|  |     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) |   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) |   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||||
|   #x =random_arch() ; y = mutate_arch(x) |   #x =random_arch() ; y = mutate_arch(x) | ||||||
|   | |||||||
| @@ -172,6 +172,7 @@ def main(xargs, nas_bench): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||||
|  |   if xargs.data_path is not None: | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |     cifar_split = load_config(split_Fpath, None, None) | ||||||
| @@ -190,6 +191,11 @@ def main(xargs, nas_bench): | |||||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) |     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} |     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||||
|  |   else: | ||||||
|  |     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||||
|  |     config = load_config(config_path, None, logger) | ||||||
|  |     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||||
|  |  | ||||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) |   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) |   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||||
|   | |||||||
| @@ -99,6 +99,7 @@ def main(xargs, nas_bench): | |||||||
|   logger = prepare_logger(args) |   logger = prepare_logger(args) | ||||||
|  |  | ||||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' |   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||||
|  |   if xargs.data_path is not None: | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |     cifar_split = load_config(split_Fpath, None, None) | ||||||
| @@ -117,6 +118,12 @@ def main(xargs, nas_bench): | |||||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) |     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} |     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||||
|  |   else: | ||||||
|  |     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||||
|  |     config = load_config(config_path, None, logger) | ||||||
|  |     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||||
|  |     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||||
|  |    | ||||||
|    |    | ||||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) |   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||||
|   policy    = Policy(xargs.max_nodes, search_space) |   policy    = Policy(xargs.max_nodes, search_space) | ||||||
|   | |||||||
| @@ -74,13 +74,20 @@ class Structure: | |||||||
|       nodes[i+1] = sum(sums) > 0 |       nodes[i+1] = sum(sums) > 0 | ||||||
|     return nodes[len(self.nodes)] |     return nodes[len(self.nodes)] | ||||||
|  |  | ||||||
|   def to_unique_str(self): |   def to_unique_str(self, consider_zero=False): | ||||||
|     # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation |     # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation | ||||||
|     # two operations are special, i.e., none and skip_connect |     # two operations are special, i.e., none and skip_connect | ||||||
|     nodes = {0: '0'} |     nodes = {0: '0'} | ||||||
|     for i_node, node_info in enumerate(self.nodes): |     for i_node, node_info in enumerate(self.nodes): | ||||||
|       cur_node = [] |       cur_node = [] | ||||||
|       for op, xin in node_info: |       for op, xin in node_info: | ||||||
|  |         if consider_zero is None: | ||||||
|  |           x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||||
|  |         elif consider_zero: | ||||||
|  |           if op == 'none' or nodes[xin] == '#': x = '#' # zero | ||||||
|  |           elif op == 'skip_connect': x = nodes[xin] | ||||||
|  |           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||||
|  |         else: | ||||||
|           if op == 'skip_connect': x = nodes[xin] |           if op == 'skip_connect': x = nodes[xin] | ||||||
|           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) |           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||||
|         cur_node.append(x) |         cur_node.append(x) | ||||||
|   | |||||||
| @@ -41,8 +41,9 @@ class NASBench102API(object): | |||||||
|       if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict)) |       if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict)) | ||||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) |       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||||
|       file_path_or_dict = torch.load(file_path_or_dict) |       file_path_or_dict = torch.load(file_path_or_dict) | ||||||
|     else: |     elif isinstance(file_path_or_dict, dict): | ||||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) |       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||||
|  |     else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict))) | ||||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) |     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') |     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) |     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||||
| @@ -152,26 +153,40 @@ class NASBench102API(object): | |||||||
|     archresult = arch2infos[index] |     archresult = arch2infos[index] | ||||||
|     return archresult.get_net_param(dataset, seed) |     return archresult.get_net_param(dataset, seed) | ||||||
|  |  | ||||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False): |   # obtain the metric for the `index`-th architecture | ||||||
|  |   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|     archresult = arch2infos[index] |     archresult = arch2infos[index] | ||||||
|     if dataset == 'cifar10-valid': |     if dataset == 'cifar10-valid': | ||||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=True) |       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True) |       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random) | ||||||
|       test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True) |       try: | ||||||
|  |         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |       except: | ||||||
|  |         test__info = None | ||||||
|       total      = train_info['iepoch'] + 1 |       total      = train_info['iepoch'] + 1 | ||||||
|       return {'train-loss'    : train_info['loss'], |       xifo = {'train-loss'    : train_info['loss'], | ||||||
|               'train-accuracy': train_info['accuracy'], |               'train-accuracy': train_info['accuracy'], | ||||||
|               'train-all-time': train_info['all_time'], |               'train-all-time': train_info['all_time'], | ||||||
|               'valid-loss'    : valid_info['loss'], |               'valid-loss'    : valid_info['loss'], | ||||||
|               'valid-accuracy': valid_info['accuracy'], |               'valid-accuracy': valid_info['accuracy'], | ||||||
|               'valid-all-time': valid_info['all_time'], |               'valid-all-time': valid_info['all_time'], | ||||||
|               'valid-per-time': valid_info['all_time'] / total, |               'valid-per-time': None if valid_info['all_time'] is None else valid_info['all_time'] / total} | ||||||
|  |       if test__info is not None: | ||||||
|  |         xifo['test-loss']     = test__info['loss'] | ||||||
|  |         xifo['test-accuracy'] = test__info['accuracy'] | ||||||
|  |       return xifo | ||||||
|  |     else: | ||||||
|  |       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||||
|  |       if dataset == 'cifar10': | ||||||
|  |         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |       else: | ||||||
|  |         test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||||
|  |       return {'train-loss'    : train_info['loss'], | ||||||
|  |               'train-accuracy': train_info['accuracy'], | ||||||
|               'test-loss'     : test__info['loss'], |               'test-loss'     : test__info['loss'], | ||||||
|               'test-accuracy' : test__info['accuracy']} |               'test-accuracy' : test__info['accuracy']} | ||||||
|     else: |  | ||||||
|       raise ValueError('coming soon...') |  | ||||||
|  |  | ||||||
|   def show(self, index=-1): |   def show(self, index=-1): | ||||||
|     if index < 0: # show all architectures |     if index < 0: # show all architectures | ||||||
| @@ -369,7 +384,7 @@ class ResultsCount(object): | |||||||
|   def update_latency(self, latency): |   def update_latency(self, latency): | ||||||
|     self.latency = copy.deepcopy( latency ) |     self.latency = copy.deepcopy( latency ) | ||||||
|  |  | ||||||
|   def update_eval(self, accs, losses, times): # old version |   def update_eval(self, accs, losses, times):  # new version | ||||||
|     data_names = set([x.split('@')[0] for x in accs.keys()]) |     data_names = set([x.split('@')[0] for x in accs.keys()]) | ||||||
|     for data_name in data_names: |     for data_name in data_names: | ||||||
|       assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name) |       assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name) | ||||||
|   | |||||||
| @@ -21,17 +21,11 @@ num_cells=5 | |||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-102 | space=nas-bench-102 | ||||||
|  |  | ||||||
| if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then |  | ||||||
|   data_path="$TORCH_HOME/cifar.python" |  | ||||||
| else |  | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |  | ||||||
| fi |  | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/BOHB-${dataset} | save_dir=./output/search-cell-${space}/BOHB-${dataset} | ||||||
|  |  | ||||||
| OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ | OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ | ||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--time_budget 12000  \ | 	--time_budget 12000  \ | ||||||
|   | |||||||
| @@ -22,17 +22,11 @@ num_cells=5 | |||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-102 | space=nas-bench-102 | ||||||
|  |  | ||||||
| if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then |  | ||||||
|   data_path="$TORCH_HOME/cifar.python" |  | ||||||
| else |  | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |  | ||||||
| fi |  | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/R-EA-${dataset} | save_dir=./output/search-cell-${space}/R-EA-${dataset} | ||||||
|  |  | ||||||
| OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ | OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ | ||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--time_budget 12000 \ | 	--time_budget 12000 \ | ||||||
|   | |||||||
| @@ -21,17 +21,11 @@ num_cells=5 | |||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-102 | space=nas-bench-102 | ||||||
|  |  | ||||||
| if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then |  | ||||||
|   data_path="$TORCH_HOME/cifar.python" |  | ||||||
| else |  | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |  | ||||||
| fi |  | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/REINFORCE-${dataset} | save_dir=./output/search-cell-${space}/REINFORCE-${dataset} | ||||||
|  |  | ||||||
| OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \ | OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \ | ||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--time_budget 12000 \ | 	--time_budget 12000 \ | ||||||
|   | |||||||
| @@ -21,17 +21,11 @@ num_cells=5 | |||||||
| max_nodes=4 | max_nodes=4 | ||||||
| space=nas-bench-102 | space=nas-bench-102 | ||||||
|  |  | ||||||
| if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then |  | ||||||
|   data_path="$TORCH_HOME/cifar.python" |  | ||||||
| else |  | ||||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" |  | ||||||
| fi |  | ||||||
|  |  | ||||||
| save_dir=./output/search-cell-${space}/RAND-${dataset} | save_dir=./output/search-cell-${space}/RAND-${dataset} | ||||||
|  |  | ||||||
| OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \ | OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \ | ||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--time_budget 12000 \ | 	--time_budget 12000 \ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user