Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
parent
b702ddf5a2
commit
22025887f1
2
.gitignore
vendored
2
.gitignore
vendored
@ -121,3 +121,5 @@ lib/NAS-Bench-*-v1_0.pth
|
||||
others/TF
|
||||
scripts-search/l2s-algos
|
||||
TEMP-L.sh
|
||||
|
||||
.nfs00*
|
||||
|
@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import os, sys, time, torch
|
||||
import time, torch
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from config_utils import dict2config
|
||||
@ -9,11 +9,9 @@ from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
||||
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
|
@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||
# load the configurature
|
||||
# load the configuration
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
|
||||
else : config_path = 'configs/nas-benchmark/CIFAR.config'
|
||||
|
@ -3,7 +3,7 @@
|
||||
################################################################################################
|
||||
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
|
||||
################################################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import sys, argparse
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
@ -6,7 +6,7 @@
|
||||
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
|
||||
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
|
||||
###############################################################################################
|
||||
import os, gc, sys, time, glob, random, argparse
|
||||
import os, gc, sys, argparse, psutil
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
@ -33,7 +33,7 @@ def tostr(accdict, norms):
|
||||
|
||||
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
print('\nEvaluate dataset={:}'.format(data))
|
||||
norms = []
|
||||
norms, process = [], psutil.Process(os.getpid())
|
||||
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||
for idx in range(len(api)):
|
||||
@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||
with torch.no_grad():
|
||||
net.load_state_dict(param)
|
||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||
cur_norms.append( summary['lognorm'] )
|
||||
cur_norms.append(summary['lognorm'])
|
||||
norms.append( float(np.mean(cur_norms)) )
|
||||
api.clear_params(idx, use_12epochs_result)
|
||||
api.clear_params(idx, None)
|
||||
if idx % 200 == 199 or idx + 1 == len(api):
|
||||
head = '{:05d}/{:05d}'.format(idx, len(api))
|
||||
stem_val = tostr(final_val_accs, norms)
|
||||
stem_test = tostr(final_test_accs, norms)
|
||||
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200))
|
||||
print(' -->> {:} || {:}'.format(stem_val, stem_test))
|
||||
torch.cuda.empty_cache() ; gc.collect()
|
||||
print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
|
||||
print(' [Valid] -->> {:}'.format(stem_val))
|
||||
print(' [Test.] -->> {:}'.format(stem_test))
|
||||
gc.collect()
|
||||
|
||||
|
||||
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):
|
||||
|
@ -3,7 +3,7 @@
|
||||
#####################################################
|
||||
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||
#####################################################
|
||||
import os, sys, time, argparse, collections
|
||||
import sys, argparse
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
||||
machine_info = get_machine_info()
|
||||
all_infos = {'info': machine_info}
|
||||
all_dataset_keys = []
|
||||
# look all the datasets
|
||||
# look all the dataset
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
# the train and valid data
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||
# load the configurature
|
||||
# load the configuration
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||
elif dataset.startswith('ImageNet16'):
|
||||
@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
|
||||
# check whether use splited validation set
|
||||
# check whether use the splitted validation set
|
||||
if bool(split):
|
||||
assert dataset == 'cifar10'
|
||||
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
|
||||
@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
||||
|
||||
log_dir = save_dir / 'logs'
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger = Logger(str(log_dir), 0, False)
|
||||
logger = Logger(str(log_dir), os.getpid(), False)
|
||||
|
||||
logger.log('xargs : seeds = {:}'.format(seeds))
|
||||
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
|
||||
|
@ -114,15 +114,27 @@ class NASBench201API(object):
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = torch.load(xfile_path, map_location='cpu')
|
||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||
if index in self.arch2infos_less: del self.arch2infos_less[index]
|
||||
if index in self.arch2infos_full: del self.arch2infos_full[index]
|
||||
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
||||
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
||||
|
||||
def clear_params(self, index: int, use_12epochs_result: bool):
|
||||
"""Remove the architecture's weights to save memory."""
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else : arch2infos = self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
archresult.clear_params()
|
||||
def clear_params(self, index: int, use_12epochs_result: Union[bool, None]):
|
||||
"""Remove the architecture's weights to save memory.
|
||||
:arg
|
||||
index: the index of the target architecture
|
||||
use_12epochs_result: a flag to controll how to clear the parameters.
|
||||
-- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters.
|
||||
-- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result.
|
||||
-- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result.
|
||||
"""
|
||||
if use_12epochs_result is None:
|
||||
self.arch2infos_less[index].clear_params()
|
||||
self.arch2infos_full[index].clear_params()
|
||||
else:
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else : arch2infos = self.arch2infos_full
|
||||
arch2infos[index].clear_params()
|
||||
|
||||
# This function is used to query the information of a specific archiitecture
|
||||
# 'arch' can be an architecture index or an architecture string
|
||||
@ -193,7 +205,6 @@ class NASBench201API(object):
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
return best_index, highest_accuracy
|
||||
|
||||
|
||||
def arch(self, index: int):
|
||||
"""Return the topology structure of the `index`-th architecture."""
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
@ -213,7 +224,6 @@ class NASBench201API(object):
|
||||
else: arch2infos = self.arch2infos_full
|
||||
arch_result = arch2infos[index]
|
||||
return arch_result.get_net_param(dataset, seed)
|
||||
|
||||
|
||||
def get_net_config(self, index: int, dataset: Text):
|
||||
"""
|
||||
@ -235,7 +245,6 @@ class NASBench201API(object):
|
||||
#print ('SEED [{:}] : {:}'.format(seed, result))
|
||||
raise ValueError('Impossible to reach here!')
|
||||
|
||||
|
||||
def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
|
||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
@ -243,7 +252,6 @@ class NASBench201API(object):
|
||||
arch_result = arch2infos[index]
|
||||
return arch_result.get_compute_costs(dataset)
|
||||
|
||||
|
||||
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
|
||||
"""
|
||||
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
||||
@ -254,7 +262,6 @@ class NASBench201API(object):
|
||||
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
|
||||
return cost_dict['latency']
|
||||
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
# `dataset` indicates the dataset:
|
||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
@ -388,7 +395,6 @@ class NASBench201API(object):
|
||||
return xifo
|
||||
"""
|
||||
|
||||
|
||||
def show(self, index: int = -1) -> None:
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
@ -423,7 +429,6 @@ class NASBench201API(object):
|
||||
else:
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
|
||||
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
|
||||
"""
|
||||
This function will count the number of total trials.
|
||||
@ -443,7 +448,6 @@ class NASBench201API(object):
|
||||
nums[len(dataset_seed[dataset])] += 1
|
||||
return dict(nums)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def str2lists(arch_str: Text) -> List[tuple]:
|
||||
"""
|
||||
@ -471,7 +475,6 @@ class NASBench201API(object):
|
||||
genotypes.append( input_infos )
|
||||
return genotypes
|
||||
|
||||
|
||||
@staticmethod
|
||||
def str2matrix(arch_str: Text,
|
||||
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
||||
@ -511,7 +514,6 @@ class NASBench201API(object):
|
||||
return matrix
|
||||
|
||||
|
||||
|
||||
class ArchResults(object):
|
||||
|
||||
def __init__(self, arch_index, arch_str):
|
||||
@ -752,7 +754,6 @@ class ArchResults(object):
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||
|
||||
|
||||
|
||||
"""
|
||||
@ -872,8 +873,8 @@ class ResultsCount(object):
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
|
||||
def get_eval(self, name, iepoch=None):
|
||||
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
@ -890,8 +891,8 @@ class ResultsCount(object):
|
||||
if clone: return copy.deepcopy(self.net_state_dict)
|
||||
else: return self.net_state_dict
|
||||
|
||||
# This function is used to obtain the config dict for this architecture.
|
||||
def get_config(self, str2structure):
|
||||
"""This function is used to obtain the config dict for this architecture."""
|
||||
if str2structure is None:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
|
@ -15,7 +15,7 @@ else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
|
||||
CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
|
||||
--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \
|
||||
--dataset $1 \
|
||||
--use_12 $2
|
||||
|
Loading…
Reference in New Issue
Block a user