Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| import os, sys, time, torch | ||||
| import time, torch | ||||
| from procedures   import prepare_seed, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from config_utils import dict2config | ||||
| @@ -9,11 +9,9 @@ from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net | ||||
|  | ||||
|  | ||||
|  | ||||
| __all__ = ['evaluate_for_seed', 'pure_evaluate'] | ||||
|  | ||||
|  | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   | ||||
| @@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c | ||||
|   for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||
|     # train valid data | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||
|     # load the configurature | ||||
|     # load the configuration | ||||
|     if dataset == 'cifar10' or dataset == 'cifar100': | ||||
|       if use_less: config_path = 'configs/nas-benchmark/LESS.config' | ||||
|       else       : config_path = 'configs/nas-benchmark/CIFAR.config' | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| ################################################################################################ | ||||
| # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # | ||||
| ################################################################################################ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import sys, argparse | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|   | ||||
| @@ -6,7 +6,7 @@ | ||||
| # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | ||||
| # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | ||||
| ############################################################################################### | ||||
| import os, gc, sys, time, glob, random, argparse | ||||
| import os, gc, sys, argparse, psutil | ||||
| import numpy as np | ||||
| import torch | ||||
| from pathlib import Path | ||||
| @@ -33,7 +33,7 @@ def tostr(accdict, norms): | ||||
|  | ||||
| def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||
|   print('\nEvaluate dataset={:}'.format(data)) | ||||
|   norms = [] | ||||
|   norms, process = [], psutil.Process(os.getpid()) | ||||
|   final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||
|   final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||
|   for idx in range(len(api)): | ||||
| @@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||
|       with torch.no_grad(): | ||||
|         net.load_state_dict(param) | ||||
|         _, summary = weight_watcher.analyze(net, alphas=False) | ||||
|         cur_norms.append( summary['lognorm'] ) | ||||
|         cur_norms.append(summary['lognorm']) | ||||
|     norms.append( float(np.mean(cur_norms)) ) | ||||
|     api.clear_params(idx, use_12epochs_result) | ||||
|     api.clear_params(idx, None) | ||||
|     if idx % 200 == 199 or idx + 1 == len(api): | ||||
|       head = '{:05d}/{:05d}'.format(idx, len(api)) | ||||
|       stem_val = tostr(final_val_accs, norms) | ||||
|       stem_test = tostr(final_test_accs, norms) | ||||
|       print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200)) | ||||
|       print('    -->>  {:}  ||  {:}'.format(stem_val, stem_test)) | ||||
|       torch.cuda.empty_cache() ; gc.collect() | ||||
|       print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6)) | ||||
|       print('  [Valid] -->>  {:}'.format(stem_val)) | ||||
|       print('  [Test.] -->>  {:}'.format(stem_test)) | ||||
|       gc.collect() | ||||
|  | ||||
|  | ||||
| def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result): | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| ##################################################### | ||||
| # python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||
| ##################################################### | ||||
| import os, sys, time, argparse, collections | ||||
| import sys, argparse | ||||
| from tqdm import tqdm | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
|   | ||||
| @@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex | ||||
|   machine_info = get_machine_info() | ||||
|   all_infos = {'info': machine_info} | ||||
|   all_dataset_keys = [] | ||||
|   # look all the datasets | ||||
|   # look all the dataset | ||||
|   for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||
|     # train valid data | ||||
|     # the train and valid data | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||
|     # load the configurature | ||||
|     # load the configuration | ||||
|     if dataset == 'cifar10' or dataset == 'cifar100': | ||||
|       split_info  = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | ||||
|     elif dataset.startswith('ImageNet16'): | ||||
| @@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex | ||||
|     else: | ||||
|       raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||
|     config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) | ||||
|     # check whether use splited validation set | ||||
|     # check whether use the splitted validation set | ||||
|     if bool(split): | ||||
|       assert dataset == 'cifar10' | ||||
|       ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} | ||||
| @@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | ||||
|  | ||||
|   log_dir = save_dir / 'logs' | ||||
|   log_dir.mkdir(parents=True, exist_ok=True) | ||||
|   logger = Logger(str(log_dir), 0, False) | ||||
|   logger = Logger(str(log_dir), os.getpid(), False) | ||||
|  | ||||
|   logger.log('xargs : seeds      = {:}'.format(seeds)) | ||||
|   logger.log('xargs : cover_mode = {:}'.format(cover_mode)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user