update code styles
This commit is contained in:
		| @@ -8,7 +8,6 @@ from tqdm import tqdm | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| from collections import defaultdict | ||||
| import matplotlib | ||||
| @@ -498,6 +497,8 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_ | ||||
|  | ||||
|   def get_accs(xdata): | ||||
|     epochs, xresults = xdata['epoch'], [] | ||||
|     metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False) | ||||
|     xresults.append( metrics['accuracy'] ) | ||||
|     for iepoch in range(epochs): | ||||
|       genotype = xdata['genotypes'][iepoch] | ||||
|       index = api.query_index_by_arch(genotype) | ||||
| @@ -547,7 +548,6 @@ if __name__ == '__main__': | ||||
|   #visualize_relative_ranking(vis_save_dir) | ||||
|  | ||||
|   api = API(args.api_path) | ||||
|   """ | ||||
|   for x_maxs in [50, 250]: | ||||
|     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
| @@ -555,11 +555,12 @@ if __name__ == '__main__': | ||||
|     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|   just_show(api) | ||||
|   """ | ||||
|   just_show(api) | ||||
|   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||
|   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||
|   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||
|   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||
|   """ | ||||
|   | ||||
| @@ -10,7 +10,6 @@ from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.distributions import Categorical | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, dict2config, configure2str | ||||
|   | ||||
| @@ -121,9 +121,19 @@ def main(xargs): | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   elif xargs.dataset == 'cifar100': | ||||
|     raise ValueError('not support yet : {:}'.format(xargs.dataset)) | ||||
|   elif xargs.dataset.startswith('ImageNet16'): | ||||
|     raise ValueError('not support yet : {:}'.format(xargs.dataset)) | ||||
|     cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) | ||||
|     search_train_data = train_data | ||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||
|     search_data   = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid) | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True) | ||||
|   elif xargs.dataset == 'ImageNet16-120': | ||||
|     imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) | ||||
|     search_train_data = train_data | ||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||
|     search_data   = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||
| @@ -168,7 +178,7 @@ def main(xargs): | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||
|   else: | ||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} | ||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} | ||||
|  | ||||
|   # start training | ||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||
| @@ -230,7 +240,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||
|   # channels and number-of-cells | ||||
|   parser.add_argument('--config_path',        type=str,   help='The config paths.') | ||||
|   parser.add_argument('--config_path',        type=str,   help='The config path.') | ||||
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | ||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   | ||||
| @@ -181,8 +181,8 @@ def main(xargs): | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||
|   config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   #config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
| @@ -233,7 +233,7 @@ def main(xargs): | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||
|   else: | ||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} | ||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} | ||||
|  | ||||
|   # start training | ||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||
| @@ -297,6 +297,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||
|   # channels and number-of-cells | ||||
|   parser.add_argument('--config_path',        type=str,   help='The config path.') | ||||
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | ||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import os, sys, time, glob, random, argparse | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| @@ -11,7 +11,7 @@ import torch.nn as nn | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, dict2config, configure2str | ||||
| from config_utils import load_config, dict2config | ||||
| from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
|   | ||||
| @@ -1,12 +1,14 @@ | ||||
| # python ./exps/vis/test.py | ||||
| import os, sys, random | ||||
| from pathlib import Path | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import numpy as np | ||||
| from collections import OrderedDict | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from nas_102_api import NASBench102API as API | ||||
|  | ||||
| def test_nas_api(): | ||||
|   from nas_102_api import ArchResults | ||||
| @@ -72,7 +74,40 @@ def test_auto_grad(): | ||||
|     s_grads = torch.autograd.grad(grads, net.parameters()) | ||||
|     second_order_grads.append( s_grads ) | ||||
|  | ||||
|  | ||||
| def test_one_shot_model(ckpath, use_train): | ||||
|   from models import get_cell_based_tiny_net, get_search_spaces | ||||
|   from datasets import get_datasets, SearchDataset | ||||
|   from config_utils import load_config, dict2config | ||||
|   from utils.nas_utils import evaluate_one_shot | ||||
|   use_train = int(use_train) > 0 | ||||
|   #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' | ||||
|   #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' | ||||
|   print ('ckpath : {:}'.format(ckpath)) | ||||
|   ckp = torch.load(ckpath) | ||||
|   xargs = ckp['args'] | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) | ||||
|   if xargs.dataset == 'cifar10': | ||||
|     cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | ||||
|     xvalid_data = deepcopy(train_data) | ||||
|     xvalid_data.transform = valid_data.transform | ||||
|     valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True) | ||||
|   else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet)) | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space, | ||||
|                               'affine'   : False, 'track_running_stats': True}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   search_model.load_state_dict( ckp['search_model'] ) | ||||
|   search_model = search_model.cuda() | ||||
|   api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth') | ||||
|   archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   #test_nas_api() | ||||
|   #for i in range(200): plot('{:04d}'.format(i)) | ||||
|   test_auto_grad() | ||||
|   #test_auto_grad() | ||||
|   test_one_shot_model(sys.argv[1], sys.argv[2]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user