diff --git a/exps/NAS-Bench-201/test-correlation.py b/exps/NAS-Bench-201/test-correlation.py index 7b3202b..0a49634 100644 --- a/exps/NAS-Bench-201/test-correlation.py +++ b/exps/NAS-Bench-201/test-correlation.py @@ -3,110 +3,18 @@ ######################################################## # python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth ######################################################## -import os, sys, time, glob, random, argparse +import sys, argparse import numpy as np from copy import deepcopy from tqdm import tqdm import torch -import torch.nn as nn from pathlib import Path lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces, CellStructure +from log_utils import time_string +from models import CellStructure from nas_201_api import NASBench201API as API - -def valid_func(xloader, network, criterion): - data_time, batch_time = AverageMeter(), AverageMeter() - arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() - network.eval() - end = time.time() - with torch.no_grad(): - for step, (arch_inputs, arch_targets) in enumerate(xloader): - arch_targets = arch_targets.cuda(non_blocking=True) - # measure data loading time - data_time.update(time.time() - end) - # prediction - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - # record - arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) - arch_losses.update(arch_loss.item(), arch_inputs.size(0)) - arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) - arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - return arch_losses.avg, arch_top1.avg, arch_top5.avg - - -def main(xargs): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads( xargs.workers ) - prepare_seed(xargs.rand_seed) - logger = prepare_logger(args) - - train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) - if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': - split_Fpath = 'configs/nas-benchmark/cifar-split.txt' - cifar_split = load_config(split_Fpath, None, None) - train_split, valid_split = cifar_split.train, cifar_split.valid - logger.log('Load split file from {:}'.format(split_Fpath)) - elif xargs.dataset.startswith('ImageNet16'): - split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) - imagenet16_split = load_config(split_Fpath, None, None) - train_split, valid_split = imagenet16_split.train, imagenet16_split.valid - logger.log('Load split file from {:}'.format(split_Fpath)) - else: - raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) - config_path = 'configs/nas-benchmark/algos/DARTS.config' - config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) - # To split data - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) - # data loader - search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) - logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) - logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) - - search_space = get_search_spaces('cell', xargs.search_space_name) - model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, - 'max_nodes': xargs.max_nodes, 'num_classes': class_num, - 'space' : search_space}, None) - search_model = get_cell_based_tiny_net(model_config) - logger.log('search-model :\n{:}'.format(search_model)) - - w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) - a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) - logger.log('w-optimizer : {:}'.format(w_optimizer)) - logger.log('a-optimizer : {:}'.format(a_optimizer)) - logger.log('w-scheduler : {:}'.format(w_scheduler)) - logger.log('criterion : {:}'.format(criterion)) - flop, param = get_model_infos(search_model, xshape) - #logger.log('{:}'.format(search_model)) - logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) - if xargs.arch_nas_dataset is None: - api = None - else: - api = API(xargs.arch_nas_dataset) - logger.log('{:} create API = {:} done'.format(time_string(), api)) - - last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') - network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() - - logger.close() - def check_unique_arch(meta_file): api = API(str(meta_file)) diff --git a/exps/NAS-Bench-201/test-weights.py b/exps/NAS-Bench-201/test-weights.py new file mode 100644 index 0000000..b8182c6 --- /dev/null +++ b/exps/NAS-Bench-201/test-weights.py @@ -0,0 +1,36 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # +######################################################## +# python exps/NAS-Bench-201/test-weights.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth +######################################################## +import os, sys, time, glob, random, argparse +import numpy as np +import torch +from pathlib import Path +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from nas_201_api import NASBench201API as API +from utils import weight_watcher + + +def main(meta_file, weight_dir, save_dir): + import pdb; + pdb.set_trace() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") + parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.') + parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.') + parser.add_argument('--weight_dir', type=str, default=None, help='The directory path to the weights of every NAS-Bench-201 architecture.') + args = parser.parse_args() + + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + meta_file = Path(args.api_path) + weight_dir = Path(args.weight_dir) + assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) + + main(meta_file, weight_dir, save_dir) + diff --git a/exps/experimental/test-ww.py b/exps/experimental/test-ww.py index 5d2813f..dec4e34 100644 --- a/exps/experimental/test-ww.py +++ b/exps/experimental/test-ww.py @@ -9,12 +9,23 @@ from utils import weight_watcher def main(): - model = models.vgg19_bn(pretrained=True) - _, summary = weight_watcher.analyze(model, alphas=False) - # print(summary) - for key, value in summary.items(): - print('{:10s} : {:}'.format(key, value)) - # import pdb; pdb.set_trace() + # model = models.vgg19_bn(pretrained=True) + # _, summary = weight_watcher.analyze(model, alphas=False) + # for key, value in summary.items(): + # print('{:10s} : {:}'.format(key, value)) + + _, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False) + print('vgg-13 : {:}'.format(summary['lognorm'])) + _, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False) + print('vgg-13-BN : {:}'.format(summary['lognorm'])) + _, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False) + print('vgg-16 : {:}'.format(summary['lognorm'])) + _, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False) + print('vgg-16-BN : {:}'.format(summary['lognorm'])) + _, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False) + print('vgg-19 : {:}'.format(summary['lognorm'])) + _, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False) + print('vgg-19-BN : {:}'.format(summary['lognorm'])) if __name__ == '__main__': diff --git a/lib/utils/weight_watcher.py b/lib/utils/weight_watcher.py index f675c72..a132c44 100644 --- a/lib/utils/weight_watcher.py +++ b/lib/utils/weight_watcher.py @@ -304,7 +304,7 @@ def analyze(model: nn.Module, min_size=50, max_size=0, if isinstance(module, available_module_types()): names.append(name) modules.append(module) - print('There are {:} layers to be analyzed in this model.'.format(len(modules))) + # print('There are {:} layers to be analyzed in this model.'.format(len(modules))) all_results = OrderedDict() for index, module in enumerate(modules): if isinstance(module, nn.Linear):