simplify baselines
This commit is contained in:
parent
f8f44bfb31
commit
9ec25663f1
193
exps/NAS-Bench-102/test-correlation.py
Normal file
193
exps/NAS-Bench-102/test-correlation.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
########################################################
|
||||||
|
# python exps/NAS-Bench-102/test-correlation.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth
|
||||||
|
########################################################
|
||||||
|
import os, sys, time, glob, random, argparse
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from config_utils import load_config, dict2config, configure2str
|
||||||
|
from datasets import get_datasets, SearchDataset
|
||||||
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
|
from utils import get_model_infos, obtain_accuracy
|
||||||
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
|
from models import get_cell_based_tiny_net, get_search_spaces, CellStructure
|
||||||
|
from nas_102_api import NASBench102API as API
|
||||||
|
|
||||||
|
|
||||||
|
def valid_func(xloader, network, criterion):
|
||||||
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
|
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
network.eval()
|
||||||
|
end = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||||
|
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
# prediction
|
||||||
|
_, logits = network(arch_inputs)
|
||||||
|
arch_loss = criterion(logits, arch_targets)
|
||||||
|
# record
|
||||||
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
|
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||||
|
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||||
|
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||||
|
|
||||||
|
|
||||||
|
def main(xargs):
|
||||||
|
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.set_num_threads( xargs.workers )
|
||||||
|
prepare_seed(xargs.rand_seed)
|
||||||
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
|
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||||
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
elif xargs.dataset.startswith('ImageNet16'):
|
||||||
|
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||||
|
imagenet16_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||||
|
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||||
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
|
# To split data
|
||||||
|
train_data_v2 = deepcopy(train_data)
|
||||||
|
train_data_v2.transform = valid_data.transform
|
||||||
|
valid_data = train_data_v2
|
||||||
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
|
# data loader
|
||||||
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
|
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
|
||||||
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
|
model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||||
|
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||||
|
'space' : search_space}, None)
|
||||||
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
|
logger.log('search-model :\n{:}'.format(search_model))
|
||||||
|
|
||||||
|
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||||
|
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||||
|
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||||
|
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
||||||
|
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||||
|
logger.log('criterion : {:}'.format(criterion))
|
||||||
|
flop, param = get_model_infos(search_model, xshape)
|
||||||
|
#logger.log('{:}'.format(search_model))
|
||||||
|
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||||
|
if xargs.arch_nas_dataset is None:
|
||||||
|
api = None
|
||||||
|
else:
|
||||||
|
api = API(xargs.arch_nas_dataset)
|
||||||
|
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||||
|
|
||||||
|
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||||
|
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||||
|
|
||||||
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
def check_unique_arch(meta_file):
|
||||||
|
api = API(str(meta_file))
|
||||||
|
arch_strs = deepcopy(api.meta_archs)
|
||||||
|
xarchs = [CellStructure.str2structure(x) for x in arch_strs]
|
||||||
|
def get_unique_matrix(archs, consider_zero):
|
||||||
|
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
|
||||||
|
print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
|
||||||
|
Unique2Index = dict()
|
||||||
|
for index, xstr in enumerate(UniquStrs):
|
||||||
|
if xstr not in Unique2Index: Unique2Index[xstr] = list()
|
||||||
|
Unique2Index[xstr].append( index )
|
||||||
|
sm_matrix = torch.eye(len(archs)).bool()
|
||||||
|
for _, xlist in Unique2Index.items():
|
||||||
|
for i in xlist:
|
||||||
|
for j in xlist:
|
||||||
|
sm_matrix[i,j] = True
|
||||||
|
unique_ids, unique_num = [-1 for _ in archs], 0
|
||||||
|
for i in range(len(unique_ids)):
|
||||||
|
if unique_ids[i] > -1: continue
|
||||||
|
neighbours = sm_matrix[i].nonzero().view(-1).tolist()
|
||||||
|
for nghb in neighbours:
|
||||||
|
assert unique_ids[nghb] == -1, 'impossible'
|
||||||
|
unique_ids[nghb] = unique_num
|
||||||
|
unique_num += 1
|
||||||
|
return sm_matrix, unique_ids, unique_num
|
||||||
|
|
||||||
|
print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in xarchs) ))
|
||||||
|
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
|
||||||
|
print ('{:} There are {:} unique architectures (considering nothing).'.format(time_string(), unique_num))
|
||||||
|
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
|
||||||
|
print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num))
|
||||||
|
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
|
||||||
|
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
|
||||||
|
|
||||||
|
|
||||||
|
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
|
||||||
|
if isinstance(meta_file, API):
|
||||||
|
api = meta_file
|
||||||
|
else:
|
||||||
|
api = API(str(meta_file))
|
||||||
|
cifar10_valid = []
|
||||||
|
cifar10_test = []
|
||||||
|
cifar100_test = []
|
||||||
|
imagenet_test = []
|
||||||
|
for idx, arch in enumerate(api):
|
||||||
|
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
|
||||||
|
cifar10_valid.append( results['valid-accuracy'] )
|
||||||
|
results = api.get_more_info(idx, 'cifar10' , None, False, is_rand)
|
||||||
|
cifar10_test.append( results['test-accuracy'] )
|
||||||
|
results = api.get_more_info(idx, 'cifar100' , None, False, is_rand)
|
||||||
|
cifar100_test.append( results['test-accuracy'] )
|
||||||
|
results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand)
|
||||||
|
imagenet_test.append( results['test-accuracy'] )
|
||||||
|
def get_cor(A, B):
|
||||||
|
return float(np.corrcoef(A, B)[0,1])
|
||||||
|
cors = []
|
||||||
|
for basestr, xlist in zip(['CIFAR-010', 'CIFAR-100', 'ImageNet16'], [cifar10_test,cifar100_test, imagenet_test]):
|
||||||
|
correlation = get_cor(cifar10_valid, xlist)
|
||||||
|
print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(less_epoch, '012' if use_less_or_not else '200', basestr, correlation))
|
||||||
|
cors.append( correlation )
|
||||||
|
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
|
||||||
|
#print('-'*200)
|
||||||
|
#print('*'*230)
|
||||||
|
return cors
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser("Analysis of NAS-Bench-102")
|
||||||
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||||
|
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
vis_save_dir = Path(args.save_dir)
|
||||||
|
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
meta_file = Path(args.api_path)
|
||||||
|
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||||
|
|
||||||
|
#check_unique_arch(meta_file)
|
||||||
|
api = API(str(meta_file))
|
||||||
|
#for iepoch in [11, 25, 50, 100, 150, 175, 200]:
|
||||||
|
# check_cor_for_bandit(api, 6, iepoch)
|
||||||
|
# check_cor_for_bandit(api, 12, iepoch)
|
||||||
|
correlations = check_cor_for_bandit(api, 6, True, True)
|
||||||
|
import pdb; pdb.set_trace()
|
@ -370,17 +370,17 @@ def write_video(save_dir):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||||
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
|
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
vis_save_dir = Path(args.save_dir) / 'visuals'
|
vis_save_dir = Path(args.save_dir)
|
||||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
meta_file = Path(args.api_path)
|
meta_file = Path(args.api_path)
|
||||||
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||||
visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
|
#visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
|
||||||
write_video(vis_save_dir / 'over-time')
|
#write_video(vis_save_dir / 'over-time')
|
||||||
visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
|
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
|
||||||
visualize_info(str(meta_file), 'cifar100', vis_save_dir)
|
#visualize_info(str(meta_file), 'cifar100', vis_save_dir)
|
||||||
visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
|
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
|
||||||
visualize_relative_ranking(vis_save_dir)
|
visualize_relative_ranking(vis_save_dir)
|
||||||
|
@ -110,25 +110,30 @@ def main(xargs, nas_bench):
|
|||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
if xargs.data_path is not None:
|
||||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
cifar_split = load_config(split_Fpath, None, None)
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||||
# To split data
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
train_data_v2 = deepcopy(train_data)
|
# To split data
|
||||||
train_data_v2.transform = valid_data.transform
|
train_data_v2 = deepcopy(train_data)
|
||||||
valid_data = train_data_v2
|
train_data_v2.transform = valid_data.transform
|
||||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
valid_data = train_data_v2
|
||||||
# data loader
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
# data loader
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
||||||
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
||||||
|
else:
|
||||||
|
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||||
|
config = load_config(config_path, None, logger)
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
|
||||||
|
|
||||||
# nas dataset load
|
# nas dataset load
|
||||||
assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset)
|
assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset)
|
||||||
|
@ -29,25 +29,30 @@ def main(xargs, nas_bench):
|
|||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
if xargs.data_path is not None:
|
||||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
cifar_split = load_config(split_Fpath, None, None)
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||||
# To split data
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
train_data_v2 = deepcopy(train_data)
|
# To split data
|
||||||
train_data_v2.transform = valid_data.transform
|
train_data_v2 = deepcopy(train_data)
|
||||||
valid_data = train_data_v2
|
train_data_v2.transform = valid_data.transform
|
||||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
valid_data = train_data_v2
|
||||||
# data loader
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
# data loader
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
||||||
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
||||||
|
else:
|
||||||
|
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||||
|
config = load_config(config_path, None, logger)
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
|
||||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
||||||
#x =random_arch() ; y = mutate_arch(x)
|
#x =random_arch() ; y = mutate_arch(x)
|
||||||
@ -71,7 +76,7 @@ def main(xargs, nas_bench):
|
|||||||
logger.log('-'*100)
|
logger.log('-'*100)
|
||||||
logger.close()
|
logger.close()
|
||||||
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -172,24 +172,30 @@ def main(xargs, nas_bench):
|
|||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
if xargs.data_path is not None:
|
||||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
cifar_split = load_config(split_Fpath, None, None)
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||||
# To split data
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
train_data_v2 = deepcopy(train_data)
|
# To split data
|
||||||
train_data_v2.transform = valid_data.transform
|
train_data_v2 = deepcopy(train_data)
|
||||||
valid_data = train_data_v2
|
train_data_v2.transform = valid_data.transform
|
||||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
valid_data = train_data_v2
|
||||||
# data loader
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
# data loader
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
||||||
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
||||||
|
else:
|
||||||
|
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||||
|
config = load_config(config_path, None, logger)
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
|
||||||
|
|
||||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
||||||
|
@ -99,24 +99,31 @@ def main(xargs, nas_bench):
|
|||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
if xargs.data_path is not None:
|
||||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
cifar_split = load_config(split_Fpath, None, None)
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||||
# To split data
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
train_data_v2 = deepcopy(train_data)
|
# To split data
|
||||||
train_data_v2.transform = valid_data.transform
|
train_data_v2 = deepcopy(train_data)
|
||||||
valid_data = train_data_v2
|
train_data_v2.transform = valid_data.transform
|
||||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
valid_data = train_data_v2
|
||||||
# data loader
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
# data loader
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
|
||||||
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
|
||||||
|
else:
|
||||||
|
config_path = 'configs/nas-benchmark/algos/R-EA.config'
|
||||||
|
config = load_config(config_path, None, logger)
|
||||||
|
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
|
||||||
|
|
||||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
policy = Policy(xargs.max_nodes, search_space)
|
policy = Policy(xargs.max_nodes, search_space)
|
||||||
|
@ -74,15 +74,22 @@ class Structure:
|
|||||||
nodes[i+1] = sum(sums) > 0
|
nodes[i+1] = sum(sums) > 0
|
||||||
return nodes[len(self.nodes)]
|
return nodes[len(self.nodes)]
|
||||||
|
|
||||||
def to_unique_str(self):
|
def to_unique_str(self, consider_zero=False):
|
||||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||||
# two operations are special, i.e., none and skip_connect
|
# two operations are special, i.e., none and skip_connect
|
||||||
nodes = {0: '0'}
|
nodes = {0: '0'}
|
||||||
for i_node, node_info in enumerate(self.nodes):
|
for i_node, node_info in enumerate(self.nodes):
|
||||||
cur_node = []
|
cur_node = []
|
||||||
for op, xin in node_info:
|
for op, xin in node_info:
|
||||||
if op == 'skip_connect': x = nodes[xin]
|
if consider_zero is None:
|
||||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||||
|
elif consider_zero:
|
||||||
|
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
||||||
|
elif op == 'skip_connect': x = nodes[xin]
|
||||||
|
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||||
|
else:
|
||||||
|
if op == 'skip_connect': x = nodes[xin]
|
||||||
|
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||||
cur_node.append(x)
|
cur_node.append(x)
|
||||||
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
||||||
return nodes[ len(self.nodes) ]
|
return nodes[ len(self.nodes) ]
|
||||||
|
@ -41,8 +41,9 @@ class NASBench102API(object):
|
|||||||
if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict))
|
if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict))
|
||||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||||
file_path_or_dict = torch.load(file_path_or_dict)
|
file_path_or_dict = torch.load(file_path_or_dict)
|
||||||
else:
|
elif isinstance(file_path_or_dict, dict):
|
||||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||||
|
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||||
@ -152,26 +153,40 @@ class NASBench102API(object):
|
|||||||
archresult = arch2infos[index]
|
archresult = arch2infos[index]
|
||||||
return archresult.get_net_param(dataset, seed)
|
return archresult.get_net_param(dataset, seed)
|
||||||
|
|
||||||
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False):
|
# obtain the metric for the `index`-th architecture
|
||||||
|
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
archresult = arch2infos[index]
|
archresult = arch2infos[index]
|
||||||
if dataset == 'cifar10-valid':
|
if dataset == 'cifar10-valid':
|
||||||
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=True)
|
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
|
||||||
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True)
|
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random)
|
||||||
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True)
|
try:
|
||||||
|
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||||
|
except:
|
||||||
|
test__info = None
|
||||||
total = train_info['iepoch'] + 1
|
total = train_info['iepoch'] + 1
|
||||||
return {'train-loss' : train_info['loss'],
|
xifo = {'train-loss' : train_info['loss'],
|
||||||
'train-accuracy': train_info['accuracy'],
|
'train-accuracy': train_info['accuracy'],
|
||||||
'train-all-time': train_info['all_time'],
|
'train-all-time': train_info['all_time'],
|
||||||
'valid-loss' : valid_info['loss'],
|
'valid-loss' : valid_info['loss'],
|
||||||
'valid-accuracy': valid_info['accuracy'],
|
'valid-accuracy': valid_info['accuracy'],
|
||||||
'valid-all-time': valid_info['all_time'],
|
'valid-all-time': valid_info['all_time'],
|
||||||
'valid-per-time': valid_info['all_time'] / total,
|
'valid-per-time': None if valid_info['all_time'] is None else valid_info['all_time'] / total}
|
||||||
|
if test__info is not None:
|
||||||
|
xifo['test-loss'] = test__info['loss']
|
||||||
|
xifo['test-accuracy'] = test__info['accuracy']
|
||||||
|
return xifo
|
||||||
|
else:
|
||||||
|
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||||
|
else:
|
||||||
|
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||||
|
return {'train-loss' : train_info['loss'],
|
||||||
|
'train-accuracy': train_info['accuracy'],
|
||||||
'test-loss' : test__info['loss'],
|
'test-loss' : test__info['loss'],
|
||||||
'test-accuracy' : test__info['accuracy']}
|
'test-accuracy' : test__info['accuracy']}
|
||||||
else:
|
|
||||||
raise ValueError('coming soon...')
|
|
||||||
|
|
||||||
def show(self, index=-1):
|
def show(self, index=-1):
|
||||||
if index < 0: # show all architectures
|
if index < 0: # show all architectures
|
||||||
@ -369,7 +384,7 @@ class ResultsCount(object):
|
|||||||
def update_latency(self, latency):
|
def update_latency(self, latency):
|
||||||
self.latency = copy.deepcopy( latency )
|
self.latency = copy.deepcopy( latency )
|
||||||
|
|
||||||
def update_eval(self, accs, losses, times): # old version
|
def update_eval(self, accs, losses, times): # new version
|
||||||
data_names = set([x.split('@')[0] for x in accs.keys()])
|
data_names = set([x.split('@')[0] for x in accs.keys()])
|
||||||
for data_name in data_names:
|
for data_name in data_names:
|
||||||
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
|
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
|
||||||
|
@ -21,17 +21,11 @@ num_cells=5
|
|||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-102
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
|
||||||
else
|
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
|
||||||
fi
|
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/BOHB-${dataset}
|
save_dir=./output/search-cell-${space}/BOHB-${dataset}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
|
@ -22,17 +22,11 @@ num_cells=5
|
|||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-102
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
|
||||||
else
|
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
|
||||||
fi
|
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/R-EA-${dataset}
|
save_dir=./output/search-cell-${space}/R-EA-${dataset}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
|
@ -21,17 +21,11 @@ num_cells=5
|
|||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-102
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
|
||||||
else
|
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
|
||||||
fi
|
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/REINFORCE-${dataset}
|
save_dir=./output/search-cell-${space}/REINFORCE-${dataset}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
|
@ -21,17 +21,11 @@ num_cells=5
|
|||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-102
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
|
||||||
else
|
|
||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
|
||||||
fi
|
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/RAND-${dataset}
|
save_dir=./output/search-cell-${space}/RAND-${dataset}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
|
Loading…
Reference in New Issue
Block a user