Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
parent
b702ddf5a2
commit
22025887f1
2
.gitignore
vendored
2
.gitignore
vendored
@ -121,3 +121,5 @@ lib/NAS-Bench-*-v1_0.pth
|
|||||||
others/TF
|
others/TF
|
||||||
scripts-search/l2s-algos
|
scripts-search/l2s-algos
|
||||||
TEMP-L.sh
|
TEMP-L.sh
|
||||||
|
|
||||||
|
.nfs00*
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||||
#####################################################
|
#####################################################
|
||||||
import os, sys, time, torch
|
import time, torch
|
||||||
from procedures import prepare_seed, get_optim_scheduler
|
from procedures import prepare_seed, get_optim_scheduler
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from config_utils import dict2config
|
from config_utils import dict2config
|
||||||
@ -9,11 +9,9 @@ from log_utils import AverageMeter, time_string, convert_secs2time
|
|||||||
from models import get_cell_based_tiny_net
|
from models import get_cell_based_tiny_net
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c
|
|||||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||||
# train valid data
|
# train valid data
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||||
# load the configurature
|
# load the configuration
|
||||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||||
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
|
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
|
||||||
else : config_path = 'configs/nas-benchmark/CIFAR.config'
|
else : config_path = 'configs/nas-benchmark/CIFAR.config'
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
################################################################################################
|
################################################################################################
|
||||||
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
|
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
|
||||||
################################################################################################
|
################################################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import sys, argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
|
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
|
||||||
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
|
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
|
||||||
###############################################################################################
|
###############################################################################################
|
||||||
import os, gc, sys, time, glob, random, argparse
|
import os, gc, sys, argparse, psutil
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -33,7 +33,7 @@ def tostr(accdict, norms):
|
|||||||
|
|
||||||
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
||||||
print('\nEvaluate dataset={:}'.format(data))
|
print('\nEvaluate dataset={:}'.format(data))
|
||||||
norms = []
|
norms, process = [], psutil.Process(os.getpid())
|
||||||
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||||
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
|
||||||
for idx in range(len(api)):
|
for idx in range(len(api)):
|
||||||
@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
net.load_state_dict(param)
|
net.load_state_dict(param)
|
||||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||||
cur_norms.append( summary['lognorm'] )
|
cur_norms.append(summary['lognorm'])
|
||||||
norms.append( float(np.mean(cur_norms)) )
|
norms.append( float(np.mean(cur_norms)) )
|
||||||
api.clear_params(idx, use_12epochs_result)
|
api.clear_params(idx, None)
|
||||||
if idx % 200 == 199 or idx + 1 == len(api):
|
if idx % 200 == 199 or idx + 1 == len(api):
|
||||||
head = '{:05d}/{:05d}'.format(idx, len(api))
|
head = '{:05d}/{:05d}'.format(idx, len(api))
|
||||||
stem_val = tostr(final_val_accs, norms)
|
stem_val = tostr(final_val_accs, norms)
|
||||||
stem_test = tostr(final_test_accs, norms)
|
stem_test = tostr(final_test_accs, norms)
|
||||||
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200))
|
print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
|
||||||
print(' -->> {:} || {:}'.format(stem_val, stem_test))
|
print(' [Valid] -->> {:}'.format(stem_val))
|
||||||
torch.cuda.empty_cache() ; gc.collect()
|
print(' [Test.] -->> {:}'.format(stem_test))
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):
|
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||||
#####################################################
|
#####################################################
|
||||||
import os, sys, time, argparse, collections
|
import sys, argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
|||||||
machine_info = get_machine_info()
|
machine_info = get_machine_info()
|
||||||
all_infos = {'info': machine_info}
|
all_infos = {'info': machine_info}
|
||||||
all_dataset_keys = []
|
all_dataset_keys = []
|
||||||
# look all the datasets
|
# look all the dataset
|
||||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||||
# train valid data
|
# the train and valid data
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||||
# load the configurature
|
# load the configuration
|
||||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||||
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||||
elif dataset.startswith('ImageNet16'):
|
elif dataset.startswith('ImageNet16'):
|
||||||
@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
|||||||
else:
|
else:
|
||||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||||
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
|
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
|
||||||
# check whether use splited validation set
|
# check whether use the splitted validation set
|
||||||
if bool(split):
|
if bool(split):
|
||||||
assert dataset == 'cifar10'
|
assert dataset == 'cifar10'
|
||||||
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
|
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
|
||||||
@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
|||||||
|
|
||||||
log_dir = save_dir / 'logs'
|
log_dir = save_dir / 'logs'
|
||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
logger = Logger(str(log_dir), 0, False)
|
logger = Logger(str(log_dir), os.getpid(), False)
|
||||||
|
|
||||||
logger.log('xargs : seeds = {:}'.format(seeds))
|
logger.log('xargs : seeds = {:}'.format(seeds))
|
||||||
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
|
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
|
||||||
|
@ -114,15 +114,27 @@ class NASBench201API(object):
|
|||||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||||
xdata = torch.load(xfile_path, map_location='cpu')
|
xdata = torch.load(xfile_path, map_location='cpu')
|
||||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||||
|
if index in self.arch2infos_less: del self.arch2infos_less[index]
|
||||||
|
if index in self.arch2infos_full: del self.arch2infos_full[index]
|
||||||
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
||||||
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
||||||
|
|
||||||
def clear_params(self, index: int, use_12epochs_result: bool):
|
def clear_params(self, index: int, use_12epochs_result: Union[bool, None]):
|
||||||
"""Remove the architecture's weights to save memory."""
|
"""Remove the architecture's weights to save memory.
|
||||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
:arg
|
||||||
else : arch2infos = self.arch2infos_full
|
index: the index of the target architecture
|
||||||
archresult = arch2infos[index]
|
use_12epochs_result: a flag to controll how to clear the parameters.
|
||||||
archresult.clear_params()
|
-- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters.
|
||||||
|
-- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result.
|
||||||
|
-- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result.
|
||||||
|
"""
|
||||||
|
if use_12epochs_result is None:
|
||||||
|
self.arch2infos_less[index].clear_params()
|
||||||
|
self.arch2infos_full[index].clear_params()
|
||||||
|
else:
|
||||||
|
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||||
|
else : arch2infos = self.arch2infos_full
|
||||||
|
arch2infos[index].clear_params()
|
||||||
|
|
||||||
# This function is used to query the information of a specific archiitecture
|
# This function is used to query the information of a specific archiitecture
|
||||||
# 'arch' can be an architecture index or an architecture string
|
# 'arch' can be an architecture index or an architecture string
|
||||||
@ -193,7 +205,6 @@ class NASBench201API(object):
|
|||||||
best_index, highest_accuracy = idx, accuracy
|
best_index, highest_accuracy = idx, accuracy
|
||||||
return best_index, highest_accuracy
|
return best_index, highest_accuracy
|
||||||
|
|
||||||
|
|
||||||
def arch(self, index: int):
|
def arch(self, index: int):
|
||||||
"""Return the topology structure of the `index`-th architecture."""
|
"""Return the topology structure of the `index`-th architecture."""
|
||||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||||
@ -213,7 +224,6 @@ class NASBench201API(object):
|
|||||||
else: arch2infos = self.arch2infos_full
|
else: arch2infos = self.arch2infos_full
|
||||||
arch_result = arch2infos[index]
|
arch_result = arch2infos[index]
|
||||||
return arch_result.get_net_param(dataset, seed)
|
return arch_result.get_net_param(dataset, seed)
|
||||||
|
|
||||||
|
|
||||||
def get_net_config(self, index: int, dataset: Text):
|
def get_net_config(self, index: int, dataset: Text):
|
||||||
"""
|
"""
|
||||||
@ -235,7 +245,6 @@ class NASBench201API(object):
|
|||||||
#print ('SEED [{:}] : {:}'.format(seed, result))
|
#print ('SEED [{:}] : {:}'.format(seed, result))
|
||||||
raise ValueError('Impossible to reach here!')
|
raise ValueError('Impossible to reach here!')
|
||||||
|
|
||||||
|
|
||||||
def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
|
def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
|
||||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||||
@ -243,7 +252,6 @@ class NASBench201API(object):
|
|||||||
arch_result = arch2infos[index]
|
arch_result = arch2infos[index]
|
||||||
return arch_result.get_compute_costs(dataset)
|
return arch_result.get_compute_costs(dataset)
|
||||||
|
|
||||||
|
|
||||||
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
|
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
|
||||||
"""
|
"""
|
||||||
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
||||||
@ -254,7 +262,6 @@ class NASBench201API(object):
|
|||||||
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
|
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
|
||||||
return cost_dict['latency']
|
return cost_dict['latency']
|
||||||
|
|
||||||
|
|
||||||
# obtain the metric for the `index`-th architecture
|
# obtain the metric for the `index`-th architecture
|
||||||
# `dataset` indicates the dataset:
|
# `dataset` indicates the dataset:
|
||||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||||
@ -388,7 +395,6 @@ class NASBench201API(object):
|
|||||||
return xifo
|
return xifo
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def show(self, index: int = -1) -> None:
|
def show(self, index: int = -1) -> None:
|
||||||
"""
|
"""
|
||||||
This function will print the information of a specific (or all) architecture(s).
|
This function will print the information of a specific (or all) architecture(s).
|
||||||
@ -423,7 +429,6 @@ class NASBench201API(object):
|
|||||||
else:
|
else:
|
||||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||||
|
|
||||||
|
|
||||||
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
|
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
|
||||||
"""
|
"""
|
||||||
This function will count the number of total trials.
|
This function will count the number of total trials.
|
||||||
@ -443,7 +448,6 @@ class NASBench201API(object):
|
|||||||
nums[len(dataset_seed[dataset])] += 1
|
nums[len(dataset_seed[dataset])] += 1
|
||||||
return dict(nums)
|
return dict(nums)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def str2lists(arch_str: Text) -> List[tuple]:
|
def str2lists(arch_str: Text) -> List[tuple]:
|
||||||
"""
|
"""
|
||||||
@ -471,7 +475,6 @@ class NASBench201API(object):
|
|||||||
genotypes.append( input_infos )
|
genotypes.append( input_infos )
|
||||||
return genotypes
|
return genotypes
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def str2matrix(arch_str: Text,
|
def str2matrix(arch_str: Text,
|
||||||
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
||||||
@ -511,7 +514,6 @@ class NASBench201API(object):
|
|||||||
return matrix
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ArchResults(object):
|
class ArchResults(object):
|
||||||
|
|
||||||
def __init__(self, arch_index, arch_str):
|
def __init__(self, arch_index, arch_str):
|
||||||
@ -752,7 +754,6 @@ class ArchResults(object):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -872,8 +873,8 @@ class ResultsCount(object):
|
|||||||
'cur_time': xtime,
|
'cur_time': xtime,
|
||||||
'all_time': atime}
|
'all_time': atime}
|
||||||
|
|
||||||
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
|
|
||||||
def get_eval(self, name, iepoch=None):
|
def get_eval(self, name, iepoch=None):
|
||||||
|
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
|
||||||
if iepoch is None: iepoch = self.epochs-1
|
if iepoch is None: iepoch = self.epochs-1
|
||||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||||
@ -890,8 +891,8 @@ class ResultsCount(object):
|
|||||||
if clone: return copy.deepcopy(self.net_state_dict)
|
if clone: return copy.deepcopy(self.net_state_dict)
|
||||||
else: return self.net_state_dict
|
else: return self.net_state_dict
|
||||||
|
|
||||||
# This function is used to obtain the config dict for this architecture.
|
|
||||||
def get_config(self, str2structure):
|
def get_config(self, str2structure):
|
||||||
|
"""This function is used to obtain the config dict for this architecture."""
|
||||||
if str2structure is None:
|
if str2structure is None:
|
||||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||||
'N' : self.arch_config['num_cells'],
|
'N' : self.arch_config['num_cells'],
|
||||||
|
@ -15,7 +15,7 @@ else
|
|||||||
echo "TORCH_HOME : $TORCH_HOME"
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
|
CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
|
||||||
--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \
|
--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \
|
||||||
--dataset $1 \
|
--dataset $1 \
|
||||||
--use_12 $2
|
--use_12 $2
|
||||||
|
Loading…
Reference in New Issue
Block a user