Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -121,3 +121,5 @@ lib/NAS-Bench-*-v1_0.pth | ||||
| others/TF | ||||
| scripts-search/l2s-algos | ||||
| TEMP-L.sh | ||||
|  | ||||
| .nfs00* | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| import os, sys, time, torch | ||||
| import time, torch | ||||
| from procedures   import prepare_seed, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from config_utils import dict2config | ||||
| @@ -9,11 +9,9 @@ from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net | ||||
|  | ||||
|  | ||||
|  | ||||
| __all__ = ['evaluate_for_seed', 'pure_evaluate'] | ||||
|  | ||||
|  | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   | ||||
| @@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c | ||||
|   for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||
|     # train valid data | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||
|     # load the configurature | ||||
|     # load the configuration | ||||
|     if dataset == 'cifar10' or dataset == 'cifar100': | ||||
|       if use_less: config_path = 'configs/nas-benchmark/LESS.config' | ||||
|       else       : config_path = 'configs/nas-benchmark/CIFAR.config' | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| ################################################################################################ | ||||
| # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # | ||||
| ################################################################################################ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import sys, argparse | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|   | ||||
| @@ -6,7 +6,7 @@ | ||||
| # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | ||||
| # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | ||||
| ############################################################################################### | ||||
| import os, gc, sys, time, glob, random, argparse | ||||
| import os, gc, sys, argparse, psutil | ||||
| import numpy as np | ||||
| import torch | ||||
| from pathlib import Path | ||||
| @@ -33,7 +33,7 @@ def tostr(accdict, norms): | ||||
|  | ||||
| def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||
|   print('\nEvaluate dataset={:}'.format(data)) | ||||
|   norms = [] | ||||
|   norms, process = [], psutil.Process(os.getpid()) | ||||
|   final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||
|   final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||
|   for idx in range(len(api)): | ||||
| @@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||
|       with torch.no_grad(): | ||||
|         net.load_state_dict(param) | ||||
|         _, summary = weight_watcher.analyze(net, alphas=False) | ||||
|         cur_norms.append( summary['lognorm'] ) | ||||
|         cur_norms.append(summary['lognorm']) | ||||
|     norms.append( float(np.mean(cur_norms)) ) | ||||
|     api.clear_params(idx, use_12epochs_result) | ||||
|     api.clear_params(idx, None) | ||||
|     if idx % 200 == 199 or idx + 1 == len(api): | ||||
|       head = '{:05d}/{:05d}'.format(idx, len(api)) | ||||
|       stem_val = tostr(final_val_accs, norms) | ||||
|       stem_test = tostr(final_test_accs, norms) | ||||
|       print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200)) | ||||
|       print('    -->>  {:}  ||  {:}'.format(stem_val, stem_test)) | ||||
|       torch.cuda.empty_cache() ; gc.collect() | ||||
|       print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6)) | ||||
|       print('  [Valid] -->>  {:}'.format(stem_val)) | ||||
|       print('  [Test.] -->>  {:}'.format(stem_test)) | ||||
|       gc.collect() | ||||
|  | ||||
|  | ||||
| def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result): | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| ##################################################### | ||||
| # python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||
| ##################################################### | ||||
| import os, sys, time, argparse, collections | ||||
| import sys, argparse | ||||
| from tqdm import tqdm | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
|   | ||||
| @@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex | ||||
|   machine_info = get_machine_info() | ||||
|   all_infos = {'info': machine_info} | ||||
|   all_dataset_keys = [] | ||||
|   # look all the datasets | ||||
|   # look all the dataset | ||||
|   for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||
|     # train valid data | ||||
|     # the train and valid data | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||
|     # load the configurature | ||||
|     # load the configuration | ||||
|     if dataset == 'cifar10' or dataset == 'cifar100': | ||||
|       split_info  = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | ||||
|     elif dataset.startswith('ImageNet16'): | ||||
| @@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex | ||||
|     else: | ||||
|       raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||
|     config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) | ||||
|     # check whether use splited validation set | ||||
|     # check whether use the splitted validation set | ||||
|     if bool(split): | ||||
|       assert dataset == 'cifar10' | ||||
|       ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} | ||||
| @@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | ||||
|  | ||||
|   log_dir = save_dir / 'logs' | ||||
|   log_dir.mkdir(parents=True, exist_ok=True) | ||||
|   logger = Logger(str(log_dir), 0, False) | ||||
|   logger = Logger(str(log_dir), os.getpid(), False) | ||||
|  | ||||
|   logger.log('xargs : seeds      = {:}'.format(seeds)) | ||||
|   logger.log('xargs : cover_mode = {:}'.format(cover_mode)) | ||||
|   | ||||
| @@ -114,15 +114,27 @@ class NASBench201API(object): | ||||
|     assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) | ||||
|     xdata = torch.load(xfile_path, map_location='cpu') | ||||
|     assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) | ||||
|     if index in self.arch2infos_less: del self.arch2infos_less[index] | ||||
|     if index in self.arch2infos_full: del self.arch2infos_full[index] | ||||
|     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||
|     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||
|  | ||||
|   def clear_params(self, index: int, use_12epochs_result: bool): | ||||
|     """Remove the architecture's weights to save memory.""" | ||||
|     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
|     else                  : arch2infos = self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     archresult.clear_params() | ||||
|   def clear_params(self, index: int, use_12epochs_result: Union[bool, None]): | ||||
|     """Remove the architecture's weights to save memory. | ||||
|     :arg | ||||
|       index: the index of the target architecture | ||||
|       use_12epochs_result: a flag to controll how to clear the parameters. | ||||
|         -- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters. | ||||
|         -- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result. | ||||
|         -- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result. | ||||
|     """ | ||||
|     if use_12epochs_result is None: | ||||
|       self.arch2infos_less[index].clear_params() | ||||
|       self.arch2infos_full[index].clear_params() | ||||
|     else: | ||||
|       if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
|       else                  : arch2infos = self.arch2infos_full | ||||
|       arch2infos[index].clear_params() | ||||
|    | ||||
|   # This function is used to query the information of a specific archiitecture | ||||
|   # 'arch' can be an architecture index or an architecture string | ||||
| @@ -193,7 +205,6 @@ class NASBench201API(object): | ||||
|         best_index, highest_accuracy = idx, accuracy | ||||
|     return best_index, highest_accuracy | ||||
|  | ||||
|  | ||||
|   def arch(self, index: int): | ||||
|     """Return the topology structure of the `index`-th architecture.""" | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
| @@ -213,7 +224,6 @@ class NASBench201API(object): | ||||
|     else: arch2infos = self.arch2infos_full | ||||
|     arch_result = arch2infos[index] | ||||
|     return arch_result.get_net_param(dataset, seed) | ||||
|    | ||||
|  | ||||
|   def get_net_config(self, index: int, dataset: Text): | ||||
|     """ | ||||
| @@ -235,7 +245,6 @@ class NASBench201API(object): | ||||
|       #print ('SEED [{:}] : {:}'.format(seed, result)) | ||||
|     raise ValueError('Impossible to reach here!') | ||||
|  | ||||
|  | ||||
|   def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]: | ||||
|     """To obtain the cost metric for the `index`-th architecture on a dataset.""" | ||||
|     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
| @@ -243,7 +252,6 @@ class NASBench201API(object): | ||||
|     arch_result = arch2infos[index] | ||||
|     return arch_result.get_compute_costs(dataset) | ||||
|  | ||||
|  | ||||
|   def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float: | ||||
|     """ | ||||
|     To obtain the latency of the network (by default it will return the latency with the batch size of 256). | ||||
| @@ -254,7 +262,6 @@ class NASBench201API(object): | ||||
|     cost_dict = self.get_cost_info(index, dataset, use_12epochs_result) | ||||
|     return cost_dict['latency'] | ||||
|  | ||||
|  | ||||
|   # obtain the metric for the `index`-th architecture | ||||
|   # `dataset` indicates the dataset: | ||||
|   #   'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set | ||||
| @@ -388,7 +395,6 @@ class NASBench201API(object): | ||||
|       return xifo | ||||
|   """ | ||||
|  | ||||
|  | ||||
|   def show(self, index: int = -1) -> None: | ||||
|     """ | ||||
|     This function will print the information of a specific (or all) architecture(s). | ||||
| @@ -423,7 +429,6 @@ class NASBench201API(object): | ||||
|       else: | ||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||
|  | ||||
|  | ||||
|   def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]: | ||||
|     """ | ||||
|     This function will count the number of total trials. | ||||
| @@ -443,7 +448,6 @@ class NASBench201API(object): | ||||
|         nums[len(dataset_seed[dataset])] += 1 | ||||
|     return dict(nums) | ||||
|  | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2lists(arch_str: Text) -> List[tuple]: | ||||
|     """ | ||||
| @@ -471,7 +475,6 @@ class NASBench201API(object): | ||||
|       genotypes.append( input_infos ) | ||||
|     return genotypes | ||||
|  | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2matrix(arch_str: Text, | ||||
|                  search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray: | ||||
| @@ -511,7 +514,6 @@ class NASBench201API(object): | ||||
|     return matrix | ||||
|  | ||||
|  | ||||
|  | ||||
| class ArchResults(object): | ||||
|  | ||||
|   def __init__(self, arch_index, arch_str): | ||||
| @@ -752,7 +754,6 @@ class ArchResults(object): | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done)) | ||||
|      | ||||
|  | ||||
|  | ||||
| """ | ||||
| @@ -872,8 +873,8 @@ class ResultsCount(object): | ||||
|             'cur_time': xtime, | ||||
|             'all_time': atime} | ||||
|  | ||||
|   # get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument). | ||||
|   def get_eval(self, name, iepoch=None): | ||||
|     """Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).""" | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: | ||||
| @@ -890,8 +891,8 @@ class ResultsCount(object): | ||||
|     if clone: return copy.deepcopy(self.net_state_dict) | ||||
|     else: return self.net_state_dict | ||||
|  | ||||
|   # This function is used to obtain the config dict for this architecture. | ||||
|   def get_config(self, str2structure): | ||||
|     """This function is used to obtain the config dict for this architecture.""" | ||||
|     if str2structure is None: | ||||
|       return {'name': 'infer.tiny', 'C': self.arch_config['channel'], | ||||
|               'N'   : self.arch_config['num_cells'], | ||||
|   | ||||
| @@ -15,7 +15,7 @@ else | ||||
|   echo "TORCH_HOME : $TORCH_HOME" | ||||
| fi | ||||
|  | ||||
| OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \ | ||||
| CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \ | ||||
| 	--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \ | ||||
| 	--dataset $1 \ | ||||
| 	--use_12 $2 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user