update NAS-Bench-102 baselines
This commit is contained in:
parent
af4212b4db
commit
44a0d51449
@ -1,10 +1,10 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"eta_min" : ["float", "0.001"],
|
||||
"epochs" : ["int", "50"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
|
@ -2,7 +2,7 @@
|
||||
"scheduler": ["str", "cos"],
|
||||
"LR" : ["float", "0.05"],
|
||||
"eta_min" : ["float", "0.0005"],
|
||||
"epochs" : ["int", "310"],
|
||||
"epochs" : ["int", "250"],
|
||||
"T_max" : ["int", "10"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
|
@ -1,10 +1,10 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"eta_min" : ["float", "0.001"],
|
||||
"epochs" : ["int", "250"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
|
@ -1,10 +1,10 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"eta_min" : ["float", "0.001"],
|
||||
"epochs" : ["int", "250"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
|
@ -1,10 +1,10 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"eta_min" : ["float", "0.001"],
|
||||
"epochs" : ["int", "250"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.025"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
|
@ -15,6 +15,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
|
||||
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger):
|
||||
@ -224,6 +225,12 @@ def main(xargs):
|
||||
#flop, param = get_model_infos(shared_cnn, xshape)
|
||||
#logger.log('{:}'.format(shared_cnn))
|
||||
#logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||
logger.log('search-space : {:}'.format(search_space))
|
||||
if xargs.arch_nas_dataset is None:
|
||||
api = None
|
||||
else:
|
||||
api = API(xargs.arch_nas_dataset)
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda()
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
@ -247,7 +254,7 @@ def main(xargs):
|
||||
start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None
|
||||
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||
@ -263,7 +270,8 @@ def main(xargs):
|
||||
'ctl_entropy_w': xargs.controller_entropy_weight,
|
||||
'ctl_bl_dec' : xargs.controller_bl_dec}, None), \
|
||||
epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline))
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum))
|
||||
best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
|
||||
shared_cnn.module.update_arch(best_arch)
|
||||
_, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)
|
||||
@ -298,6 +306,7 @@ def main(xargs):
|
||||
if find_best:
|
||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc))
|
||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
@ -306,27 +315,15 @@ def main(xargs):
|
||||
logger.log('During searching, the best architecture is {:}'.format(genotypes['best']))
|
||||
logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
|
||||
logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples))
|
||||
start_time = time.time()
|
||||
final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
|
||||
search_time.update(time.time() - start_time)
|
||||
shared_cnn.module.update_arch(final_arch)
|
||||
final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
|
||||
logger.log('The Selected Final Architecture : {:}'.format(final_arch))
|
||||
logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5))
|
||||
# check the performance from the architecture dataset
|
||||
#if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
# logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
#else:
|
||||
# nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||
# geno = genotypes[total_epoch-1]
|
||||
# logger.log('The last model is {:}'.format(geno))
|
||||
# info = nas_bench.query_by_arch( geno )
|
||||
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
# else : logger.log('{:}'.format(info))
|
||||
# logger.log('-'*100)
|
||||
# geno = genotypes['best']
|
||||
# logger.log('The best model is {:}'.format(geno))
|
||||
# info = nas_bench.query_by_arch( geno )
|
||||
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
# else : logger.log('{:}'.format(info))
|
||||
logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) ))
|
||||
logger.close()
|
||||
|
||||
|
||||
|
@ -93,8 +93,8 @@ def main(xargs):
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config_path = 'configs/nas-benchmark/algos/GDAS.config'
|
||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
#config_path = 'configs/nas-benchmark/algos/GDAS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
@ -105,7 +105,7 @@ def main(xargs):
|
||||
model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': True}, None)
|
||||
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
logger.log('search-model :\n{:}'.format(search_model))
|
||||
|
||||
@ -156,7 +156,7 @@ def main(xargs):
|
||||
search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
|
||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 ))
|
||||
# check the best accuracy
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
@ -210,6 +210,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
parser.add_argument('--config_path', type=str, help='The path of the configuration.')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
|
@ -15,6 +15,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
|
||||
@ -130,6 +131,9 @@ def main(xargs):
|
||||
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||
logger.log('criterion : {:}'.format(criterion))
|
||||
if xargs.arch_nas_dataset is None: api = None
|
||||
else : api = API(xargs.arch_nas_dataset)
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||
@ -149,7 +153,7 @@ def main(xargs):
|
||||
start_epoch, valid_accuracies = 0, {'best': -1}
|
||||
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||
@ -157,7 +161,8 @@ def main(xargs):
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
|
||||
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
# check the best accuracy
|
||||
@ -188,7 +193,8 @@ def main(xargs):
|
||||
start_time = time.time()
|
||||
|
||||
logger.log('\n' + '-'*200)
|
||||
|
||||
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
|
||||
start_time = time.time()
|
||||
best_arch, best_acc = None, -1
|
||||
for iarch in range(xargs.select_num):
|
||||
arch = search_model.random_genotype( True )
|
||||
@ -197,24 +203,10 @@ def main(xargs):
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
|
||||
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc))
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
"""
|
||||
# check the performance from the architecture dataset
|
||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
else:
|
||||
nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||
geno = best_arch
|
||||
logger.log('The last model is {:}'.format(geno))
|
||||
info = nas_bench.query_by_arch( geno )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
|
||||
logger.close()
|
||||
"""
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -52,14 +52,18 @@ def main(xargs, nas_bench):
|
||||
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
||||
#x =random_arch() ; y = mutate_arch(x)
|
||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||
best_arch, best_acc = None, -1
|
||||
for idx in range(xargs.random_num):
|
||||
best_arch, best_acc, total_time_cost, history = None, -1, 0, []
|
||||
#for idx in range(xargs.random_num):
|
||||
while total_time_cost < xargs.time_budget:
|
||||
arch = random_arch()
|
||||
accuracy = train_and_eval(arch, nas_bench, extra_info)
|
||||
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
||||
if total_time_cost + cost_time > xargs.time_budget: break
|
||||
else: total_time_cost += cost_time
|
||||
history.append(arch)
|
||||
if best_arch is None or best_acc < accuracy:
|
||||
best_acc, best_arch = accuracy, arch
|
||||
logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy))
|
||||
logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc))
|
||||
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
|
||||
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost))
|
||||
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
@ -79,7 +83,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
|
||||
#parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
|
@ -60,12 +60,12 @@ def train_and_eval(arch, nas_bench, extra_info):
|
||||
arch_index = nas_bench.query_index_by_arch( arch )
|
||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||
info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True)
|
||||
import pdb; pdb.set_trace()
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
|
||||
else:
|
||||
# train a model from scratch.
|
||||
raise ValueError('NOT IMPLEMENT YET')
|
||||
return valid_acc
|
||||
return valid_acc, time_cost
|
||||
|
||||
|
||||
def random_architecture_func(max_nodes, op_names):
|
||||
@ -101,7 +101,7 @@ def mutate_arch_func(op_names):
|
||||
return mutate_arch_func
|
||||
|
||||
|
||||
def regularized_evolution(cycles, population_size, sample_size, random_arch, mutate_arch, nas_bench, extra_info):
|
||||
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info):
|
||||
"""Algorithm for regularized evolution (i.e. aging evolution).
|
||||
|
||||
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
|
||||
@ -111,27 +111,30 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut
|
||||
cycles: the number of cycles the algorithm should run for.
|
||||
population_size: the number of individuals to keep in the population.
|
||||
sample_size: the number of individuals that should participate in each tournament.
|
||||
time_budget: the upper bound of searching cost
|
||||
|
||||
Returns:
|
||||
history: a list of `Model` instances, representing all the models computed
|
||||
during the evolution experiment.
|
||||
"""
|
||||
population = collections.deque()
|
||||
history = [] # Not used by the algorithm, only used to report results.
|
||||
history, total_time_cost = [], 0 # Not used by the algorithm, only used to report results.
|
||||
|
||||
# Initialize the population with random models.
|
||||
while len(population) < population_size:
|
||||
model = Model()
|
||||
model.arch = random_arch()
|
||||
model.accuracy = train_and_eval(model.arch, nas_bench, extra_info)
|
||||
model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info)
|
||||
population.append(model)
|
||||
history.append(model)
|
||||
total_time_cost += time_cost
|
||||
|
||||
# Carry out evolution in cycles. Each cycle produces a model and removes
|
||||
# another.
|
||||
while len(history) < cycles:
|
||||
#while len(history) < cycles:
|
||||
while total_time_cost < time_budget:
|
||||
# Sample randomly chosen models from the current population.
|
||||
sample = []
|
||||
start_time, sample = time.time(), []
|
||||
while len(sample) < sample_size:
|
||||
# Inefficient, but written this way for clarity. In the case of neural
|
||||
# nets, the efficiency of this line is irrelevant because training neural
|
||||
@ -145,13 +148,18 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut
|
||||
# Create the child model and store it.
|
||||
child = Model()
|
||||
child.arch = mutate_arch(parent.arch)
|
||||
child.accuracy = train_and_eval(child.arch, nas_bench, extra_info)
|
||||
total_time_cost += time.time() - start_time
|
||||
child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info)
|
||||
if total_time_cost + time_cost > time_budget: # return
|
||||
return history, total_time_cost
|
||||
else:
|
||||
total_time_cost += time_cost
|
||||
population.append(child)
|
||||
history.append(child)
|
||||
|
||||
# Remove the oldest model.
|
||||
population.popleft()
|
||||
return history
|
||||
return history, total_time_cost
|
||||
|
||||
|
||||
def main(xargs, nas_bench):
|
||||
@ -188,8 +196,9 @@ def main(xargs, nas_bench):
|
||||
mutate_arch = mutate_arch_func(search_space)
|
||||
#x =random_arch() ; y = mutate_arch(x)
|
||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||
history = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch.'.format(time_string(), len(history)))
|
||||
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
|
||||
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s.'.format(time_string(), len(history), total_cost))
|
||||
best_arch = max(history, key=lambda i: i.accuracy)
|
||||
best_arch = best_arch.arch
|
||||
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
||||
@ -216,6 +225,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
|
||||
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
|
||||
parser.add_argument('--ea_fast_by_api', type=int, help='Use our API to speed up the experiments or not.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
|
@ -17,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||
@ -162,7 +163,8 @@ def main(xargs):
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space}, None)
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
|
||||
logger.log('search space : {:}'.format(search_space))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
|
||||
@ -175,6 +177,12 @@ def main(xargs):
|
||||
flop, param = get_model_infos(search_model, xshape)
|
||||
#logger.log('{:}'.format(search_model))
|
||||
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||
logger.log('search-space : {:}'.format(search_space))
|
||||
if xargs.arch_nas_dataset is None:
|
||||
api = None
|
||||
else:
|
||||
api = API(xargs.arch_nas_dataset)
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||
@ -196,7 +204,7 @@ def main(xargs):
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||
@ -205,7 +213,8 @@ def main(xargs):
|
||||
|
||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
|
||||
|
||||
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
|
||||
@ -243,52 +252,23 @@ def main(xargs):
|
||||
}, logger.path('info'), logger)
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
#logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
|
||||
# the final post procedure : count the time
|
||||
start_time = time.time()
|
||||
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
|
||||
search_time.update(time.time() - start_time)
|
||||
network.module.set_cal_mode('dynamic', genotype)
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
|
||||
# sampling
|
||||
"""
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
selected_archs = set()
|
||||
while len(selected_archs) < xargs.select_num:
|
||||
architecture = search_model.dync_genotype()
|
||||
selected_archs.add( architecture )
|
||||
logger.log('select {:} architectures based on the learned arch-parameters'.format( len(selected_archs) ))
|
||||
|
||||
best_arch, best_acc = None, -1
|
||||
state_dict = deepcopy( network.state_dict() )
|
||||
for index, arch in enumerate(selected_archs):
|
||||
with torch.no_grad():
|
||||
search_model.set_cal_mode('dynamic', arch)
|
||||
network.load_state_dict( deepcopy(state_dict) )
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('{:} [{:03d}/{:03d}] : {:125s}, loss={:.3f}, accuracy={:.3f}%'.format(time_string(), index, len(selected_archs), str(arch), valid_a_loss , valid_a_top1))
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc))
|
||||
"""
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
# check the performance from the architecture dataset
|
||||
"""
|
||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
else:
|
||||
nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||
geno = best_arch
|
||||
logger.log('The last model is {:}'.format(geno))
|
||||
info = nas_bench.query_by_arch( geno )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
"""
|
||||
logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotype) ))
|
||||
logger.close()
|
||||
|
||||
|
||||
@ -303,7 +283,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
|
||||
parser.add_argument('--config_path', type=str, help='.')
|
||||
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
parser.add_argument('--config_path', type=str, help='The path of the configuration.')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
|
@ -20,6 +20,9 @@ def get_cell_based_tiny_net(config):
|
||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
||||
if super_type == 'basic' and config.name in group_names:
|
||||
from .cell_searchs import nas_super_nets
|
||||
try:
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
|
||||
except:
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
elif super_type == 'l2s-base' and config.name in group_names:
|
||||
from .l2s_cell_searchs import nas_super_nets
|
||||
|
@ -11,7 +11,8 @@ from .genotypes import Structure
|
||||
|
||||
class TinyNetworkGDAS(nn.Module):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
|
||||
#def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
|
||||
super(TinyNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
@ -13,7 +13,7 @@ from .genotypes import Structure
|
||||
|
||||
class TinyNetworkSETN(nn.Module):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
|
||||
super(TinyNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
@ -31,7 +31,7 @@ class TinyNetworkSETN(nn.Module):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
|
||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||
self.cells.append( cell )
|
||||
|
@ -34,6 +34,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name ${space} \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--tau_max 10 --tau_min 0.1 \
|
||||
--config_path configs/nas-benchmark/algos/GDAS.config \
|
||||
--tau_max 10 --tau_min 0.1 --track_running_stats 1 \
|
||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
@ -35,5 +35,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name ${space} \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--ea_cycles 30 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \
|
||||
--time_budget 12000 \
|
||||
--ea_cycles 100 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
@ -34,5 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name ${space} \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--random_num 100 \
|
||||
--time_budget 12000 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
# --random_num 100 \
|
||||
|
@ -36,6 +36,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \
|
||||
--search_space_name ${space} \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--config_path configs/nas-benchmark/algos/SETN.config \
|
||||
--track_running_stats 1 \
|
||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||
--select_num 100 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
Loading…
Reference in New Issue
Block a user