Update WW
This commit is contained in:
		| @@ -3,111 +3,19 @@ | |||||||
| ######################################################## | ######################################################## | ||||||
| # python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | # python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||||
| ######################################################## | ######################################################## | ||||||
| import os, sys, time, glob, random, argparse | import sys, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config, configure2str | from log_utils    import time_string | ||||||
| from datasets     import get_datasets, SearchDataset | from models       import CellStructure | ||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler |  | ||||||
| from utils        import get_model_infos, obtain_accuracy |  | ||||||
| from log_utils    import AverageMeter, time_string, convert_secs2time |  | ||||||
| from models       import get_cell_based_tiny_net, get_search_spaces, CellStructure |  | ||||||
| from nas_201_api  import NASBench201API as API | from nas_201_api  import NASBench201API as API | ||||||
|  |  | ||||||
|  |  | ||||||
| def valid_func(xloader, network, criterion): |  | ||||||
|   data_time, batch_time = AverageMeter(), AverageMeter() |  | ||||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() |  | ||||||
|   network.eval() |  | ||||||
|   end = time.time() |  | ||||||
|   with torch.no_grad(): |  | ||||||
|     for step, (arch_inputs, arch_targets) in enumerate(xloader): |  | ||||||
|       arch_targets = arch_targets.cuda(non_blocking=True) |  | ||||||
|       # measure data loading time |  | ||||||
|       data_time.update(time.time() - end) |  | ||||||
|       # prediction |  | ||||||
|       _, logits = network(arch_inputs) |  | ||||||
|       arch_loss = criterion(logits, arch_targets) |  | ||||||
|       # record |  | ||||||
|       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) |  | ||||||
|       arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) |  | ||||||
|       arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) |  | ||||||
|       arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) |  | ||||||
|       # measure elapsed time |  | ||||||
|       batch_time.update(time.time() - end) |  | ||||||
|       end = time.time() |  | ||||||
|   return arch_losses.avg, arch_top1.avg, arch_top5.avg |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs): |  | ||||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' |  | ||||||
|   torch.backends.cudnn.enabled   = True |  | ||||||
|   torch.backends.cudnn.benchmark = False |  | ||||||
|   torch.backends.cudnn.deterministic = True |  | ||||||
|   torch.set_num_threads( xargs.workers ) |  | ||||||
|   prepare_seed(xargs.rand_seed) |  | ||||||
|   logger = prepare_logger(args) |  | ||||||
|  |  | ||||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) |  | ||||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |  | ||||||
|     cifar_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   elif xargs.dataset.startswith('ImageNet16'): |  | ||||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) |  | ||||||
|     imagenet16_split = load_config(split_Fpath, None, None) |  | ||||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid |  | ||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |  | ||||||
|   config_path = 'configs/nas-benchmark/algos/DARTS.config' |  | ||||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) |  | ||||||
|   # To split data |  | ||||||
|   train_data_v2 = deepcopy(train_data) |  | ||||||
|   train_data_v2.transform = valid_data.transform |  | ||||||
|   valid_data    = train_data_v2 |  | ||||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) |  | ||||||
|   # data loader |  | ||||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |  | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) |  | ||||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) |  | ||||||
|  |  | ||||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) |  | ||||||
|   model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, |  | ||||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, |  | ||||||
|                               'space'    : search_space}, None) |  | ||||||
|   search_model = get_cell_based_tiny_net(model_config) |  | ||||||
|   logger.log('search-model :\n{:}'.format(search_model)) |  | ||||||
|    |  | ||||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) |  | ||||||
|   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) |  | ||||||
|   logger.log('w-optimizer : {:}'.format(w_optimizer)) |  | ||||||
|   logger.log('a-optimizer : {:}'.format(a_optimizer)) |  | ||||||
|   logger.log('w-scheduler : {:}'.format(w_scheduler)) |  | ||||||
|   logger.log('criterion   : {:}'.format(criterion)) |  | ||||||
|   flop, param  = get_model_infos(search_model, xshape) |  | ||||||
|   #logger.log('{:}'.format(search_model)) |  | ||||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) |  | ||||||
|   if xargs.arch_nas_dataset is None: |  | ||||||
|     api = None |  | ||||||
|   else: |  | ||||||
|     api = API(xargs.arch_nas_dataset) |  | ||||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) |  | ||||||
|  |  | ||||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') |  | ||||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() |  | ||||||
|  |  | ||||||
|   logger.close() |  | ||||||
|    |  | ||||||
|  |  | ||||||
| def check_unique_arch(meta_file): | def check_unique_arch(meta_file): | ||||||
|   api = API(str(meta_file)) |   api = API(str(meta_file)) | ||||||
|   arch_strs = deepcopy(api.meta_archs) |   arch_strs = deepcopy(api.meta_archs) | ||||||
|   | |||||||
							
								
								
									
										36
									
								
								exps/NAS-Bench-201/test-weights.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								exps/NAS-Bench-201/test-weights.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||||
|  | ######################################################## | ||||||
|  | # python exps/NAS-Bench-201/test-weights.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||||
|  | ######################################################## | ||||||
|  | import os, sys, time, glob, random, argparse | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | from pathlib import Path | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
|  | from nas_201_api  import NASBench201API as API | ||||||
|  | from utils import weight_watcher | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(meta_file, weight_dir, save_dir): | ||||||
|  |   import pdb; | ||||||
|  |   pdb.set_trace() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |   parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | ||||||
|  |   parser.add_argument('--save_dir',   type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') | ||||||
|  |   parser.add_argument('--api_path',   type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.') | ||||||
|  |   parser.add_argument('--weight_dir', type=str, default=None, help='The directory path to the weights of every NAS-Bench-201 architecture.') | ||||||
|  |   args = parser.parse_args() | ||||||
|  |  | ||||||
|  |   save_dir = Path(args.save_dir) | ||||||
|  |   save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |   meta_file = Path(args.api_path) | ||||||
|  |   weight_dir = Path(args.weight_dir) | ||||||
|  |   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||||
|  |  | ||||||
|  |   main(meta_file, weight_dir, save_dir) | ||||||
|  |  | ||||||
| @@ -9,12 +9,23 @@ from utils import weight_watcher | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|   model = models.vgg19_bn(pretrained=True) |   # model = models.vgg19_bn(pretrained=True) | ||||||
|   _, summary = weight_watcher.analyze(model, alphas=False) |   # _, summary = weight_watcher.analyze(model, alphas=False) | ||||||
|   # print(summary) |   # for key, value in summary.items(): | ||||||
|   for key, value in summary.items(): |   #   print('{:10s} : {:}'.format(key, value)) | ||||||
|     print('{:10s} : {:}'.format(key, value)) |  | ||||||
|   # import pdb; pdb.set_trace() |   _, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False) | ||||||
|  |   print('vgg-13 : {:}'.format(summary['lognorm'])) | ||||||
|  |   _, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False) | ||||||
|  |   print('vgg-13-BN : {:}'.format(summary['lognorm'])) | ||||||
|  |   _, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False) | ||||||
|  |   print('vgg-16 : {:}'.format(summary['lognorm'])) | ||||||
|  |   _, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False) | ||||||
|  |   print('vgg-16-BN : {:}'.format(summary['lognorm'])) | ||||||
|  |   _, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False) | ||||||
|  |   print('vgg-19 : {:}'.format(summary['lognorm'])) | ||||||
|  |   _, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False) | ||||||
|  |   print('vgg-19-BN : {:}'.format(summary['lognorm'])) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -304,7 +304,7 @@ def analyze(model: nn.Module, min_size=50, max_size=0, | |||||||
|     if isinstance(module, available_module_types()): |     if isinstance(module, available_module_types()): | ||||||
|       names.append(name) |       names.append(name) | ||||||
|       modules.append(module) |       modules.append(module) | ||||||
|   print('There are {:} layers to be analyzed in this model.'.format(len(modules))) |   # print('There are {:} layers to be analyzed in this model.'.format(len(modules))) | ||||||
|   all_results = OrderedDict() |   all_results = OrderedDict() | ||||||
|   for index, module in enumerate(modules): |   for index, module in enumerate(modules): | ||||||
|     if isinstance(module, nn.Linear): |     if isinstance(module, nn.Linear): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user