Add more algorithms
This commit is contained in:
		
							
								
								
									
										159
									
								
								exps/KD-main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								exps/KD-main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| import sys, time, torch, random, argparse | ||||
| from PIL     import ImageFile | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, obtain_cls_kd_args as obtain_args | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from procedures   import get_optim_scheduler, get_procedures | ||||
| from datasets     import get_datasets | ||||
| from models       import obtain_model, load_net_from_checkpoint | ||||
| from utils        import get_model_infos | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   torch.backends.cudnn.benchmark = True | ||||
|   #torch.backends.cudnn.deterministic = True | ||||
|   torch.set_num_threads( args.workers ) | ||||
|    | ||||
|   prepare_seed(args.rand_seed) | ||||
|   logger = prepare_logger(args) | ||||
|    | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) | ||||
|   valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) | ||||
|   # get configures | ||||
|   model_config = load_config(args.model_config, {'class_num': class_num}, logger) | ||||
|   optim_config = load_config(args.optim_config, | ||||
|                                 {'class_num': class_num, 'KD_alpha': args.KD_alpha, 'KD_temperature': args.KD_temperature}, | ||||
|                                 logger) | ||||
|  | ||||
|   # load checkpoint | ||||
|   teacher_base = load_net_from_checkpoint(args.KD_checkpoint) | ||||
|   teacher      = torch.nn.DataParallel(teacher_base).cuda() | ||||
|  | ||||
|   base_model   = obtain_model(model_config) | ||||
|   flop, param  = get_model_infos(base_model, xshape) | ||||
|   logger.log('Student ====>>>>:\n{:}'.format(base_model)) | ||||
|   logger.log('Teacher ====>>>>:\n{:}'.format(teacher_base)) | ||||
|   logger.log('model information : {:}'.format(base_model.get_message())) | ||||
|   logger.log('-'*50) | ||||
|   logger.log('Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3)) | ||||
|   logger.log('-'*50) | ||||
|   logger.log('train_data : {:}'.format(train_data)) | ||||
|   logger.log('valid_data : {:}'.format(valid_data)) | ||||
|   optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config) | ||||
|   logger.log('optimizer  : {:}'.format(optimizer)) | ||||
|   logger.log('scheduler  : {:}'.format(scheduler)) | ||||
|   logger.log('criterion  : {:}'.format(criterion)) | ||||
|    | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() | ||||
|  | ||||
|   if last_info.exists(): # automatically resume from previous checkpoint | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) | ||||
|     last_info   = torch.load(last_info) | ||||
|     start_epoch = last_info['epoch'] + 1 | ||||
|     checkpoint  = torch.load(last_info['last_checkpoint']) | ||||
|     base_model.load_state_dict( checkpoint['base-model'] ) | ||||
|     scheduler.load_state_dict ( checkpoint['scheduler'] ) | ||||
|     optimizer.load_state_dict ( checkpoint['optimizer'] ) | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     max_bytes        = checkpoint['max_bytes'] | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||
|   elif args.resume is not None: | ||||
|     assert Path(args.resume).exists(), 'Can not find the resume file : {:}'.format(args.resume) | ||||
|     checkpoint  = torch.load( args.resume ) | ||||
|     start_epoch = checkpoint['epoch'] + 1 | ||||
|     base_model.load_state_dict( checkpoint['base-model'] ) | ||||
|     scheduler.load_state_dict ( checkpoint['scheduler'] ) | ||||
|     optimizer.load_state_dict ( checkpoint['optimizer'] ) | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     max_bytes        = checkpoint['max_bytes'] | ||||
|     logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch)) | ||||
|   elif args.init_model is not None: | ||||
|     assert Path(args.init_model).exists(), 'Can not find the initialization file : {:}'.format(args.init_model) | ||||
|     checkpoint  = torch.load( args.init_model ) | ||||
|     base_model.load_state_dict( checkpoint['base-model'] ) | ||||
|     start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {} | ||||
|     logger.log('=> initialize the model from {:}'.format( args.init_model )) | ||||
|   else: | ||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|     start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {} | ||||
|  | ||||
|   train_func, valid_func = get_procedures(args.procedure) | ||||
|    | ||||
|   total_epoch = optim_config.epochs + optim_config.warmup | ||||
|   # Main Training and Evaluation Loop | ||||
|   start_time  = time.time() | ||||
|   epoch_time  = AverageMeter() | ||||
|   for epoch in range(start_epoch, total_epoch): | ||||
|     scheduler.update(epoch, 0.0) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) ) | ||||
|     epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch) | ||||
|     LRs       = scheduler.get_lr() | ||||
|     find_best = False | ||||
|  | ||||
|     logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler)) | ||||
|      | ||||
|     # train for one epoch | ||||
|     train_loss, train_acc1, train_acc5 = train_func(train_loader, teacher, network, criterion, scheduler, optimizer, optim_config, epoch_str, args.print_freq, logger) | ||||
|     # log the results     | ||||
|     logger.log('***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_loss, train_acc1, train_acc5)) | ||||
|  | ||||
|     # evaluate the performance | ||||
|     if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): | ||||
|       logger.log('-'*150) | ||||
|       valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, teacher, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger) | ||||
|       valid_accuracies[epoch] = valid_acc1 | ||||
|       logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best'])) | ||||
|       if valid_acc1 > valid_accuracies['best']: | ||||
|         valid_accuracies['best'] = valid_acc1 | ||||
|         find_best                = True | ||||
|         logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path)) | ||||
|       num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0 | ||||
|       logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) | ||||
|       max_bytes[epoch] = num_bytes | ||||
|     if epoch % 10 == 0: torch.cuda.empty_cache() | ||||
|  | ||||
|     # save checkpoint | ||||
|     save_path = save_checkpoint({ | ||||
|           'epoch'        : epoch, | ||||
|           'args'         : deepcopy(args), | ||||
|           'max_bytes'    : deepcopy(max_bytes), | ||||
|           'FLOP'         : flop, | ||||
|           'PARAM'        : param, | ||||
|           'valid_accuracies': deepcopy(valid_accuracies), | ||||
|           'model-config' : model_config._asdict(), | ||||
|           'optim-config' : optim_config._asdict(), | ||||
|           'base-model'   : base_model.state_dict(), | ||||
|           'scheduler'    : scheduler.state_dict(), | ||||
|           'optimizer'    : optimizer.state_dict(), | ||||
|           }, model_base_path, logger) | ||||
|     if find_best: copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     last_info = save_checkpoint({ | ||||
|           'epoch': epoch, | ||||
|           'args' : deepcopy(args), | ||||
|           'last_checkpoint': save_path, | ||||
|           }, logger.path('info'), logger) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('\n' + '-'*200) | ||||
|   logger.log('||| Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3)) | ||||
|   logger.log('Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path('info'))) | ||||
|   logger.log('-'*200 + '\n') | ||||
|   logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   args = obtain_args() | ||||
|   main(args) | ||||
							
								
								
									
										68
									
								
								exps/basic-eval.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								exps/basic-eval.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,68 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch, random, argparse | ||||
| from PIL     import ImageFile | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, dict2config | ||||
| from procedures   import get_procedures, get_optim_scheduler | ||||
| from datasets     import get_datasets | ||||
| from models       import obtain_model | ||||
| from utils        import get_model_infos | ||||
| from log_utils    import PrintLogger, time_string | ||||
|  | ||||
|  | ||||
| assert torch.cuda.is_available(), 'torch.cuda is not available' | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|  | ||||
|   assert os.path.isdir ( args.data_path ) , 'invalid data-path : {:}'.format(args.data_path) | ||||
|   assert os.path.isfile( args.checkpoint ), 'invalid checkpoint : {:}'.format(args.checkpoint) | ||||
|  | ||||
|   checkpoint = torch.load( args.checkpoint ) | ||||
|   xargs      = checkpoint['args'] | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, args.data_path, xargs.cutout_length) | ||||
|   valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=xargs.batch_size, shuffle=False, num_workers=xargs.workers, pin_memory=True) | ||||
|  | ||||
|   logger       = PrintLogger() | ||||
|   model_config = dict2config(checkpoint['model-config'], logger) | ||||
|   base_model   = obtain_model(model_config) | ||||
|   flop, param  = get_model_infos(base_model, xshape) | ||||
|   logger.log('model ====>>>>:\n{:}'.format(base_model)) | ||||
|   logger.log('model information : {:}'.format(base_model.get_message())) | ||||
|   logger.log('-'*50) | ||||
|   logger.log('Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3)) | ||||
|   logger.log('-'*50) | ||||
|   logger.log('valid_data : {:}'.format(valid_data)) | ||||
|   optim_config = dict2config(checkpoint['optim-config'], logger) | ||||
|   _, _, criterion = get_optim_scheduler(base_model.parameters(), optim_config) | ||||
|   logger.log('criterion  : {:}'.format(criterion)) | ||||
|   base_model.load_state_dict( checkpoint['base-model'] ) | ||||
|   _, valid_func = get_procedures(xargs.procedure) | ||||
|   logger.log('initialize the CNN done, evaluate it using {:}'.format(valid_func)) | ||||
|   network = torch.nn.DataParallel(base_model).cuda() | ||||
|    | ||||
|   try: | ||||
|     valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, network, criterion, optim_config, 'pure-evaluation', xargs.print_freq_eval, logger) | ||||
|   except: | ||||
|     _, valid_func = get_procedures('basic') | ||||
|     valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, network, criterion, optim_config, 'pure-evaluation', xargs.print_freq_eval, logger) | ||||
|    | ||||
|   num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0 | ||||
|   logger.log('***{:s}*** EVALUATION loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f}, error@1 = {:.2f}, error@5 = {:.2f}'.format(time_string(), valid_loss, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5)) | ||||
|   logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) | ||||
|   logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("Evaluate-CNN") | ||||
|   parser.add_argument('--data_path',         type=str,   help='Path to dataset.') | ||||
|   parser.add_argument('--checkpoint',        type=str,   help='Choose between Cifar10/100 and ImageNet.') | ||||
|   args = parser.parse_args() | ||||
|   main(args) | ||||
							
								
								
									
										165
									
								
								exps/basic-main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								exps/basic-main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,165 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import sys, time, torch, random, argparse | ||||
| from PIL     import ImageFile | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, obtain_basic_args as obtain_args | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from procedures   import get_optim_scheduler, get_procedures | ||||
| from datasets     import get_datasets | ||||
| from models       import obtain_model | ||||
| from nas_infer_model import obtain_nas_infer_model | ||||
| from utils        import get_model_infos | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   torch.backends.cudnn.benchmark = True | ||||
|   #torch.backends.cudnn.deterministic = True | ||||
|   torch.set_num_threads( args.workers ) | ||||
|    | ||||
|   prepare_seed(args.rand_seed) | ||||
|   logger = prepare_logger(args) | ||||
|    | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) | ||||
|   valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) | ||||
|   # get configures | ||||
|   model_config = load_config(args.model_config, {'class_num': class_num}, logger) | ||||
|   optim_config = load_config(args.optim_config, {'class_num': class_num}, logger) | ||||
|  | ||||
|   if args.model_source == 'normal': | ||||
|     base_model   = obtain_model(model_config) | ||||
|   elif args.model_source == 'nas': | ||||
|     base_model   = obtain_nas_infer_model(model_config) | ||||
|   else: | ||||
|     raise ValueError('invalid model-source : {:}'.format(args.model_source)) | ||||
|   flop, param  = get_model_infos(base_model, xshape) | ||||
|   logger.log('model ====>>>>:\n{:}'.format(base_model)) | ||||
|   logger.log('model information : {:}'.format(base_model.get_message())) | ||||
|   logger.log('-'*50) | ||||
|   logger.log('Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format(param, flop, flop/1e3)) | ||||
|   logger.log('-'*50) | ||||
|   logger.log('train_data : {:}'.format(train_data)) | ||||
|   logger.log('valid_data : {:}'.format(valid_data)) | ||||
|   optimizer, scheduler, criterion = get_optim_scheduler(base_model.parameters(), optim_config) | ||||
|   logger.log('optimizer  : {:}'.format(optimizer)) | ||||
|   logger.log('scheduler  : {:}'.format(scheduler)) | ||||
|   logger.log('criterion  : {:}'.format(criterion)) | ||||
|    | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda() | ||||
|  | ||||
|   if last_info.exists(): # automatically resume from previous checkpoint | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) | ||||
|     last_infox  = torch.load(last_info) | ||||
|     start_epoch = last_infox['epoch'] + 1 | ||||
|     last_checkpoint_path = last_infox['last_checkpoint'] | ||||
|     if not last_checkpoint_path.exists(): | ||||
|       logger.log('Does not find {:}, try another path'.format(last_checkpoint_path)) | ||||
|       last_checkpoint_path = last_info.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name | ||||
|     checkpoint  = torch.load( last_checkpoint_path ) | ||||
|     base_model.load_state_dict( checkpoint['base-model'] ) | ||||
|     scheduler.load_state_dict ( checkpoint['scheduler'] ) | ||||
|     optimizer.load_state_dict ( checkpoint['optimizer'] ) | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     max_bytes        = checkpoint['max_bytes'] | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||
|   elif args.resume is not None: | ||||
|     assert Path(args.resume).exists(), 'Can not find the resume file : {:}'.format(args.resume) | ||||
|     checkpoint  = torch.load( args.resume ) | ||||
|     start_epoch = checkpoint['epoch'] + 1 | ||||
|     base_model.load_state_dict( checkpoint['base-model'] ) | ||||
|     scheduler.load_state_dict ( checkpoint['scheduler'] ) | ||||
|     optimizer.load_state_dict ( checkpoint['optimizer'] ) | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     max_bytes        = checkpoint['max_bytes'] | ||||
|     logger.log("=> loading checkpoint from '{:}' start with {:}-th epoch.".format(args.resume, start_epoch)) | ||||
|   elif args.init_model is not None: | ||||
|     assert Path(args.init_model).exists(), 'Can not find the initialization file : {:}'.format(args.init_model) | ||||
|     checkpoint  = torch.load( args.init_model ) | ||||
|     base_model.load_state_dict( checkpoint['base-model'] ) | ||||
|     start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {} | ||||
|     logger.log('=> initialize the model from {:}'.format( args.init_model )) | ||||
|   else: | ||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|     start_epoch, valid_accuracies, max_bytes = 0, {'best': -1}, {} | ||||
|  | ||||
|   train_func, valid_func = get_procedures(args.procedure) | ||||
|    | ||||
|   total_epoch = optim_config.epochs + optim_config.warmup | ||||
|   # Main Training and Evaluation Loop | ||||
|   start_time  = time.time() | ||||
|   epoch_time  = AverageMeter() | ||||
|   for epoch in range(start_epoch, total_epoch): | ||||
|     scheduler.update(epoch, 0.0) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) ) | ||||
|     epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch) | ||||
|     LRs       = scheduler.get_lr() | ||||
|     find_best = False | ||||
|     # set-up drop-out ratio | ||||
|     if hasattr(base_model, 'update_drop_path'): base_model.update_drop_path(model_config.drop_path_prob * epoch / total_epoch) | ||||
|     logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler)) | ||||
|      | ||||
|     # train for one epoch | ||||
|     train_loss, train_acc1, train_acc5 = train_func(train_loader, network, criterion, scheduler, optimizer, optim_config, epoch_str, args.print_freq, logger) | ||||
|     # log the results     | ||||
|     logger.log('***{:s}*** TRAIN [{:}] loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_loss, train_acc1, train_acc5)) | ||||
|  | ||||
|     # evaluate the performance | ||||
|     if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): | ||||
|       logger.log('-'*150) | ||||
|       valid_loss, valid_acc1, valid_acc5 = valid_func(valid_loader, network, criterion, optim_config, epoch_str, args.print_freq_eval, logger) | ||||
|       valid_accuracies[epoch] = valid_acc1 | ||||
|       logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best'])) | ||||
|       if valid_acc1 > valid_accuracies['best']: | ||||
|         valid_accuracies['best'] = valid_acc1 | ||||
|         find_best                = True | ||||
|         logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path)) | ||||
|       num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0 | ||||
|       logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) | ||||
|       max_bytes[epoch] = num_bytes | ||||
|     if epoch % 10 == 0: torch.cuda.empty_cache() | ||||
|  | ||||
|     # save checkpoint | ||||
|     save_path = save_checkpoint({ | ||||
|           'epoch'        : epoch, | ||||
|           'args'         : deepcopy(args), | ||||
|           'max_bytes'    : deepcopy(max_bytes), | ||||
|           'FLOP'         : flop, | ||||
|           'PARAM'        : param, | ||||
|           'valid_accuracies': deepcopy(valid_accuracies), | ||||
|           'model-config' : model_config._asdict(), | ||||
|           'optim-config' : optim_config._asdict(), | ||||
|           'base-model'   : base_model.state_dict(), | ||||
|           'scheduler'    : scheduler.state_dict(), | ||||
|           'optimizer'    : optimizer.state_dict(), | ||||
|           }, model_base_path, logger) | ||||
|     if find_best: copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     last_info = save_checkpoint({ | ||||
|           'epoch': epoch, | ||||
|           'args' : deepcopy(args), | ||||
|           'last_checkpoint': save_path, | ||||
|           }, logger.path('info'), logger) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('\n' + '-'*200) | ||||
|   logger.log('Finish training/validation in {:} with Max-GPU-Memory of {:.2f} MB, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e6, logger.path('info'))) | ||||
|   logger.log('-'*200 + '\n') | ||||
|   logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   args = obtain_args() | ||||
|   main(args) | ||||
							
								
								
									
										78
									
								
								exps/compare.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								exps/compare.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,78 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # python exps/compare.py --checkpoints basic.pth order.pth --names basic order --save ./output/vis/basic-vs-order.pdf | ||||
| import sys, time, torch, random, argparse | ||||
| from PIL     import ImageFile | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
| import numpy as np | ||||
| import matplotlib | ||||
| matplotlib.use('agg') | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='Visualize the checkpoint and compare', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
| parser.add_argument('--checkpoints', type=str,    nargs='+',     help='checkpoint paths.') | ||||
| parser.add_argument('--names',       type=str,    nargs='+',     help='names.') | ||||
| parser.add_argument('--save',        type=str,                   help='the save path.') | ||||
| args = parser.parse_args() | ||||
|  | ||||
|  | ||||
| def visualize_acc(epochs, accuracies, names, save_path): | ||||
|  | ||||
|   LabelSize = 24 | ||||
|   LegendFontsize = 22 | ||||
|   matplotlib.rcParams['xtick.labelsize'] = LabelSize  | ||||
|   matplotlib.rcParams['ytick.labelsize'] = LabelSize  | ||||
|   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||
|   dpi = 300 | ||||
|   width, height = 3400, 3600 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|  | ||||
|   fig = plt.figure(figsize=figsize) | ||||
|   plt.xlim(0, max(epochs)) | ||||
|   plt.ylim(0, 100) | ||||
|   interval_x, interval_y = 20, 10 | ||||
|   plt.xticks(np.arange(0, max(epochs) + interval_x, interval_x), fontsize=LegendFontsize) | ||||
|   plt.yticks(np.arange(0, 100 + interval_y, interval_y), fontsize=LegendFontsize) | ||||
|   plt.grid() | ||||
|    | ||||
|   plt.xlabel('epoch', fontsize=16) | ||||
|   plt.ylabel('accuracy (%)', fontsize=16) | ||||
|  | ||||
|   for idx, tag in enumerate(names): | ||||
|     xaccs = [accuracies[idx][x] for x in epochs] | ||||
|     plt.plot(epochs, xaccs, color=color_set[idx], linestyle='-', label='Test Accuracy : {:}'.format(tag), lw=3) | ||||
|     plt.legend(loc=4, fontsize=LegendFontsize) | ||||
|    | ||||
|   if save_path is not None: | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||
|     print ('---- save figure into {:}.'.format(save_path)) | ||||
|   plt.close(fig) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|   checkpoints, names = args.checkpoints, args.names | ||||
|   assert len(checkpoints) == len(names), 'invalid length : {:} vs {:}'.format(len(checkpoints), len(names)) | ||||
|   for i, checkpoint in enumerate(checkpoints): | ||||
|     assert Path(checkpoint).exists(), 'The {:}-th checkpoint : {:} does not exist'.format( checkpoint ) | ||||
|  | ||||
|   save_path = Path(args.save) | ||||
|   save_dir  = save_path.parent | ||||
|   save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   accuracies = [] | ||||
|   for checkpoint in checkpoints: | ||||
|     checkpoint = torch.load( checkpoint ) | ||||
|     accuracies.append( checkpoint['valid_accuracies'] ) | ||||
|   epochs = [x for x in accuracies[0].keys() if isinstance(x, int)] | ||||
|   epochs = sorted( epochs ) | ||||
|     | ||||
|   visualize_acc(epochs, accuracies, names, save_path) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main() | ||||
							
								
								
									
										77
									
								
								exps/prepare.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								exps/prepare.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,77 @@ | ||||
| # python exps/prepare.py --name cifar10     --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth | ||||
| # python exps/prepare.py --name cifar100    --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth | ||||
| # python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012   --save ./data/imagenet-1k.split.pth | ||||
| import sys, time, torch, random, argparse | ||||
| from collections import defaultdict | ||||
| import os.path as osp | ||||
| from PIL     import ImageFile | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
| import torchvision | ||||
| import torchvision.datasets as dset | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| parser = argparse.ArgumentParser(description='Prepare splits for searching', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
| parser.add_argument('--name' , type=str,    help='The dataset name.') | ||||
| parser.add_argument('--root' , type=str,    help='The directory to the dataset.') | ||||
| parser.add_argument('--save' , type=str,    help='The save path.') | ||||
| parser.add_argument('--ratio', type=float,  help='The save path.') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| def main(): | ||||
|   save_path = Path(args.save) | ||||
|   save_dir  = save_path.parent | ||||
|   name      = args.name | ||||
|   save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   assert not save_path.exists(), '{:} already exists'.format(save_path) | ||||
|   print ('torchvision version : {:}'.format(torchvision.__version__)) | ||||
|  | ||||
|   if name == 'cifar10': | ||||
|     dataset = dset.CIFAR10 (args.root, train=True) | ||||
|   elif name == 'cifar100': | ||||
|     dataset = dset.CIFAR100(args.root, train=True) | ||||
|   elif name == 'imagenet-1k': | ||||
|     dataset = dset.ImageFolder(osp.join(args.root, 'train')) | ||||
|   else: raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|   if hasattr(dataset, 'targets'): | ||||
|     targets = dataset.targets | ||||
|   elif hasattr(dataset, 'train_labels'): | ||||
|     targets = dataset.train_labels | ||||
|   elif hasattr(dataset, 'imgs'): | ||||
|     targets = [x[1] for x in dataset.imgs] | ||||
|   else: | ||||
|     raise ValueError('invalid pattern') | ||||
|   print ('There are {:} samples in this dataset.'.format( len(targets) )) | ||||
|  | ||||
|   class2index = defaultdict(list) | ||||
|   train, valid = [], [] | ||||
|   random.seed(111) | ||||
|   for index, cls in enumerate(targets): | ||||
|     class2index[cls].append( index ) | ||||
|   classes = sorted( list(class2index.keys()) ) | ||||
|   for cls in classes: | ||||
|     xlist = class2index[cls] | ||||
|     xtrain = random.sample(xlist, int(len(xlist)*args.ratio)) | ||||
|     xvalid = list(set(xlist) - set(xtrain)) | ||||
|     train += xtrain | ||||
|     valid += xvalid | ||||
|   train.sort() | ||||
|   valid.sort() | ||||
|   ## for statistics | ||||
|   class2numT, class2numV = defaultdict(int), defaultdict(int) | ||||
|   for index in train: | ||||
|     class2numT[ targets[index] ] += 1 | ||||
|   for index in valid: | ||||
|     class2numV[ targets[index] ] += 1 | ||||
|   class2numT, class2numV = dict(class2numT), dict(class2numV) | ||||
|   torch.save({'train': train, | ||||
|               'valid': valid, | ||||
|               'class2numTrain': class2numT, | ||||
|               'class2numValid': class2numV}, save_path) | ||||
|   print ('-'*80) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main() | ||||
							
								
								
									
										201
									
								
								exps/search-shape.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								exps/search-shape.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,201 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import sys, time, torch, random, argparse | ||||
| from PIL     import ImageFile | ||||
| from os      import path as osp | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| import numpy as np | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, configure2str, obtain_search_single_args as obtain_args | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from procedures   import get_optim_scheduler, get_procedures | ||||
| from datasets     import get_datasets, SearchDataset | ||||
| from models       import obtain_search_model, obtain_model, change_key | ||||
| from utils        import get_model_infos | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   torch.backends.cudnn.benchmark = True | ||||
|   #torch.backends.cudnn.deterministic = True | ||||
|   torch.set_num_threads( args.workers ) | ||||
|    | ||||
|   prepare_seed(args.rand_seed) | ||||
|   logger = prepare_logger(args) | ||||
|    | ||||
|   # prepare dataset | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) | ||||
|   #train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) | ||||
|   valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) | ||||
|  | ||||
|   split_file_path = Path(args.split_path) | ||||
|   assert split_file_path.exists(), '{:} does not exist'.format(split_file_path) | ||||
|   split_info      = torch.load(split_file_path) | ||||
|  | ||||
|   train_split, valid_split = split_info['train'], split_info['valid'] | ||||
|   assert len( set(train_split).intersection( set(valid_split) ) ) == 0, 'There should be 0 element that belongs to both train and valid' | ||||
|   assert len(train_split) + len(valid_split) == len(train_data), '{:} + {:} vs {:}'.format(len(train_split), len(valid_split), len(train_data)) | ||||
|   search_dataset  = SearchDataset(args.dataset, train_data, train_split, valid_split) | ||||
|    | ||||
|   search_train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), pin_memory=True, num_workers=args.workers) | ||||
|   search_valid_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), pin_memory=True, num_workers=args.workers) | ||||
|   search_loader       = torch.utils.data.DataLoader(search_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) | ||||
|   # get configures | ||||
|   model_config = load_config(args.model_config, {'class_num': class_num, 'search_mode': args.search_shape}, logger) | ||||
|  | ||||
|   # obtain the model | ||||
|   search_model = obtain_search_model(model_config) | ||||
|   MAX_FLOP, param  = get_model_infos(search_model, xshape) | ||||
|   optim_config = load_config(args.optim_config, {'class_num': class_num, 'FLOP': MAX_FLOP}, logger) | ||||
|   logger.log('Model Information : {:}'.format(search_model.get_message())) | ||||
|   logger.log('MAX_FLOP = {:} M'.format(MAX_FLOP)) | ||||
|   logger.log('Params   = {:} M'.format(param)) | ||||
|   logger.log('train_data : {:}'.format(train_data)) | ||||
|   logger.log('search-data: {:}'.format(search_dataset)) | ||||
|   logger.log('search_train_loader : {:} samples'.format( len(train_split) )) | ||||
|   logger.log('search_valid_loader : {:} samples'.format( len(valid_split) )) | ||||
|   base_optimizer, scheduler, criterion = get_optim_scheduler(search_model.base_parameters(), optim_config) | ||||
|   arch_optimizer = torch.optim.Adam(search_model.arch_parameters(), lr=optim_config.arch_LR, betas=(0.5, 0.999), weight_decay=optim_config.arch_decay) | ||||
|   logger.log('base-optimizer : {:}'.format(base_optimizer)) | ||||
|   logger.log('arch-optimizer : {:}'.format(arch_optimizer)) | ||||
|   logger.log('scheduler      : {:}'.format(scheduler)) | ||||
|   logger.log('criterion      : {:}'.format(criterion)) | ||||
|    | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|   # load checkpoint | ||||
|   if last_info.exists() or (args.resume is not None and osp.isfile(args.resume)): # automatically resume from previous checkpoint | ||||
|     if args.resume is not None and osp.isfile(args.resume): | ||||
|       resume_path = Path(args.resume) | ||||
|     elif last_info.exists(): | ||||
|       resume_path = last_info | ||||
|     else: raise ValueError('Something is wrong.') | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start".format(resume_path)) | ||||
|     checkpoint  = torch.load(resume_path) | ||||
|     if 'last_checkpoint' in checkpoint: | ||||
|       last_checkpoint_path = checkpoint['last_checkpoint'] | ||||
|       if not last_checkpoint_path.exists(): | ||||
|         logger.log('Does not find {:}, try another path'.format(last_checkpoint_path)) | ||||
|         last_checkpoint_path = resume_path.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name | ||||
|       assert last_checkpoint_path.exists(), 'can not find the checkpoint from {:}'.format(last_checkpoint_path) | ||||
|       checkpoint = torch.load( last_checkpoint_path ) | ||||
|     start_epoch = checkpoint['epoch'] + 1 | ||||
|     search_model.load_state_dict( checkpoint['search_model'] ) | ||||
|     scheduler.load_state_dict ( checkpoint['scheduler'] ) | ||||
|     base_optimizer.load_state_dict ( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict ( checkpoint['arch_optimizer'] ) | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     arch_genotypes   = checkpoint['arch_genotypes'] | ||||
|     discrepancies    = checkpoint['discrepancies'] | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(resume_path, start_epoch)) | ||||
|   else: | ||||
|     logger.log("=> do not find the last-info file : {:} or resume : {:}".format(last_info, args.resume)) | ||||
|     start_epoch, valid_accuracies, arch_genotypes, discrepancies = 0, {'best': -1}, {}, {} | ||||
|  | ||||
|   # main procedure | ||||
|   train_func, valid_func = get_procedures(args.procedure) | ||||
|   total_epoch = optim_config.epochs + optim_config.warmup | ||||
|   start_time, epoch_time = time.time(), AverageMeter() | ||||
|   for epoch in range(start_epoch, total_epoch): | ||||
|     scheduler.update(epoch, 0.0) | ||||
|     search_model.set_tau(args.gumbel_tau_max, args.gumbel_tau_min, epoch*1.0/total_epoch) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) ) | ||||
|     epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch) | ||||
|     LRs       = scheduler.get_lr() | ||||
|     find_best = False | ||||
|     | ||||
|     logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}, tau={:}, FLOP={:.2f}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler, search_model.tau, MAX_FLOP)) | ||||
|  | ||||
|     # train for one epoch | ||||
|     train_base_loss, train_arch_loss, train_acc1, train_acc5 = train_func(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, \ | ||||
|                                                                                 {'epoch-str'  : epoch_str,        'FLOP-exp': MAX_FLOP * args.FLOP_ratio, | ||||
|                                                                                  'FLOP-weight': args.FLOP_weight, 'FLOP-tolerant': MAX_FLOP * args.FLOP_tolerant}, args.print_freq, logger) | ||||
|     # log the results | ||||
|     logger.log('***{:s}*** TRAIN [{:}] base-loss = {:.6f}, arch-loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_base_loss, train_arch_loss, train_acc1, train_acc5)) | ||||
|     cur_FLOP, genotype = search_model.get_flop('genotype', model_config._asdict(), None) | ||||
|     arch_genotypes[epoch]  = genotype | ||||
|     arch_genotypes['last'] = genotype | ||||
|     logger.log('[{:}] genotype : {:}'.format(epoch_str, genotype)) | ||||
|     arch_info, discrepancy = search_model.get_arch_info() | ||||
|     logger.log(arch_info) | ||||
|     discrepancies[epoch]   = discrepancy | ||||
|     logger.log('[{:}] FLOP : {:.2f} MB, ratio : {:.4f}, Expected-ratio : {:.4f}, Discrepancy : {:.3f}'.format(epoch_str, cur_FLOP, cur_FLOP/MAX_FLOP, args.FLOP_ratio, np.mean(discrepancy))) | ||||
|  | ||||
|     #if cur_FLOP/MAX_FLOP > args.FLOP_ratio: | ||||
|     #  init_flop_weight = init_flop_weight * args.FLOP_decay | ||||
|     #else: | ||||
|     #  init_flop_weight = init_flop_weight / args.FLOP_decay | ||||
|      | ||||
|     # evaluate the performance | ||||
|     if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): | ||||
|       logger.log('-'*150) | ||||
|       valid_loss, valid_acc1, valid_acc5 = valid_func(search_valid_loader, network, criterion, epoch_str, args.print_freq_eval, logger) | ||||
|       valid_accuracies[epoch] = valid_acc1 | ||||
|       logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best'])) | ||||
|       if valid_acc1 > valid_accuracies['best']: | ||||
|         valid_accuracies['best'] = valid_acc1 | ||||
|         arch_genotypes['best']   = genotype | ||||
|         find_best                = True | ||||
|         logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path)) | ||||
|  | ||||
|     # save checkpoint | ||||
|     save_path = save_checkpoint({ | ||||
|           'epoch'        : epoch, | ||||
|           'args'         : deepcopy(args), | ||||
|           'valid_accuracies': deepcopy(valid_accuracies), | ||||
|           'model-config' : model_config._asdict(), | ||||
|           'optim-config' : optim_config._asdict(), | ||||
|           'search_model' : search_model.state_dict(), | ||||
|           'scheduler'    : scheduler.state_dict(), | ||||
|           'base_optimizer': base_optimizer.state_dict(), | ||||
|           'arch_optimizer': arch_optimizer.state_dict(), | ||||
|           'arch_genotypes': arch_genotypes, | ||||
|           'discrepancies' : discrepancies, | ||||
|           }, model_base_path, logger) | ||||
|     if find_best: copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     last_info = save_checkpoint({ | ||||
|           'epoch': epoch, | ||||
|           'args' : deepcopy(args), | ||||
|           'last_checkpoint': save_path, | ||||
|           }, logger.path('info'), logger) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|      | ||||
|  | ||||
|   logger.log('') | ||||
|   logger.log('-'*100) | ||||
|   last_config_path = logger.path('log') / 'seed-{:}-last.config'.format(args.rand_seed) | ||||
|   configure2str(arch_genotypes['last'], str(last_config_path)) | ||||
|   logger.log('save the last config int {:} :\n{:}'.format(last_config_path, arch_genotypes['last'])) | ||||
|  | ||||
|   best_arch, valid_acc = arch_genotypes['best'], valid_accuracies['best'] | ||||
|   for key, config in arch_genotypes.items(): | ||||
|     if key == 'last': continue | ||||
|     FLOP_ratio = config['estimated_FLOP'] / MAX_FLOP | ||||
|     if abs(FLOP_ratio - args.FLOP_ratio) <= args.FLOP_tolerant: | ||||
|       if valid_acc < valid_accuracies[key]: | ||||
|         best_arch, valid_acc = config, valid_accuracies[key] | ||||
|   print('Best-Arch : {:}\nRatio={:}, Valid-ACC={:}'.format(best_arch, best_arch['estimated_FLOP'] / MAX_FLOP, valid_acc)) | ||||
|   best_config_path = logger.path('log') / 'seed-{:}-best.config'.format(args.rand_seed) | ||||
|   configure2str(best_arch, str(best_config_path)) | ||||
|   logger.log('save the last config int {:} :\n{:}'.format(best_config_path, best_arch)) | ||||
|   logger.log('\n' + '-'*200) | ||||
|   logger.log('Finish training/validation in {:}, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), logger.path('info'))) | ||||
|   logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   args = obtain_args() | ||||
|   main(args) | ||||
							
								
								
									
										215
									
								
								exps/search-transformable.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										215
									
								
								exps/search-transformable.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,215 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import sys, time, torch, random, argparse | ||||
| from PIL     import ImageFile | ||||
| from os      import path as osp | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| import numpy as np | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, configure2str, obtain_search_args as obtain_args | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from procedures   import get_optim_scheduler, get_procedures | ||||
| from datasets     import get_datasets, SearchDataset | ||||
| from models       import obtain_search_model, obtain_model, change_key | ||||
| from utils        import get_model_infos | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   torch.backends.cudnn.benchmark = True | ||||
|   #torch.backends.cudnn.deterministic = True | ||||
|   torch.set_num_threads( args.workers ) | ||||
|    | ||||
|   prepare_seed(args.rand_seed) | ||||
|   logger = prepare_logger(args) | ||||
|    | ||||
|   # prepare dataset | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_path, args.cutout_length) | ||||
|   #train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True , num_workers=args.workers, pin_memory=True) | ||||
|   valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) | ||||
|  | ||||
|   split_file_path = Path(args.split_path) | ||||
|   assert split_file_path.exists(), '{:} does not exist'.format(split_file_path) | ||||
|   split_info      = torch.load(split_file_path) | ||||
|  | ||||
|   train_split, valid_split = split_info['train'], split_info['valid'] | ||||
|   assert len( set(train_split).intersection( set(valid_split) ) ) == 0, 'There should be 0 element that belongs to both train and valid' | ||||
|   assert len(train_split) + len(valid_split) == len(train_data), '{:} + {:} vs {:}'.format(len(train_split), len(valid_split), len(train_data)) | ||||
|   search_dataset  = SearchDataset(args.dataset, train_data, train_split, valid_split) | ||||
|    | ||||
|   search_train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), pin_memory=True, num_workers=args.workers) | ||||
|   search_valid_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), pin_memory=True, num_workers=args.workers) | ||||
|   search_loader       = torch.utils.data.DataLoader(search_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) | ||||
|   # get configures | ||||
|   if args.ablation_num_select is None or args.ablation_num_select <= 0: | ||||
|     model_config = load_config(args.model_config, {'class_num': class_num, 'search_mode': 'shape'}, logger) | ||||
|   else: | ||||
|     model_config = load_config(args.model_config, {'class_num': class_num, 'search_mode': 'ablation', 'num_random_select': args.ablation_num_select}, logger) | ||||
|  | ||||
|   # obtain the model | ||||
|   search_model = obtain_search_model(model_config) | ||||
|   MAX_FLOP, param  = get_model_infos(search_model, xshape) | ||||
|   optim_config = load_config(args.optim_config, {'class_num': class_num, 'FLOP': MAX_FLOP}, logger) | ||||
|   logger.log('Model Information : {:}'.format(search_model.get_message())) | ||||
|   logger.log('MAX_FLOP = {:} M'.format(MAX_FLOP)) | ||||
|   logger.log('Params   = {:} M'.format(param)) | ||||
|   logger.log('train_data : {:}'.format(train_data)) | ||||
|   logger.log('search-data: {:}'.format(search_dataset)) | ||||
|   logger.log('search_train_loader : {:} samples'.format( len(train_split) )) | ||||
|   logger.log('search_valid_loader : {:} samples'.format( len(valid_split) )) | ||||
|   base_optimizer, scheduler, criterion = get_optim_scheduler(search_model.base_parameters(), optim_config) | ||||
|   arch_optimizer = torch.optim.Adam(search_model.arch_parameters(optim_config.arch_LR), lr=optim_config.arch_LR, betas=(0.5, 0.999), weight_decay=optim_config.arch_decay) | ||||
|   logger.log('base-optimizer : {:}'.format(base_optimizer)) | ||||
|   logger.log('arch-optimizer : {:}'.format(arch_optimizer)) | ||||
|   logger.log('scheduler      : {:}'.format(scheduler)) | ||||
|   logger.log('criterion      : {:}'.format(criterion)) | ||||
|    | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|   # load checkpoint | ||||
|   if last_info.exists() or (args.resume is not None and osp.isfile(args.resume)): # automatically resume from previous checkpoint | ||||
|     if args.resume is not None and osp.isfile(args.resume): | ||||
|       resume_path = Path(args.resume) | ||||
|     elif last_info.exists(): | ||||
|       resume_path = last_info | ||||
|     else: raise ValueError('Something is wrong.') | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start".format(resume_path)) | ||||
|     checkpoint  = torch.load(resume_path) | ||||
|     if 'last_checkpoint' in checkpoint: | ||||
|       last_checkpoint_path = checkpoint['last_checkpoint'] | ||||
|       if not last_checkpoint_path.exists(): | ||||
|         logger.log('Does not find {:}, try another path'.format(last_checkpoint_path)) | ||||
|         last_checkpoint_path = resume_path.parent / last_checkpoint_path.parent.name / last_checkpoint_path.name | ||||
|       assert last_checkpoint_path.exists(), 'can not find the checkpoint from {:}'.format(last_checkpoint_path) | ||||
|       checkpoint = torch.load( last_checkpoint_path ) | ||||
|     start_epoch = checkpoint['epoch'] + 1 | ||||
|     #for key, value in checkpoint['search_model'].items(): | ||||
|     #  print('K {:} = Shape={:}'.format(key, value.shape)) | ||||
|     search_model.load_state_dict( checkpoint['search_model'] ) | ||||
|     scheduler.load_state_dict ( checkpoint['scheduler'] ) | ||||
|     base_optimizer.load_state_dict ( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict ( checkpoint['arch_optimizer'] ) | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     arch_genotypes   = checkpoint['arch_genotypes'] | ||||
|     discrepancies    = checkpoint['discrepancies'] | ||||
|     max_bytes        = checkpoint['max_bytes'] | ||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(resume_path, start_epoch)) | ||||
|   else: | ||||
|     logger.log("=> do not find the last-info file : {:} or resume : {:}".format(last_info, args.resume)) | ||||
|     start_epoch, valid_accuracies, arch_genotypes, discrepancies, max_bytes = 0, {'best': -1}, {}, {}, {} | ||||
|  | ||||
|   # main procedure | ||||
|   train_func, valid_func = get_procedures(args.procedure) | ||||
|   total_epoch = optim_config.epochs + optim_config.warmup | ||||
|   start_time, epoch_time = time.time(), AverageMeter() | ||||
|   for epoch in range(start_epoch, total_epoch): | ||||
|     scheduler.update(epoch, 0.0) | ||||
|     search_model.set_tau(args.gumbel_tau_max, args.gumbel_tau_min, epoch*1.0/total_epoch) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch), True) ) | ||||
|     epoch_str = 'epoch={:03d}/{:03d}'.format(epoch, total_epoch) | ||||
|     LRs       = scheduler.get_lr() | ||||
|     find_best = False | ||||
|     | ||||
|     logger.log('\n***{:s}*** start {:s} {:s}, LR=[{:.6f} ~ {:.6f}], scheduler={:}, tau={:}, FLOP={:.2f}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), scheduler, search_model.tau, MAX_FLOP)) | ||||
|  | ||||
|     # train for one epoch | ||||
|     train_base_loss, train_arch_loss, train_acc1, train_acc5 = train_func(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, \ | ||||
|                                                                                 {'epoch-str'  : epoch_str,        'FLOP-exp': MAX_FLOP * args.FLOP_ratio, | ||||
|                                                                                  'FLOP-weight': args.FLOP_weight, 'FLOP-tolerant': MAX_FLOP * args.FLOP_tolerant}, args.print_freq, logger) | ||||
|     # log the results | ||||
|     logger.log('***{:s}*** TRAIN [{:}] base-loss = {:.6f}, arch-loss = {:.6f}, accuracy-1 = {:.2f}, accuracy-5 = {:.2f}'.format(time_string(), epoch_str, train_base_loss, train_arch_loss, train_acc1, train_acc5)) | ||||
|     cur_FLOP, genotype = search_model.get_flop('genotype', model_config._asdict(), None) | ||||
|     arch_genotypes[epoch]  = genotype | ||||
|     arch_genotypes['last'] = genotype | ||||
|     logger.log('[{:}] genotype : {:}'.format(epoch_str, genotype)) | ||||
|     # save the configuration | ||||
|     configure2str(genotype, str( logger.path('log') / 'seed-{:}-temp.config'.format(args.rand_seed) )) | ||||
|     arch_info, discrepancy = search_model.get_arch_info() | ||||
|     logger.log(arch_info) | ||||
|     discrepancies[epoch]   = discrepancy | ||||
|     logger.log('[{:}] FLOP : {:.2f} MB, ratio : {:.4f}, Expected-ratio : {:.4f}, Discrepancy : {:.3f}'.format(epoch_str, cur_FLOP, cur_FLOP/MAX_FLOP, args.FLOP_ratio, np.mean(discrepancy))) | ||||
|  | ||||
|     #if cur_FLOP/MAX_FLOP > args.FLOP_ratio: | ||||
|     #  init_flop_weight = init_flop_weight * args.FLOP_decay | ||||
|     #else: | ||||
|     #  init_flop_weight = init_flop_weight / args.FLOP_decay | ||||
|      | ||||
|     # evaluate the performance | ||||
|     if (epoch % args.eval_frequency == 0) or (epoch + 1 == total_epoch): | ||||
|       logger.log('-'*150) | ||||
|       valid_loss, valid_acc1, valid_acc5 = valid_func(search_valid_loader, network, criterion, epoch_str, args.print_freq_eval, logger) | ||||
|       valid_accuracies[epoch] = valid_acc1 | ||||
|       logger.log('***{:s}*** VALID [{:}] loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f} | Best-Valid-Acc@1={:.2f}, Error@1={:.2f}'.format(time_string(), epoch_str, valid_loss, valid_acc1, valid_acc5, valid_accuracies['best'], 100-valid_accuracies['best'])) | ||||
|       if valid_acc1 > valid_accuracies['best']: | ||||
|         valid_accuracies['best'] = valid_acc1 | ||||
|         arch_genotypes['best']   = genotype | ||||
|         find_best                = True | ||||
|         logger.log('Currently, the best validation accuracy found at {:03d}-epoch :: acc@1={:.2f}, acc@5={:.2f}, error@1={:.2f}, error@5={:.2f}, save into {:}.'.format(epoch, valid_acc1, valid_acc5, 100-valid_acc1, 100-valid_acc5, model_best_path)) | ||||
|       # log the GPU memory usage | ||||
|       #num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 | ||||
|       num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device ) * 1.0 | ||||
|       logger.log('[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]'.format(next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) | ||||
|       max_bytes[epoch] = num_bytes | ||||
|  | ||||
|     # save checkpoint | ||||
|     save_path = save_checkpoint({ | ||||
|           'epoch'        : epoch, | ||||
|           'args'         : deepcopy(args), | ||||
|           'max_bytes'    : deepcopy(max_bytes), | ||||
|           'valid_accuracies': deepcopy(valid_accuracies), | ||||
|           'model-config' : model_config._asdict(), | ||||
|           'optim-config' : optim_config._asdict(), | ||||
|           'search_model' : search_model.state_dict(), | ||||
|           'scheduler'    : scheduler.state_dict(), | ||||
|           'base_optimizer': base_optimizer.state_dict(), | ||||
|           'arch_optimizer': arch_optimizer.state_dict(), | ||||
|           'arch_genotypes': arch_genotypes, | ||||
|           'discrepancies' : discrepancies, | ||||
|           }, model_base_path, logger) | ||||
|     if find_best: copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     last_info = save_checkpoint({ | ||||
|           'epoch': epoch, | ||||
|           'args' : deepcopy(args), | ||||
|           'last_checkpoint': save_path, | ||||
|           }, logger.path('info'), logger) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|      | ||||
|  | ||||
|   logger.log('') | ||||
|   logger.log('-'*100) | ||||
|   last_config_path = logger.path('log') / 'seed-{:}-last.config'.format(args.rand_seed) | ||||
|   configure2str(arch_genotypes['last'], str(last_config_path)) | ||||
|   logger.log('save the last config int {:} :\n{:}'.format(last_config_path, arch_genotypes['last'])) | ||||
|  | ||||
|   best_arch, valid_acc = arch_genotypes['best'], valid_accuracies['best'] | ||||
|   for key, config in arch_genotypes.items(): | ||||
|     if key == 'last': continue | ||||
|     FLOP_ratio = config['estimated_FLOP'] / MAX_FLOP | ||||
|     if abs(FLOP_ratio - args.FLOP_ratio) <= args.FLOP_tolerant: | ||||
|       if valid_acc <= valid_accuracies[key]: | ||||
|         best_arch, valid_acc = config, valid_accuracies[key] | ||||
|   print('Best-Arch : {:}\nRatio={:}, Valid-ACC={:}'.format(best_arch, best_arch['estimated_FLOP'] / MAX_FLOP, valid_acc)) | ||||
|   best_config_path = logger.path('log') / 'seed-{:}-best.config'.format(args.rand_seed) | ||||
|   configure2str(best_arch, str(best_config_path)) | ||||
|   logger.log('save the last config int {:} :\n{:}'.format(best_config_path, best_arch)) | ||||
|   logger.log('\n' + '-'*200) | ||||
|   logger.log('Finish training/validation in {:} with Max-GPU-Memory of {:.2f} GB, and save final checkpoint into {:}'.format(convert_secs2time(epoch_time.sum, True), max(v for k, v in max_bytes.items()) / 1e9, logger.path('info'))) | ||||
|   logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   args = obtain_args() | ||||
|   main(args) | ||||
		Reference in New Issue
	
	Block a user