Fix the potential memory leak in NAS-Bench-201 clear_param

This commit is contained in:
D-X-Y 2020-03-21 01:33:07 -07:00
parent b702ddf5a2
commit 22025887f1
9 changed files with 40 additions and 38 deletions

2
.gitignore vendored
View File

@ -121,3 +121,5 @@ lib/NAS-Bench-*-v1_0.pth
others/TF others/TF
scripts-search/l2s-algos scripts-search/l2s-algos
TEMP-L.sh TEMP-L.sh
.nfs00*

View File

@ -1,7 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
##################################################### #####################################################
import os, sys, time, torch import time, torch
from procedures import prepare_seed, get_optim_scheduler from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy
from config_utils import dict2config from config_utils import dict2config
@ -9,11 +9,9 @@ from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net from models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate'] __all__ = ['evaluate_for_seed', 'pure_evaluate']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()

View File

@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
for dataset, xpath, split in zip(datasets, xpaths, splits): for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data # train valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature # load the configuration
if dataset == 'cifar10' or dataset == 'cifar100': if dataset == 'cifar10' or dataset == 'cifar100':
if use_less: config_path = 'configs/nas-benchmark/LESS.config' if use_less: config_path = 'configs/nas-benchmark/LESS.config'
else : config_path = 'configs/nas-benchmark/CIFAR.config' else : config_path = 'configs/nas-benchmark/CIFAR.config'

View File

@ -3,7 +3,7 @@
################################################################################################ ################################################################################################
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
################################################################################################ ################################################################################################
import os, sys, time, glob, random, argparse import sys, argparse
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))

View File

@ -6,7 +6,7 @@
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
############################################################################################### ###############################################################################################
import os, gc, sys, time, glob, random, argparse import os, gc, sys, argparse, psutil
import numpy as np import numpy as np
import torch import torch
from pathlib import Path from pathlib import Path
@ -33,7 +33,7 @@ def tostr(accdict, norms):
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
print('\nEvaluate dataset={:}'.format(data)) print('\nEvaluate dataset={:}'.format(data))
norms = [] norms, process = [], psutil.Process(os.getpid())
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)): for idx in range(len(api)):
@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
with torch.no_grad(): with torch.no_grad():
net.load_state_dict(param) net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False) _, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] ) cur_norms.append(summary['lognorm'])
norms.append( float(np.mean(cur_norms)) ) norms.append( float(np.mean(cur_norms)) )
api.clear_params(idx, use_12epochs_result) api.clear_params(idx, None)
if idx % 200 == 199 or idx + 1 == len(api): if idx % 200 == 199 or idx + 1 == len(api):
head = '{:05d}/{:05d}'.format(idx, len(api)) head = '{:05d}/{:05d}'.format(idx, len(api))
stem_val = tostr(final_val_accs, norms) stem_val = tostr(final_val_accs, norms)
stem_test = tostr(final_test_accs, norms) stem_test = tostr(final_test_accs, norms)
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200)) print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
print(' -->> {:} || {:}'.format(stem_val, stem_test)) print(' [Valid] -->> {:}'.format(stem_val))
torch.cuda.empty_cache() ; gc.collect() print(' [Test.] -->> {:}'.format(stem_test))
gc.collect()
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result): def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):

View File

@ -3,7 +3,7 @@
##################################################### #####################################################
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
##################################################### #####################################################
import os, sys, time, argparse, collections import sys, argparse
from tqdm import tqdm from tqdm import tqdm
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np

View File

@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
machine_info = get_machine_info() machine_info = get_machine_info()
all_infos = {'info': machine_info} all_infos = {'info': machine_info}
all_dataset_keys = [] all_dataset_keys = []
# look all the datasets # look all the dataset
for dataset, xpath, split in zip(datasets, xpaths, splits): for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data # the train and valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature # load the configuration
if dataset == 'cifar10' or dataset == 'cifar100': if dataset == 'cifar10' or dataset == 'cifar100':
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
elif dataset.startswith('ImageNet16'): elif dataset.startswith('ImageNet16'):
@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
else: else:
raise ValueError('invalid dataset : {:}'.format(dataset)) raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
# check whether use splited validation set # check whether use the splitted validation set
if bool(split): if bool(split):
assert dataset == 'cifar10' assert dataset == 'cifar10'
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
log_dir = save_dir / 'logs' log_dir = save_dir / 'logs'
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(str(log_dir), 0, False) logger = Logger(str(log_dir), os.getpid(), False)
logger.log('xargs : seeds = {:}'.format(seeds)) logger.log('xargs : seeds = {:}'.format(seeds))
logger.log('xargs : cover_mode = {:}'.format(cover_mode)) logger.log('xargs : cover_mode = {:}'.format(cover_mode))

View File

@ -114,15 +114,27 @@ class NASBench201API(object):
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu') xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
if index in self.arch2infos_less: del self.arch2infos_less[index]
if index in self.arch2infos_full: del self.arch2infos_full[index]
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
def clear_params(self, index: int, use_12epochs_result: bool): def clear_params(self, index: int, use_12epochs_result: Union[bool, None]):
"""Remove the architecture's weights to save memory.""" """Remove the architecture's weights to save memory.
if use_12epochs_result: arch2infos = self.arch2infos_less :arg
else : arch2infos = self.arch2infos_full index: the index of the target architecture
archresult = arch2infos[index] use_12epochs_result: a flag to controll how to clear the parameters.
archresult.clear_params() -- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters.
-- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result.
-- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result.
"""
if use_12epochs_result is None:
self.arch2infos_less[index].clear_params()
self.arch2infos_full[index].clear_params()
else:
if use_12epochs_result: arch2infos = self.arch2infos_less
else : arch2infos = self.arch2infos_full
arch2infos[index].clear_params()
# This function is used to query the information of a specific archiitecture # This function is used to query the information of a specific archiitecture
# 'arch' can be an architecture index or an architecture string # 'arch' can be an architecture index or an architecture string
@ -193,7 +205,6 @@ class NASBench201API(object):
best_index, highest_accuracy = idx, accuracy best_index, highest_accuracy = idx, accuracy
return best_index, highest_accuracy return best_index, highest_accuracy
def arch(self, index: int): def arch(self, index: int):
"""Return the topology structure of the `index`-th architecture.""" """Return the topology structure of the `index`-th architecture."""
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
@ -214,7 +225,6 @@ class NASBench201API(object):
arch_result = arch2infos[index] arch_result = arch2infos[index]
return arch_result.get_net_param(dataset, seed) return arch_result.get_net_param(dataset, seed)
def get_net_config(self, index: int, dataset: Text): def get_net_config(self, index: int, dataset: Text):
""" """
This function is used to obtain the configuration for the `index`-th architecture on `dataset`. This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
@ -235,7 +245,6 @@ class NASBench201API(object):
#print ('SEED [{:}] : {:}'.format(seed, result)) #print ('SEED [{:}] : {:}'.format(seed, result))
raise ValueError('Impossible to reach here!') raise ValueError('Impossible to reach here!')
def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]: def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset.""" """To obtain the cost metric for the `index`-th architecture on a dataset."""
if use_12epochs_result: arch2infos = self.arch2infos_less if use_12epochs_result: arch2infos = self.arch2infos_less
@ -243,7 +252,6 @@ class NASBench201API(object):
arch_result = arch2infos[index] arch_result = arch2infos[index]
return arch_result.get_compute_costs(dataset) return arch_result.get_compute_costs(dataset)
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float: def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
""" """
To obtain the latency of the network (by default it will return the latency with the batch size of 256). To obtain the latency of the network (by default it will return the latency with the batch size of 256).
@ -254,7 +262,6 @@ class NASBench201API(object):
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result) cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
return cost_dict['latency'] return cost_dict['latency']
# obtain the metric for the `index`-th architecture # obtain the metric for the `index`-th architecture
# `dataset` indicates the dataset: # `dataset` indicates the dataset:
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set # 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
@ -388,7 +395,6 @@ class NASBench201API(object):
return xifo return xifo
""" """
def show(self, index: int = -1) -> None: def show(self, index: int = -1) -> None:
""" """
This function will print the information of a specific (or all) architecture(s). This function will print the information of a specific (or all) architecture(s).
@ -423,7 +429,6 @@ class NASBench201API(object):
else: else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]: def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
""" """
This function will count the number of total trials. This function will count the number of total trials.
@ -443,7 +448,6 @@ class NASBench201API(object):
nums[len(dataset_seed[dataset])] += 1 nums[len(dataset_seed[dataset])] += 1
return dict(nums) return dict(nums)
@staticmethod @staticmethod
def str2lists(arch_str: Text) -> List[tuple]: def str2lists(arch_str: Text) -> List[tuple]:
""" """
@ -471,7 +475,6 @@ class NASBench201API(object):
genotypes.append( input_infos ) genotypes.append( input_infos )
return genotypes return genotypes
@staticmethod @staticmethod
def str2matrix(arch_str: Text, def str2matrix(arch_str: Text,
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray: search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
@ -511,7 +514,6 @@ class NASBench201API(object):
return matrix return matrix
class ArchResults(object): class ArchResults(object):
def __init__(self, arch_index, arch_str): def __init__(self, arch_index, arch_str):
@ -754,7 +756,6 @@ class ArchResults(object):
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done)) return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
""" """
This class (ResultsCount) is used to save the information of one trial for a single architecture. This class (ResultsCount) is used to save the information of one trial for a single architecture.
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called. I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
@ -872,8 +873,8 @@ class ResultsCount(object):
'cur_time': xtime, 'cur_time': xtime,
'all_time': atime} 'all_time': atime}
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
def get_eval(self, name, iepoch=None): def get_eval(self, name, iepoch=None):
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
if iepoch is None: iepoch = self.epochs-1 if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
@ -890,8 +891,8 @@ class ResultsCount(object):
if clone: return copy.deepcopy(self.net_state_dict) if clone: return copy.deepcopy(self.net_state_dict)
else: return self.net_state_dict else: return self.net_state_dict
# This function is used to obtain the config dict for this architecture.
def get_config(self, str2structure): def get_config(self, str2structure):
"""This function is used to obtain the config dict for this architecture."""
if str2structure is None: if str2structure is None:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'], 'N' : self.arch_config['num_cells'],

View File

@ -15,7 +15,7 @@ else
echo "TORCH_HOME : $TORCH_HOME" echo "TORCH_HOME : $TORCH_HOME"
fi fi
OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \ CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \ --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \
--dataset $1 \ --dataset $1 \
--use_12 $2 --use_12 $2