update affines for NAS

This commit is contained in:
D-X-Y
2019-12-02 18:03:40 +11:00
parent 487fec21bf
commit d175a361bd
9 changed files with 78 additions and 41 deletions

View File

@@ -47,6 +47,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
elif mode == 'valid': network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
batch_time, end = AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
@@ -64,7 +65,10 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
return losses.avg, top1.avg, top5.avg
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
@@ -87,18 +91,21 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
train_times , valid_times = {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
train_loss, train_acc1, train_acc5 = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
with torch.no_grad():
valid_loss, valid_acc1, valid_acc5 = procedure(valid_loader, network, criterion, None, None, 'valid')
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(valid_loader, network, criterion, None, None, 'valid')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
valid_losses[epoch] = valid_loss
valid_acc1es[epoch] = valid_acc1
valid_acc5es[epoch] = valid_acc5
train_times [epoch] = train_tm
valid_times [epoch] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
@@ -114,9 +121,11 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'train_times' : train_times,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'valid_times' : valid_times,
'net_state_dict': net.state_dict(),
'net_string' : '{:}'.format(net),
'finish-train': True