v2
This commit is contained in:
		
							
								
								
									
										25
									
								
								autodl/procedures/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								autodl/procedures/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .starts     import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint | ||||
| from .optimizers import get_optim_scheduler | ||||
| from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed | ||||
| from .funcs_nasbench import pure_evaluate as bench_pure_evaluate | ||||
| from .funcs_nasbench import get_nas_bench_loaders | ||||
|  | ||||
| def get_procedures(procedure): | ||||
|   from .basic_main     import basic_train, basic_valid | ||||
|   from .search_main    import search_train, search_valid | ||||
|   from .search_main_v2 import search_train_v2 | ||||
|   from .simple_KD_main import simple_KD_train, simple_KD_valid | ||||
|  | ||||
|   train_funcs = {'basic' : basic_train, \ | ||||
|                  'search': search_train,'Simple-KD': simple_KD_train, \ | ||||
|                  'search-v2': search_train_v2} | ||||
|   valid_funcs = {'basic' : basic_valid, \ | ||||
|                  'search': search_valid,'Simple-KD': simple_KD_valid, \ | ||||
|                  'search-v2': search_valid} | ||||
|    | ||||
|   train_func  = train_funcs[procedure] | ||||
|   valid_func  = valid_funcs[procedure] | ||||
|   return train_func, valid_func | ||||
							
								
								
									
										75
									
								
								autodl/procedures/basic_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								autodl/procedures/basic_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils     import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): | ||||
|   loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger) | ||||
|   return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger): | ||||
|   with torch.no_grad(): | ||||
|     loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger) | ||||
|   return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): | ||||
|   data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   if mode == 'train': | ||||
|     network.train() | ||||
|   elif mode == 'valid': | ||||
|     network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|    | ||||
|   #logger.log('[{:5s}] config ::  auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) | ||||
|   logger.log('[{:5s}] config ::  auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1)) | ||||
|   end = time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|     # calculate prediction and loss | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|  | ||||
|     features, logits = network(inputs) | ||||
|     if isinstance(logits, list): | ||||
|       assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits)) | ||||
|       logits, logits_aux = logits | ||||
|     else: | ||||
|       logits, logits_aux = logits, None | ||||
|     loss             = criterion(logits, targets) | ||||
|     if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0: | ||||
|       loss_aux = criterion(logits_aux, targets) | ||||
|       loss += config.auxiliary * loss_aux | ||||
|      | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       optimizer.step() | ||||
|  | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),  inputs.size(0)) | ||||
|     top1.update  (prec1.item(), inputs.size(0)) | ||||
|     top5.update  (prec5.item(), inputs.size(0)) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if i % print_freq == 0 or (i+1) == len(xloader): | ||||
|       Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) | ||||
|       if scheduler is not None: | ||||
|         Sstr += ' {:}'.format(scheduler.get_min_info()) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) | ||||
|       Istr = 'Size={:}'.format(list(inputs.size())) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) | ||||
|  | ||||
|   logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) | ||||
|   return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										203
									
								
								autodl/procedures/funcs_nasbench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								autodl/procedures/funcs_nasbench.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,203 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| import os, time, copy, torch, pathlib | ||||
|  | ||||
| import datasets | ||||
| from config_utils import load_config | ||||
| from autodl.procedures   import prepare_seed, get_optim_scheduler | ||||
| from autodl.utils        import get_model_infos, obtain_accuracy | ||||
| from autodl.log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net | ||||
|  | ||||
|  | ||||
| __all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders'] | ||||
|  | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   latencies, device = [], torch.cuda.current_device() | ||||
|   network.eval() | ||||
|   with torch.no_grad(): | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|       targets = targets.cuda(device=device, non_blocking=True) | ||||
|       inputs  = inputs.cuda(device=device, non_blocking=True) | ||||
|       data_time.update(time.time() - end) | ||||
|       # forward | ||||
|       features, logits = network(inputs) | ||||
|       loss             = criterion(logits, targets) | ||||
|       batch_time.update(time.time() - end) | ||||
|       if batch is None or batch == inputs.size(0): | ||||
|         batch = inputs.size(0) | ||||
|         latencies.append( batch_time.val - data_time.val ) | ||||
|       # record loss and accuracy | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       losses.update(loss.item(),  inputs.size(0)) | ||||
|       top1.update  (prec1.item(), inputs.size(0)) | ||||
|       top5.update  (prec5.item(), inputs.size(0)) | ||||
|       end = time.time() | ||||
|   if len(latencies) > 2: latencies = latencies[1:] | ||||
|   return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   if mode == 'train'  : network.train() | ||||
|   elif mode == 'valid': network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|   device = torch.cuda.current_device() | ||||
|   data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|     targets = targets.cuda(device=device, non_blocking=True) | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|     # forward | ||||
|     features, logits = network(inputs) | ||||
|     loss             = criterion(logits, targets) | ||||
|     # backward | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       optimizer.step() | ||||
|     # record loss and accuracy | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),  inputs.size(0)) | ||||
|     top1.update  (prec1.item(), inputs.size(0)) | ||||
|     top5.update  (prec5.item(), inputs.size(0)) | ||||
|     # count time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|   return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger): | ||||
|  | ||||
|   prepare_seed(seed) # random seed | ||||
|   net = get_cell_based_tiny_net(arch_config) | ||||
|   #net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|   flop, param  = get_model_infos(net, opt_config.xshape) | ||||
|   logger.log('Network : {:}'.format(net.get_message()), False) | ||||
|   logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed)) | ||||
|   logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) | ||||
|   # train and valid | ||||
|   optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|   default_device = torch.cuda.current_device() | ||||
|   network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device) | ||||
|   criterion = criterion.cuda(device=default_device) | ||||
|   # start training | ||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup | ||||
|   train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||
|   train_times , valid_times, lrs = {}, {}, {} | ||||
|   for epoch in range(total_epoch): | ||||
|     scheduler.update(epoch, 0.0) | ||||
|     lr = min(scheduler.get_lr()) | ||||
|     train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') | ||||
|     train_losses[epoch] = train_loss | ||||
|     train_acc1es[epoch] = train_acc1  | ||||
|     train_acc5es[epoch] = train_acc5 | ||||
|     train_times [epoch] = train_tm | ||||
|     lrs[epoch] = lr | ||||
|     with torch.no_grad(): | ||||
|       for key, xloder in valid_loaders.items(): | ||||
|         valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder  , network, criterion,      None,      None, 'valid') | ||||
|         valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss | ||||
|         valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1  | ||||
|         valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5 | ||||
|         valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) ) | ||||
|     logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr)) | ||||
|   info_seed = {'flop' : flop, | ||||
|                'param': param, | ||||
|                'arch_config' : arch_config._asdict(), | ||||
|                'opt_config'  : opt_config._asdict(), | ||||
|                'total_epoch' : total_epoch , | ||||
|                'train_losses': train_losses, | ||||
|                'train_acc1es': train_acc1es, | ||||
|                'train_acc5es': train_acc5es, | ||||
|                'train_times' : train_times, | ||||
|                'valid_losses': valid_losses, | ||||
|                'valid_acc1es': valid_acc1es, | ||||
|                'valid_acc5es': valid_acc5es, | ||||
|                'valid_times' : valid_times, | ||||
|                'learning_rates': lrs, | ||||
|                'net_state_dict': net.state_dict(), | ||||
|                'net_string'  : '{:}'.format(net), | ||||
|                'finish-train': True | ||||
|               } | ||||
|   return info_seed | ||||
|  | ||||
|  | ||||
| def get_nas_bench_loaders(workers): | ||||
|  | ||||
|   torch.set_num_threads(workers) | ||||
|  | ||||
|   root_dir  = (pathlib.Path(__file__).parent / '..' / '..').resolve() | ||||
|   torch_dir = pathlib.Path(os.environ['TORCH_HOME']) | ||||
|   # cifar | ||||
|   cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config' | ||||
|   cifar_config = load_config(cifar_config_path, None, None) | ||||
|   get_datasets = datasets.get_datasets  # a function to return the dataset | ||||
|   break_line = '-' * 150 | ||||
|   print ('{:} Create data-loader for all datasets'.format(time_string())) | ||||
|   print (break_line) | ||||
|   TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1) | ||||
|   print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num)) | ||||
|   cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None) | ||||
|   assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14] | ||||
|   temp_dataset = copy.deepcopy(TRAIN_CIFAR10) | ||||
|   temp_dataset.transform = VALID_CIFAR10.transform | ||||
|   # data loader | ||||
|   trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|   train_cifar10_loader    = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True) | ||||
|   valid_cifar10_loader    = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True) | ||||
|   test__cifar10_loader    = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) | ||||
|   print ('CIFAR-10  : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size)) | ||||
|   print (break_line) | ||||
|   # CIFAR-100 | ||||
|   TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1) | ||||
|   print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num)) | ||||
|   cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None) | ||||
|   assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24] | ||||
|   train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) | ||||
|   valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True) | ||||
|   test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True) | ||||
|   print ('CIFAR-100  : train-loader has {:3d} batch'.format(len(train_cifar100_loader))) | ||||
|   print ('CIFAR-100  : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader))) | ||||
|   print ('CIFAR-100  : test--loader has {:3d} batch'.format(len(test__cifar100_loader))) | ||||
|   print (break_line) | ||||
|  | ||||
|   imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config' | ||||
|   imagenet16_config = load_config(imagenet16_config_path, None, None) | ||||
|   TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1) | ||||
|   print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num)) | ||||
|   imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None) | ||||
|   assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20] | ||||
|   train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) | ||||
|   valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True) | ||||
|   test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True) | ||||
|   print ('ImageNet-16-120  : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size)) | ||||
|   print ('ImageNet-16-120  : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size)) | ||||
|   print ('ImageNet-16-120  : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size)) | ||||
|  | ||||
|   # 'cifar10', 'cifar100', 'ImageNet16-120' | ||||
|   loaders = {'cifar10@trainval': trainval_cifar10_loader, | ||||
|              'cifar10@train'   : train_cifar10_loader, | ||||
|              'cifar10@valid'   : valid_cifar10_loader, | ||||
|              'cifar10@test'    : test__cifar10_loader, | ||||
|              'cifar100@train'  : train_cifar100_loader, | ||||
|              'cifar100@valid'  : valid_cifar100_loader, | ||||
|              'cifar100@test'   : test__cifar100_loader, | ||||
|              'ImageNet16-120@train': train_imagenet_loader, | ||||
|              'ImageNet16-120@valid': valid_imagenet_loader, | ||||
|              'ImageNet16-120@test' : test__imagenet_loader} | ||||
|   return loaders | ||||
							
								
								
									
										204
									
								
								autodl/procedures/optimizers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										204
									
								
								autodl/procedures/optimizers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,204 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| from bisect import bisect_right | ||||
| from torch.optim import Optimizer | ||||
|  | ||||
|  | ||||
| class _LRScheduler(object): | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs): | ||||
|     if not isinstance(optimizer, Optimizer): | ||||
|       raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__)) | ||||
|     self.optimizer = optimizer | ||||
|     for group in optimizer.param_groups: | ||||
|       group.setdefault('initial_lr', group['lr']) | ||||
|     self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) | ||||
|     self.max_epochs = epochs | ||||
|     self.warmup_epochs  = warmup_epochs | ||||
|     self.current_epoch  = 0 | ||||
|     self.current_iter   = 0 | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return '' | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__) | ||||
|               + ', {:})'.format(self.extra_repr())) | ||||
|  | ||||
|   def state_dict(self): | ||||
|     return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} | ||||
|  | ||||
|   def load_state_dict(self, state_dict): | ||||
|     self.__dict__.update(state_dict) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     raise NotImplementedError | ||||
|  | ||||
|   def get_min_info(self): | ||||
|     lrs = self.get_lr() | ||||
|     return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter) | ||||
|  | ||||
|   def get_min_lr(self): | ||||
|     return min( self.get_lr() ) | ||||
|  | ||||
|   def update(self, cur_epoch, cur_iter): | ||||
|     if cur_epoch is not None: | ||||
|       assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch) | ||||
|       self.current_epoch = cur_epoch | ||||
|     if cur_iter is not None: | ||||
|       assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter) | ||||
|       self.current_iter  = cur_iter | ||||
|     for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | ||||
|       param_group['lr'] = lr | ||||
|  | ||||
|  | ||||
|  | ||||
| class CosineAnnealingLR(_LRScheduler): | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): | ||||
|     self.T_max = T_max | ||||
|     self.eta_min = eta_min | ||||
|     super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     lrs = [] | ||||
|     for base_lr in self.base_lrs: | ||||
|       if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs: | ||||
|         last_epoch = self.current_epoch - self.warmup_epochs | ||||
|         #if last_epoch < self.T_max: | ||||
|         #if last_epoch < self.max_epochs: | ||||
|         lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2 | ||||
|         #else: | ||||
|         #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 | ||||
|       elif self.current_epoch >= self.max_epochs: | ||||
|         lr = self.eta_min | ||||
|       else: | ||||
|         lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|       lrs.append( lr ) | ||||
|     return lrs | ||||
|  | ||||
|  | ||||
|  | ||||
| class MultiStepLR(_LRScheduler): | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): | ||||
|     assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas)) | ||||
|     self.milestones = milestones | ||||
|     self.gammas     = gammas | ||||
|     super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     lrs = [] | ||||
|     for base_lr in self.base_lrs: | ||||
|       if self.current_epoch >= self.warmup_epochs: | ||||
|         last_epoch = self.current_epoch - self.warmup_epochs | ||||
|         idx = bisect_right(self.milestones, last_epoch) | ||||
|         lr = base_lr | ||||
|         for x in self.gammas[:idx]: lr *= x | ||||
|       else: | ||||
|         lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|       lrs.append( lr ) | ||||
|     return lrs | ||||
|  | ||||
|  | ||||
| class ExponentialLR(_LRScheduler): | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs, gamma): | ||||
|     self.gamma      = gamma | ||||
|     super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     lrs = [] | ||||
|     for base_lr in self.base_lrs: | ||||
|       if self.current_epoch >= self.warmup_epochs: | ||||
|         last_epoch = self.current_epoch - self.warmup_epochs | ||||
|         assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch) | ||||
|         lr = base_lr * (self.gamma ** last_epoch) | ||||
|       else: | ||||
|         lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|       lrs.append( lr ) | ||||
|     return lrs | ||||
|  | ||||
|  | ||||
| class LinearLR(_LRScheduler): | ||||
|  | ||||
|   def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): | ||||
|     self.max_LR = max_LR | ||||
|     self.min_LR = min_LR | ||||
|     super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs) | ||||
|  | ||||
|   def get_lr(self): | ||||
|     lrs = [] | ||||
|     for base_lr in self.base_lrs: | ||||
|       if self.current_epoch >= self.warmup_epochs: | ||||
|         last_epoch = self.current_epoch - self.warmup_epochs | ||||
|         assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch) | ||||
|         ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR | ||||
|         lr = base_lr * (1-ratio) | ||||
|       else: | ||||
|         lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr | ||||
|       lrs.append( lr ) | ||||
|     return lrs | ||||
|  | ||||
|  | ||||
|  | ||||
| class CrossEntropyLabelSmooth(nn.Module): | ||||
|  | ||||
|   def __init__(self, num_classes, epsilon): | ||||
|     super(CrossEntropyLabelSmooth, self).__init__() | ||||
|     self.num_classes = num_classes | ||||
|     self.epsilon = epsilon | ||||
|     self.logsoftmax = nn.LogSoftmax(dim=1) | ||||
|  | ||||
|   def forward(self, inputs, targets): | ||||
|     log_probs = self.logsoftmax(inputs) | ||||
|     targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) | ||||
|     targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | ||||
|     loss = (-targets * log_probs).mean(0).sum() | ||||
|     return loss | ||||
|  | ||||
|  | ||||
|  | ||||
| def get_optim_scheduler(parameters, config): | ||||
|   assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config) | ||||
|   if config.optim == 'SGD': | ||||
|     optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov) | ||||
|   elif config.optim == 'RMSprop': | ||||
|     optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay) | ||||
|   else: | ||||
|     raise ValueError('invalid optim : {:}'.format(config.optim)) | ||||
|  | ||||
|   if config.scheduler == 'cos': | ||||
|     T_max = getattr(config, 'T_max', config.epochs) | ||||
|     scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min) | ||||
|   elif config.scheduler == 'multistep': | ||||
|     scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas) | ||||
|   elif config.scheduler == 'exponential': | ||||
|     scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) | ||||
|   elif config.scheduler == 'linear': | ||||
|     scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min) | ||||
|   else: | ||||
|     raise ValueError('invalid scheduler : {:}'.format(config.scheduler)) | ||||
|  | ||||
|   if config.criterion == 'Softmax': | ||||
|     criterion = torch.nn.CrossEntropyLoss() | ||||
|   elif config.criterion == 'SmoothSoftmax': | ||||
|     criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) | ||||
|   else: | ||||
|     raise ValueError('invalid criterion : {:}'.format(config.criterion)) | ||||
|   return optim, scheduler, criterion | ||||
							
								
								
									
										126
									
								
								autodl/procedures/search_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								autodl/procedures/search_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils     import obtain_accuracy | ||||
| from models    import change_key | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|   expected_flop = torch.mean( expected_flop ) | ||||
|  | ||||
|   if flop_cur < flop_need - flop_tolerant:   # Too Small FLOP | ||||
|     loss = - torch.log( expected_flop ) | ||||
|   #elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|   elif flop_cur > flop_need: # Too Large FLOP | ||||
|     loss = torch.log( expected_flop ) | ||||
|   else: # Required FLOP | ||||
|     loss = None | ||||
|   if loss is None: return 0, 0 | ||||
|   else           : return loss, loss.item() | ||||
|  | ||||
|  | ||||
| def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|   epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant'] | ||||
|  | ||||
|   network.train() | ||||
|   logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight)) | ||||
|   end = time.time() | ||||
|   network.apply( change_key('search_mode', 'search') ) | ||||
|   for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): | ||||
|     scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|     # calculate prediction and loss | ||||
|     base_targets = base_targets.cuda(non_blocking=True) | ||||
|     arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|      | ||||
|     # update the weights | ||||
|     base_optimizer.zero_grad() | ||||
|     logits, expected_flop = network(base_inputs) | ||||
|     #network.apply( change_key('search_mode', 'basic') ) | ||||
|     #features, logits = network(base_inputs) | ||||
|     base_loss = criterion(logits, base_targets) | ||||
|     base_loss.backward() | ||||
|     base_optimizer.step() | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|     base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|     top1.update       (prec1.item(), base_inputs.size(0)) | ||||
|     top5.update       (prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     logits, expected_flop = network(arch_inputs) | ||||
|     flop_cur  = network.module.get_flop('genotype', None, None) | ||||
|     flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) | ||||
|     acls_loss = criterion(logits, arch_targets) | ||||
|     arch_loss = acls_loss + flop_loss * flop_weight | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|    | ||||
|     # record | ||||
|     arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|     arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|     arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0)) | ||||
|      | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|     if step % print_freq == 0 or (step+1) == len(search_loader): | ||||
|       Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5) | ||||
|       Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr) | ||||
|       #Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|       #logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|       #print(network.module.get_arch_info()) | ||||
|       #print(network.module.width_attentions[0]) | ||||
|       #print(network.module.width_attentions[1]) | ||||
|  | ||||
|   logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg)) | ||||
|   return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
|  | ||||
|  | ||||
|  | ||||
| def search_valid(xloader, network, criterion, extra_info, print_freq, logger): | ||||
|   data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|  | ||||
|   network.eval() | ||||
|   network.apply( change_key('search_mode', 'search') ) | ||||
|   end = time.time() | ||||
|   #logger.log('Starting evaluating {:}'.format(epoch_info)) | ||||
|   with torch.no_grad(): | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|       # measure data loading time | ||||
|       data_time.update(time.time() - end) | ||||
|       # calculate prediction and loss | ||||
|       targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|       logits, expected_flop = network(inputs) | ||||
|       loss             = criterion(logits, targets) | ||||
|       # record | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       losses.update(loss.item(),  inputs.size(0)) | ||||
|       top1.update  (prec1.item(), inputs.size(0)) | ||||
|       top5.update  (prec5.item(), inputs.size(0)) | ||||
|  | ||||
|       # measure elapsed time | ||||
|       batch_time.update(time.time() - end) | ||||
|       end = time.time() | ||||
|  | ||||
|       if i % print_freq == 0 or (i+1) == len(xloader): | ||||
|         Sstr = '**VALID** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) | ||||
|         Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|         Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) | ||||
|         Istr = 'Size={:}'.format(list(inputs.size())) | ||||
|         logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) | ||||
|  | ||||
|   logger.log(' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) | ||||
|   | ||||
|   return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										87
									
								
								autodl/procedures/search_main_v2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								autodl/procedures/search_main_v2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils     import obtain_accuracy | ||||
| from models    import change_key | ||||
|  | ||||
|  | ||||
| def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): | ||||
|   expected_flop = torch.mean( expected_flop ) | ||||
|  | ||||
|   if flop_cur < flop_need - flop_tolerant:   # Too Small FLOP | ||||
|     loss = - torch.log( expected_flop ) | ||||
|   #elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP | ||||
|   elif flop_cur > flop_need: # Too Large FLOP | ||||
|     loss = torch.log( expected_flop ) | ||||
|   else: # Required FLOP | ||||
|     loss = None | ||||
|   if loss is None: return 0, 0 | ||||
|   else           : return loss, loss.item() | ||||
|  | ||||
|  | ||||
| def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() | ||||
|   epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant'] | ||||
|  | ||||
|   network.train() | ||||
|   logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight)) | ||||
|   end = time.time() | ||||
|   network.apply( change_key('search_mode', 'search') ) | ||||
|   for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): | ||||
|     scheduler.update(None, 1.0 * step / len(search_loader)) | ||||
|     # calculate prediction and loss | ||||
|     base_targets = base_targets.cuda(non_blocking=True) | ||||
|     arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|      | ||||
|     # update the weights | ||||
|     base_optimizer.zero_grad() | ||||
|     logits, expected_flop = network(base_inputs) | ||||
|     base_loss = criterion(logits, base_targets) | ||||
|     base_loss.backward() | ||||
|     base_optimizer.step() | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) | ||||
|     base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|     top1.update       (prec1.item(), base_inputs.size(0)) | ||||
|     top5.update       (prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     logits, expected_flop = network(arch_inputs) | ||||
|     flop_cur  = network.module.get_flop('genotype', None, None) | ||||
|     flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) | ||||
|     acls_loss = criterion(logits, arch_targets) | ||||
|     arch_loss = acls_loss + flop_loss * flop_weight | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|    | ||||
|     # record | ||||
|     arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|     arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) | ||||
|     arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0)) | ||||
|      | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|     if step % print_freq == 0 or (step+1) == len(search_loader): | ||||
|       Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5) | ||||
|       Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr) | ||||
|       #num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 | ||||
|       #logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6)) | ||||
|       #Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) | ||||
|       #logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) | ||||
|       #print(network.module.get_arch_info()) | ||||
|       #print(network.module.width_attentions[0]) | ||||
|       #print(network.module.width_attentions[1]) | ||||
|  | ||||
|   logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg)) | ||||
|   return base_losses.avg, arch_losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										94
									
								
								autodl/procedures/simple_KD_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								autodl/procedures/simple_KD_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import os, sys, time, torch | ||||
| import torch.nn.functional as F | ||||
| # our modules | ||||
| from log_utils import AverageMeter, time_string | ||||
| from utils     import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): | ||||
|   loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger) | ||||
|   return loss, acc1, acc5 | ||||
|  | ||||
| def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger): | ||||
|   with torch.no_grad(): | ||||
|     loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger) | ||||
|   return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature): | ||||
|   basic_loss = criterion(student_logits, targets) * (1. - alpha) | ||||
|   log_student= F.log_softmax(student_logits / temperature, dim=1) | ||||
|   sof_teacher= F.softmax    (teacher_logits / temperature, dim=1) | ||||
|   KD_loss    = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature) | ||||
|   return basic_loss + KD_loss | ||||
|  | ||||
|  | ||||
| def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): | ||||
|   data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   Ttop1, Ttop5 = AverageMeter(), AverageMeter() | ||||
|   if mode == 'train': | ||||
|     network.train() | ||||
|   elif mode == 'valid': | ||||
|     network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|   teacher.eval() | ||||
|    | ||||
|   logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature)) | ||||
|   end = time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|     # calculate prediction and loss | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|  | ||||
|     student_f, logits = network(inputs) | ||||
|     if isinstance(logits, list): | ||||
|       assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits)) | ||||
|       logits, logits_aux = logits | ||||
|     else: | ||||
|       logits, logits_aux = logits, None | ||||
|     with torch.no_grad(): | ||||
|       teacher_f, teacher_logits = teacher(inputs) | ||||
|  | ||||
|     loss             = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature) | ||||
|     if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0: | ||||
|       loss_aux = criterion(logits_aux, targets) | ||||
|       loss += config.auxiliary * loss_aux | ||||
|      | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       optimizer.step() | ||||
|  | ||||
|     # record | ||||
|     sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),   inputs.size(0)) | ||||
|     top1.update  (sprec1.item(), inputs.size(0)) | ||||
|     top5.update  (sprec5.item(), inputs.size(0)) | ||||
|     # teacher | ||||
|     tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5)) | ||||
|     Ttop1.update (tprec1.item(), inputs.size(0)) | ||||
|     Ttop5.update (tprec5.item(), inputs.size(0)) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if i % print_freq == 0 or (i+1) == len(xloader): | ||||
|       Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) | ||||
|       if scheduler is not None: | ||||
|         Sstr += ' {:}'.format(scheduler.get_min_info()) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) | ||||
|       Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg) | ||||
|       Istr = 'Size={:}'.format(list(inputs.size())) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) | ||||
|  | ||||
|   logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg)) | ||||
|   logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) | ||||
|   return losses.avg, top1.avg, top5.avg | ||||
							
								
								
									
										64
									
								
								autodl/procedures/starts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								autodl/procedures/starts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, torch, random, PIL, copy, numpy as np | ||||
| from os import path as osp | ||||
| from shutil  import copyfile | ||||
|  | ||||
|  | ||||
| def prepare_seed(rand_seed): | ||||
|   random.seed(rand_seed) | ||||
|   np.random.seed(rand_seed) | ||||
|   torch.manual_seed(rand_seed) | ||||
|   torch.cuda.manual_seed(rand_seed) | ||||
|   torch.cuda.manual_seed_all(rand_seed) | ||||
|  | ||||
|  | ||||
| def prepare_logger(xargs): | ||||
|   args = copy.deepcopy( xargs ) | ||||
|   from autodl.log_utils import Logger | ||||
|   logger = Logger(args.save_dir, args.rand_seed) | ||||
|   logger.log('Main Function with logger : {:}'.format(logger)) | ||||
|   logger.log('Arguments : -------------------------------') | ||||
|   for name, value in args._get_kwargs(): | ||||
|     logger.log('{:16} : {:}'.format(name, value)) | ||||
|   logger.log("Python  Version  : {:}".format(sys.version.replace('\n', ' '))) | ||||
|   logger.log("Pillow  Version  : {:}".format(PIL.__version__)) | ||||
|   logger.log("PyTorch Version  : {:}".format(torch.__version__)) | ||||
|   logger.log("cuDNN   Version  : {:}".format(torch.backends.cudnn.version())) | ||||
|   logger.log("CUDA available   : {:}".format(torch.cuda.is_available())) | ||||
|   logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) | ||||
|   logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None')) | ||||
|   return logger | ||||
|  | ||||
|  | ||||
| def get_machine_info(): | ||||
|   info = "Python  Version  : {:}".format(sys.version.replace('\n', ' ')) | ||||
|   info+= "\nPillow  Version  : {:}".format(PIL.__version__) | ||||
|   info+= "\nPyTorch Version  : {:}".format(torch.__version__) | ||||
|   info+= "\ncuDNN   Version  : {:}".format(torch.backends.cudnn.version()) | ||||
|   info+= "\nCUDA available   : {:}".format(torch.cuda.is_available()) | ||||
|   info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) | ||||
|   if 'CUDA_VISIBLE_DEVICES' in os.environ: | ||||
|     info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES']) | ||||
|   else: | ||||
|     info+= "\nDoes not set CUDA_VISIBLE_DEVICES" | ||||
|   return info | ||||
|  | ||||
|  | ||||
| def save_checkpoint(state, filename, logger): | ||||
|   if osp.isfile(filename): | ||||
|     if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename)) | ||||
|     os.remove(filename) | ||||
|   torch.save(state, filename) | ||||
|   assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename) | ||||
|   if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename)) | ||||
|   return filename | ||||
|  | ||||
|  | ||||
| def copy_checkpoint(src, dst, logger): | ||||
|   if osp.isfile(dst): | ||||
|     if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst)) | ||||
|     os.remove(dst) | ||||
|   copyfile(src, dst) | ||||
|   if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst)) | ||||
		Reference in New Issue
	
	Block a user