update NAS-Bench-102 baselines
This commit is contained in:
parent
44a0d51449
commit
1d5e8debad
@ -62,11 +62,12 @@ class MyWorker(Worker):
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
structure = self.convert_func( config )
|
||||
reward = train_and_eval(structure, self.nas_bench, None)
|
||||
reward, time_cost = train_and_eval(structure, self.nas_bench, None)
|
||||
import pdb; pdb.set_trace()
|
||||
self.test_time += 1
|
||||
return ({
|
||||
'loss': float(100-reward),
|
||||
'info': None})
|
||||
'info': time_cost})
|
||||
|
||||
|
||||
def main(xargs, nas_bench):
|
||||
@ -121,7 +122,7 @@ def main(xargs, nas_bench):
|
||||
|
||||
bohb = BOHB(configspace=cs,
|
||||
run_id=hb_run_id,
|
||||
eta=3, min_budget=3, max_budget=108,
|
||||
eta=3, min_budget=3, max_budget=xargs.time_budget,
|
||||
nameserver=ns_host,
|
||||
nameserver_port=ns_port,
|
||||
num_samples=xargs.num_samples,
|
||||
@ -130,6 +131,7 @@ def main(xargs, nas_bench):
|
||||
# optimization_strategy=xargs.strategy, num_samples=xargs.num_samples,
|
||||
|
||||
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
bohb.shutdown(shutdown_workers=True)
|
||||
NS.shutdown()
|
||||
@ -160,6 +162,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
# BOHB
|
||||
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||
|
@ -82,6 +82,16 @@ def valid_func(xloader, network, criterion):
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def search_find_best(valid_loader, network, criterion, select_num):
|
||||
best_arch, best_acc = None, -1
|
||||
for iarch in range(select_num):
|
||||
arch = network.module.random_genotype( True )
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
return best_arch
|
||||
|
||||
|
||||
def main(xargs):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
@ -143,6 +153,7 @@ def main(xargs):
|
||||
last_info = torch.load(last_info)
|
||||
start_epoch = last_info['epoch']
|
||||
checkpoint = torch.load(last_info['last_checkpoint'])
|
||||
genotypes = checkpoint['genotypes']
|
||||
valid_accuracies = checkpoint['valid_accuracies']
|
||||
search_model.load_state_dict( checkpoint['search_model'] )
|
||||
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
|
||||
@ -150,7 +161,7 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies = 0, {'best': -1}
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
@ -160,11 +171,14 @@ def main(xargs):
|
||||
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
|
||||
# selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
|
||||
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
|
||||
genotypes[epoch] = cur_arch
|
||||
# check the best accuracy
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
if valid_a_top1 > valid_accuracies['best']:
|
||||
@ -178,6 +192,7 @@ def main(xargs):
|
||||
'search_model': search_model.state_dict(),
|
||||
'w_optimizer' : w_optimizer.state_dict(),
|
||||
'w_scheduler' : w_scheduler.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'valid_accuracies' : valid_accuracies},
|
||||
model_base_path, logger)
|
||||
last_info = save_checkpoint({
|
||||
@ -188,6 +203,7 @@ def main(xargs):
|
||||
if find_best:
|
||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
@ -202,7 +218,6 @@ def main(xargs):
|
||||
logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
|
||||
|
@ -17,7 +17,7 @@ from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nas_102_api import NASBench102API
|
||||
from nas_102_api import NASBench102API as API
|
||||
from models import CellStructure, get_search_spaces
|
||||
from R_EA import train_and_eval
|
||||
|
||||
@ -132,10 +132,18 @@ def main(xargs, nas_bench):
|
||||
|
||||
# REINFORCE
|
||||
# attempts = 0
|
||||
for istep in range(xargs.RL_steps):
|
||||
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
||||
total_steps, total_costs = 0, 0
|
||||
#for istep in range(xargs.RL_steps):
|
||||
while total_costs < xargs.time_budget:
|
||||
start_time = time.time()
|
||||
log_prob, action = select_action( policy )
|
||||
arch = policy.generate_arch( action )
|
||||
reward = train_and_eval(arch, nas_bench, extra_info)
|
||||
reward, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
||||
# accumulate time
|
||||
if total_costs + cost_time < xargs.time_budget:
|
||||
total_costs += cost_time
|
||||
else: break
|
||||
|
||||
baseline.update(reward)
|
||||
# calculate loss
|
||||
@ -143,13 +151,15 @@ def main(xargs, nas_bench):
|
||||
optimizer.zero_grad()
|
||||
policy_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
logger.log('step [{:3d}/{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(istep, xargs.RL_steps, baseline.value(), policy_loss.item(), policy.genotype()))
|
||||
# accumulate time
|
||||
total_costs += time.time() - start_time
|
||||
total_steps += 1
|
||||
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
|
||||
#logger.log('----> {:}'.format(policy.arch_parameters))
|
||||
logger.log('')
|
||||
#logger.log('')
|
||||
|
||||
best_arch = policy.genotype()
|
||||
|
||||
logger.log('REINFORCE finish with {:} steps and {:.1f} s.'.format(total_steps, total_costs))
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
@ -169,8 +179,9 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
|
||||
parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.')
|
||||
#parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.')
|
||||
parser.add_argument('--EMA_momentum', type=float, help='The momentum value for EMA.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
@ -183,7 +194,7 @@ if __name__ == '__main__':
|
||||
nas_bench = None
|
||||
else:
|
||||
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(args.arch_nas_dataset)
|
||||
nas_bench = API(args.arch_nas_dataset)
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_indexes, num = None, [], 500
|
||||
for i in range(num):
|
||||
|
@ -34,5 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name ${space} \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--n_iters 6 --num_samples 3 \
|
||||
--time_budget 12000 \
|
||||
--n_iters 100 --num_samples 4 --random_fraction 0 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
@ -34,5 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name ${space} \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--learning_rate 0.001 --RL_steps 100 --EMA_momentum 0.9 \
|
||||
--time_budget 12000 \
|
||||
--learning_rate 0.001 --EMA_momentum 0.9 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
Loading…
Reference in New Issue
Block a user