update affines for NAS
This commit is contained in:
@@ -47,6 +47,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||||
elif mode == 'valid': network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
batch_time, end = AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
@@ -64,7 +65,10 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
|
||||
@@ -87,18 +91,21 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
||||
train_times , valid_times = {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
|
||||
train_loss, train_acc1, train_acc5 = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||
with torch.no_grad():
|
||||
valid_loss, valid_acc1, valid_acc5 = procedure(valid_loader, network, criterion, None, None, 'valid')
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(valid_loader, network, criterion, None, None, 'valid')
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
valid_losses[epoch] = valid_loss
|
||||
valid_acc1es[epoch] = valid_acc1
|
||||
valid_acc5es[epoch] = valid_acc5
|
||||
train_times [epoch] = train_tm
|
||||
valid_times [epoch] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
@@ -114,9 +121,11 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see
|
||||
'train_losses': train_losses,
|
||||
'train_acc1es': train_acc1es,
|
||||
'train_acc5es': train_acc5es,
|
||||
'train_times' : train_times,
|
||||
'valid_losses': valid_losses,
|
||||
'valid_acc1es': valid_acc1es,
|
||||
'valid_acc5es': valid_acc5es,
|
||||
'valid_times' : valid_times,
|
||||
'net_state_dict': net.state_dict(),
|
||||
'net_string' : '{:}'.format(net),
|
||||
'finish-train': True
|
||||
|
Reference in New Issue
Block a user