init
This commit is contained in:
		
							
								
								
									
										310
									
								
								exps-cnn/acc_search_v2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										310
									
								
								exps-cnn/acc_search_v2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,310 @@ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torchvision.datasets as dset | ||||
| import torch.backends.cudnn as cudnn | ||||
| import torchvision.transforms as transforms | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from utils import AverageMeter, time_string, convert_secs2time | ||||
| from utils import print_log, obtain_accuracy | ||||
| from utils import Cutout, count_parameters_in_MB | ||||
| from nas import Network, NetworkACC2, NetworkV3, NetworkV4, NetworkV5, NetworkFACC1 | ||||
| from nas import return_alphas_str | ||||
| from train_utils import main_procedure | ||||
| from scheduler import load_config | ||||
|  | ||||
| Networks = {'base': Network, 'acc2': NetworkACC2, 'facc1': NetworkFACC1, 'NetworkV3': NetworkV3, 'NetworkV4': NetworkV4, 'NetworkV5': NetworkV5} | ||||
|  | ||||
|  | ||||
| parser = argparse.ArgumentParser("cifar") | ||||
| parser.add_argument('--data_path',         type=str,   help='Path to dataset') | ||||
| parser.add_argument('--dataset',           type=str,   choices=['cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.') | ||||
| parser.add_argument('--arch',              type=str,   choices=Networks.keys(),         help='Choose networks.') | ||||
| parser.add_argument('--batch_size',        type=int,   help='the batch size') | ||||
| parser.add_argument('--learning_rate_max', type=float, help='initial learning rate') | ||||
| parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate') | ||||
| parser.add_argument('--tau_max',           type=float, help='initial tau') | ||||
| parser.add_argument('--tau_min',           type=float, help='minimum tau') | ||||
| parser.add_argument('--momentum',          type=float, help='momentum') | ||||
| parser.add_argument('--weight_decay',      type=float, help='weight decay') | ||||
| parser.add_argument('--epochs',            type=int,   help='num of training epochs') | ||||
| # architecture leraning rate | ||||
| parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
| parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
| # | ||||
| parser.add_argument('--init_channels',      type=int, help='num of init channels') | ||||
| parser.add_argument('--layers',             type=int, help='total number of layers') | ||||
| #  | ||||
| parser.add_argument('--cutout',         type=int,   help='cutout length, negative means no cutout') | ||||
| parser.add_argument('--grad_clip',      type=float, help='gradient clipping') | ||||
| parser.add_argument('--model_config',   type=str  , help='the model configuration') | ||||
|  | ||||
| # resume | ||||
| parser.add_argument('--resume',         type=str  , help='the resume path') | ||||
| parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model') | ||||
| # split data | ||||
| parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not') | ||||
| parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') | ||||
| # log | ||||
| parser.add_argument('--workers',       type=int, default=2, help='number of data loading workers (default: 2)') | ||||
| parser.add_argument('--save_path',     type=str, help='Folder to save checkpoints and log.') | ||||
| parser.add_argument('--print_freq',    type=int, help='print frequency (default: 200)') | ||||
| parser.add_argument('--manualSeed',    type=int, help='manual seed') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| assert torch.cuda.is_available(), 'torch.cuda is not available' | ||||
|  | ||||
| if args.manualSeed is None: | ||||
|   args.manualSeed = random.randint(1, 10000) | ||||
| random.seed(args.manualSeed) | ||||
| cudnn.benchmark = True | ||||
| cudnn.enabled   = True | ||||
| torch.manual_seed(args.manualSeed) | ||||
| torch.cuda.manual_seed_all(args.manualSeed) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|  | ||||
|   # Init logger | ||||
|   args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed)) | ||||
|   if not os.path.isdir(args.save_path): | ||||
|     os.makedirs(args.save_path) | ||||
|   log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w') | ||||
|   print_log('save path : {}'.format(args.save_path), log) | ||||
|   state = {k: v for k, v in args._get_kwargs()} | ||||
|   print_log(state, log) | ||||
|   print_log("Random Seed: {}".format(args.manualSeed), log) | ||||
|   print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log) | ||||
|   print_log("Torch  version : {}".format(torch.__version__), log) | ||||
|   print_log("CUDA   version : {}".format(torch.version.cuda), log) | ||||
|   print_log("cuDNN  version : {}".format(cudnn.version()), log) | ||||
|   print_log("Num of GPUs    : {}".format(torch.cuda.device_count()), log) | ||||
|   args.dataset = args.dataset.lower() | ||||
|  | ||||
|   # Mean + Std | ||||
|   if args.dataset == 'cifar10': | ||||
|     mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|     std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|   elif args.dataset == 'cifar100': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Data Argumentation | ||||
|   if args.dataset == 'cifar10' or args.dataset == 'cifar100': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), | ||||
|              transforms.Normalize(mean, std)] | ||||
|     if args.cutout > 0 : lists += [Cutout(args.cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Datasets | ||||
|   if args.dataset == 'cifar10': | ||||
|     train_data = dset.CIFAR10(args.data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10(args.data_path, train=False, transform=test_transform , download=True) | ||||
|     num_classes = 10 | ||||
|   elif args.dataset == 'cifar100': | ||||
|     train_data = dset.CIFAR100(args.data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(args.data_path, train=False, transform=test_transform , download=True) | ||||
|     num_classes = 100 | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Data Loader | ||||
|   if args.validate: | ||||
|     indices = list(range(len(train_data))) | ||||
|     split   = int(args.train_portion * len(indices)) | ||||
|     random.shuffle(indices) | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), | ||||
|                       pin_memory=True, num_workers=args.workers) | ||||
|     test_loader  = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), | ||||
|                       pin_memory=True, num_workers=args.workers) | ||||
|   else: | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) | ||||
|     test_loader  = torch.utils.data.DataLoader(test_data,  batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) | ||||
|  | ||||
|   # network and criterion | ||||
|   criterion = torch.nn.CrossEntropyLoss().cuda() | ||||
|   basemodel = Networks[args.arch](args.init_channels, num_classes, args.layers) | ||||
|   model     = torch.nn.DataParallel(basemodel).cuda() | ||||
|   print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log) | ||||
|   print_log("Train-transformation : {:}\nTest--transformation : {:}".format(train_transform, test_transform), log) | ||||
|  | ||||
|   # optimizer and LR-scheduler | ||||
|   base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay) | ||||
|   #base_optimizer = torch.optim.Adam(basemodel.base_parameters(), lr=args.learning_rate_max, betas=(0.5, 0.999), weight_decay=args.weight_decay) | ||||
|   base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min) | ||||
|   arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) | ||||
|  | ||||
|   # snapshot | ||||
|   checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth') | ||||
|   if args.resume is not None and os.path.isfile(args.resume): | ||||
|     checkpoint = torch.load(args.resume) | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict( checkpoint['state_dict'] ) | ||||
|     base_optimizer.load_state_dict( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] ) | ||||
|     base_scheduler.load_state_dict( checkpoint['base_scheduler'] ) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log) | ||||
|   elif os.path.isfile(checkpoint_path): | ||||
|     checkpoint = torch.load(checkpoint_path) | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict( checkpoint['state_dict'] ) | ||||
|     base_optimizer.load_state_dict( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] ) | ||||
|     base_scheduler.load_state_dict( checkpoint['base_scheduler'] ) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log) | ||||
|   else: | ||||
|     start_epoch, genotypes = 0, {} | ||||
|     print_log('Train model-search from scratch.', log) | ||||
|  | ||||
|   config = load_config(args.model_config) | ||||
|  | ||||
|   if args.only_base: | ||||
|     print_log('---- Only Train the Searched Model ----', log) | ||||
|     main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log) | ||||
|     return | ||||
|  | ||||
|   # Main loop | ||||
|   start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0 | ||||
|   for epoch in range(start_epoch, args.epochs): | ||||
|     base_scheduler.step() | ||||
|  | ||||
|     basemodel.set_tau( args.tau_max - epoch*1.0/args.epochs*(args.tau_max-args.tau_min) ) | ||||
|     #if epoch + 2 == args.epochs: | ||||
|     #  torch.cuda.empty_cache() | ||||
|     #  basemodel.set_gumbel(False) | ||||
|  | ||||
|     need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True) | ||||
|     print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}] [Batch={:d}], tau={:}'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr()), args.batch_size, basemodel.get_tau()), log) | ||||
|  | ||||
|     genotype = basemodel.genotype() | ||||
|     print_log('genotype = {:}'.format(genotype), log) | ||||
|  | ||||
|     print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log) | ||||
|  | ||||
|     # training | ||||
|     train_acc1, train_acc5, train_obj, train_time \ | ||||
|                                       = train(train_loader, test_loader, model, criterion, base_optimizer, arch_optimizer, epoch, log) | ||||
|     total_train_time += train_time | ||||
|     # validation | ||||
|     valid_acc1, valid_acc5, valid_obj = infer(test_loader, model, criterion, epoch, log) | ||||
|     print_log('{:03d}/{:03d}, Train-Accuracy = {:.2f}, Test-Accuracy = {:.2f}'.format(epoch, args.epochs, train_acc1, valid_acc1), log) | ||||
|     # save genotype | ||||
|     genotypes[epoch] = basemodel.genotype() | ||||
|     # save checkpoint | ||||
|     torch.save({'epoch' : epoch + 1, | ||||
|                 'args'  : deepcopy(args), | ||||
|                 'state_dict': basemodel.state_dict(), | ||||
|                 'genotypes' : genotypes, | ||||
|                 'base_optimizer' : base_optimizer.state_dict(), | ||||
|                 'arch_optimizer' : arch_optimizer.state_dict(), | ||||
|                 'base_scheduler' : base_scheduler.state_dict()}, | ||||
|                 checkpoint_path) | ||||
|     print_log('----> Save into {:}'.format(checkpoint_path), log) | ||||
|  | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log) | ||||
|  | ||||
|   # clear GPU cache | ||||
|   torch.cuda.empty_cache() | ||||
|   main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log) | ||||
|   log.close() | ||||
|  | ||||
|  | ||||
| def train(train_queue, valid_queue, model, criterion, base_optimizer, arch_optimizer, epoch, log): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   model.train() | ||||
|  | ||||
|   valid_iter = iter(valid_queue) | ||||
|   end = time.time() | ||||
|   for step, (inputs, targets) in enumerate(train_queue): | ||||
|     batch, C, H, W = inputs.size() | ||||
|  | ||||
|     #inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     data_time.update(time.time() - end) | ||||
|  | ||||
|     # get a random minibatch from the search queue with replacement | ||||
|     try: | ||||
|       input_search, target_search = next(valid_iter) | ||||
|     except: | ||||
|       valid_iter = iter(valid_queue) | ||||
|       input_search, target_search = next(valid_iter) | ||||
|      | ||||
|     target_search = target_search.cuda(non_blocking=True) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     output_search = model(input_search) | ||||
|     arch_loss = criterion(output_search, target_search) | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|  | ||||
|     # update the parameters | ||||
|     base_optimizer.zero_grad() | ||||
|     logits = model(inputs) | ||||
|     loss = criterion(logits, targets) | ||||
|  | ||||
|     loss.backward() | ||||
|     torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip) | ||||
|     base_optimizer.step() | ||||
|  | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     objs.update(loss.item() , batch) | ||||
|     top1.update(prec1.item(), batch) | ||||
|     top5.update(prec5.item(), batch) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if step % args.print_freq == 0 or (step+1) == len(train_queue): | ||||
|       Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5) | ||||
|       print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return top1.avg, top5.avg, objs.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def infer(valid_queue, model, criterion, epoch, log): | ||||
|   objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|    | ||||
|   model.eval() | ||||
|   with torch.no_grad(): | ||||
|     for step, (inputs, targets) in enumerate(valid_queue): | ||||
|       batch, C, H, W = inputs.size() | ||||
|       targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|       logits = model(inputs) | ||||
|       loss = criterion(logits, targets) | ||||
|  | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       objs.update(loss.item() , batch) | ||||
|       top1.update(prec1.item(), batch) | ||||
|       top5.update(prec5.item(), batch) | ||||
|  | ||||
|       if step % args.print_freq == 0 or (step+1) == len(valid_queue): | ||||
|         Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue)) | ||||
|         Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5) | ||||
|         print_log(Sstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return top1.avg, top5.avg, objs.avg | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main()  | ||||
							
								
								
									
										397
									
								
								exps-cnn/acc_search_v3.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										397
									
								
								exps-cnn/acc_search_v3.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,397 @@ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torchvision.datasets as dset | ||||
| import torch.backends.cudnn as cudnn | ||||
| import torchvision.transforms as transforms | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from utils import AverageMeter, time_string, convert_secs2time | ||||
| from utils import print_log, obtain_accuracy | ||||
| from utils import Cutout, count_parameters_in_MB | ||||
| from nas import Network, NetworkACC2, NetworkV3, NetworkV4, NetworkV5, NetworkFACC1 | ||||
| from nas import return_alphas_str | ||||
| from train_utils import main_procedure | ||||
| from scheduler import load_config | ||||
|  | ||||
| Networks = {'base': Network, 'acc2': NetworkACC2, 'facc1': NetworkFACC1, 'NetworkV3': NetworkV3, 'NetworkV4': NetworkV4, 'NetworkV5': NetworkV5} | ||||
|  | ||||
|  | ||||
| parser = argparse.ArgumentParser("cifar") | ||||
| parser.add_argument('--data_path',         type=str,   help='Path to dataset') | ||||
| parser.add_argument('--dataset',           type=str,   choices=['cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.') | ||||
| parser.add_argument('--arch',              type=str,   choices=Networks.keys(),         help='Choose networks.') | ||||
| parser.add_argument('--batch_size',        type=int,   help='the batch size') | ||||
| parser.add_argument('--learning_rate_max', type=float, help='initial learning rate') | ||||
| parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate') | ||||
| parser.add_argument('--tau_max',           type=float, help='initial tau') | ||||
| parser.add_argument('--tau_min',           type=float, help='minimum tau') | ||||
| parser.add_argument('--momentum',          type=float, help='momentum') | ||||
| parser.add_argument('--weight_decay',      type=float, help='weight decay') | ||||
| parser.add_argument('--epochs',            type=int,   help='num of training epochs') | ||||
| # architecture leraning rate | ||||
| parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
| parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
| # | ||||
| parser.add_argument('--init_channels',      type=int, help='num of init channels') | ||||
| parser.add_argument('--layers',             type=int, help='total number of layers') | ||||
| #  | ||||
| parser.add_argument('--cutout',         type=int,   help='cutout length, negative means no cutout') | ||||
| parser.add_argument('--grad_clip',      type=float, help='gradient clipping') | ||||
| parser.add_argument('--model_config',   type=str  , help='the model configuration') | ||||
|  | ||||
| # resume | ||||
| parser.add_argument('--resume',         type=str  , help='the resume path') | ||||
| parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model') | ||||
| # split data | ||||
| parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not') | ||||
| parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') | ||||
| # log | ||||
| parser.add_argument('--workers',       type=int, default=2, help='number of data loading workers (default: 2)') | ||||
| parser.add_argument('--save_path',     type=str, help='Folder to save checkpoints and log.') | ||||
| parser.add_argument('--print_freq',    type=int, help='print frequency (default: 200)') | ||||
| parser.add_argument('--manualSeed',    type=int, help='manual seed') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| assert torch.cuda.is_available(), 'torch.cuda is not available' | ||||
|  | ||||
| if args.manualSeed is None: | ||||
|   args.manualSeed = random.randint(1, 10000) | ||||
| random.seed(args.manualSeed) | ||||
| cudnn.benchmark = True | ||||
| cudnn.enabled   = True | ||||
| torch.manual_seed(args.manualSeed) | ||||
| torch.cuda.manual_seed_all(args.manualSeed) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|  | ||||
|   # Init logger | ||||
|   args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed)) | ||||
|   if not os.path.isdir(args.save_path): | ||||
|     os.makedirs(args.save_path) | ||||
|   log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w') | ||||
|   print_log('save path : {}'.format(args.save_path), log) | ||||
|   state = {k: v for k, v in args._get_kwargs()} | ||||
|   print_log(state, log) | ||||
|   print_log("Random Seed: {}".format(args.manualSeed), log) | ||||
|   print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log) | ||||
|   print_log("Torch  version : {}".format(torch.__version__), log) | ||||
|   print_log("CUDA   version : {}".format(torch.version.cuda), log) | ||||
|   print_log("cuDNN  version : {}".format(cudnn.version()), log) | ||||
|   print_log("Num of GPUs    : {}".format(torch.cuda.device_count()), log) | ||||
|   args.dataset = args.dataset.lower() | ||||
|  | ||||
|   # Mean + Std | ||||
|   if args.dataset == 'cifar10': | ||||
|     mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|     std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|   elif args.dataset == 'cifar100': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Data Argumentation | ||||
|   if args.dataset == 'cifar10' or args.dataset == 'cifar100': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), | ||||
|              transforms.Normalize(mean, std)] | ||||
|     if args.cutout > 0 : lists += [Cutout(args.cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Datasets | ||||
|   if args.dataset == 'cifar10': | ||||
|     train_data = dset.CIFAR10(args.data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10(args.data_path, train=False, transform=test_transform , download=True) | ||||
|     num_classes = 10 | ||||
|   elif args.dataset == 'cifar100': | ||||
|     train_data = dset.CIFAR100(args.data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(args.data_path, train=False, transform=test_transform , download=True) | ||||
|     num_classes = 100 | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Data Loader | ||||
|   if args.validate: | ||||
|     indices = list(range(len(train_data))) | ||||
|     split   = int(args.train_portion * len(indices)) | ||||
|     random.shuffle(indices) | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), | ||||
|                       pin_memory=True, num_workers=args.workers) | ||||
|     test_loader  = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), | ||||
|                       pin_memory=True, num_workers=args.workers) | ||||
|   else: | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) | ||||
|     test_loader  = torch.utils.data.DataLoader(test_data,  batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) | ||||
|  | ||||
|   # network and criterion | ||||
|   criterion = torch.nn.CrossEntropyLoss().cuda() | ||||
|   basemodel = Networks[args.arch](args.init_channels, num_classes, args.layers) | ||||
|   model     = torch.nn.DataParallel(basemodel).cuda() | ||||
|   print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log) | ||||
|   print_log("Train-transformation : {:}\nTest--transformation : {:}".format(train_transform, test_transform), log) | ||||
|  | ||||
|   # optimizer and LR-scheduler | ||||
|   base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay) | ||||
|   #base_optimizer = torch.optim.Adam(basemodel.base_parameters(), lr=args.learning_rate_max, betas=(0.5, 0.999), weight_decay=args.weight_decay) | ||||
|   base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min) | ||||
|   arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) | ||||
|  | ||||
|   # snapshot | ||||
|   checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth') | ||||
|   if args.resume is not None and os.path.isfile(args.resume): | ||||
|     checkpoint = torch.load(args.resume) | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict( checkpoint['state_dict'] ) | ||||
|     base_optimizer.load_state_dict( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] ) | ||||
|     base_scheduler.load_state_dict( checkpoint['base_scheduler'] ) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log) | ||||
|   elif os.path.isfile(checkpoint_path): | ||||
|     checkpoint = torch.load(checkpoint_path) | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict( checkpoint['state_dict'] ) | ||||
|     base_optimizer.load_state_dict( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] ) | ||||
|     base_scheduler.load_state_dict( checkpoint['base_scheduler'] ) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log) | ||||
|   else: | ||||
|     start_epoch, genotypes = 0, {} | ||||
|     print_log('Train model-search from scratch.', log) | ||||
|  | ||||
|   config = load_config(args.model_config) | ||||
|  | ||||
|   if args.only_base: | ||||
|     print_log('---- Only Train the Searched Model ----', log) | ||||
|     main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log) | ||||
|     return | ||||
|  | ||||
|   # Main loop | ||||
|   start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0 | ||||
|   for epoch in range(start_epoch, args.epochs): | ||||
|     base_scheduler.step() | ||||
|  | ||||
|     basemodel.set_tau( args.tau_max - epoch*1.0/args.epochs*(args.tau_max-args.tau_min) ) | ||||
|     #if epoch + 1 == args.epochs: | ||||
|     #  torch.cuda.empty_cache() | ||||
|     #  basemodel.set_gumbel(False) | ||||
|  | ||||
|     need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True) | ||||
|     print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}] [Batch={:d}], tau={:}'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr()), args.batch_size, basemodel.get_tau()), log) | ||||
|  | ||||
|     genotype = basemodel.genotype() | ||||
|     print_log('genotype = {:}'.format(genotype), log) | ||||
|  | ||||
|     print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log) | ||||
|  | ||||
|     # training | ||||
|     if epoch + 1 == args.epochs: | ||||
|       train_acc1, train_acc5, train_obj, train_time \ | ||||
|                                       = train_joint(train_loader, test_loader, model, criterion, base_optimizer, arch_optimizer, epoch, log) | ||||
|       total_train_time += train_time | ||||
|     else: | ||||
|       train_acc1, train_acc5, train_obj, train_time \ | ||||
|                                       = train_base(train_loader, None, model, criterion, base_optimizer, None, epoch, log) | ||||
|       total_train_time += train_time | ||||
|       Arch__acc1, Arch__acc5, Arch__obj, train_time \ | ||||
|                                       = train_arch(None , test_loader, model, criterion, None, arch_optimizer, epoch, log) | ||||
|       total_train_time += train_time | ||||
|     # validation | ||||
|     valid_acc1, valid_acc5, valid_obj = infer(test_loader, model, criterion, epoch, log) | ||||
|     print_log('{:03d}/{:03d}, Train-Accuracy = {:.2f}, Arch-Accuracy = {:.2f}, Test-Accuracy = {:.2f}'.format(epoch, args.epochs, train_acc1, Arch__acc1, valid_acc1), log) | ||||
|  | ||||
|     # save genotype | ||||
|     genotypes[epoch] = basemodel.genotype() | ||||
|     # save checkpoint | ||||
|     torch.save({'epoch' : epoch + 1, | ||||
|                 'args'  : deepcopy(args), | ||||
|                 'state_dict': basemodel.state_dict(), | ||||
|                 'genotypes' : genotypes, | ||||
|                 'base_optimizer' : base_optimizer.state_dict(), | ||||
|                 'arch_optimizer' : arch_optimizer.state_dict(), | ||||
|                 'base_scheduler' : base_scheduler.state_dict()}, | ||||
|                 checkpoint_path) | ||||
|     print_log('----> Save into {:}'.format(checkpoint_path), log) | ||||
|  | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log) | ||||
|  | ||||
|   # clear GPU cache | ||||
|   torch.cuda.empty_cache() | ||||
|   main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log) | ||||
|   log.close() | ||||
|  | ||||
|  | ||||
| def train_base(train_queue, _, model, criterion, base_optimizer, __, epoch, log): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   model.train() | ||||
|  | ||||
|   end = time.time() | ||||
|   for step, (inputs, targets) in enumerate(train_queue): | ||||
|     batch, C, H, W = inputs.size() | ||||
|  | ||||
|     #inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     data_time.update(time.time() - end) | ||||
|  | ||||
|     # update the parameters | ||||
|     base_optimizer.zero_grad() | ||||
|     logits = model(inputs) | ||||
|     loss = criterion(logits, targets) | ||||
|  | ||||
|     loss.backward() | ||||
|     torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip) | ||||
|     base_optimizer.step() | ||||
|  | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     objs.update(loss.item() , batch) | ||||
|     top1.update(prec1.item(), batch) | ||||
|     top5.update(prec5.item(), batch) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if step % args.print_freq == 0 or (step+1) == len(train_queue): | ||||
|       Sstr = ' TRAIN-BASE ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5) | ||||
|       print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return top1.avg, top5.avg, objs.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def train_arch(_, valid_queue, model, criterion, __, arch_optimizer, epoch, log): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   model.train() | ||||
|  | ||||
|   end = time.time() | ||||
|   for step, (inputs, targets) in enumerate(valid_queue): | ||||
|     batch, C, H, W = inputs.size() | ||||
|  | ||||
|     #inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     data_time.update(time.time() - end) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     outputs = model(inputs) | ||||
|     arch_loss = criterion(outputs, targets) | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|  | ||||
|     prec1, prec5 = obtain_accuracy(outputs.data, targets.data, topk=(1, 5)) | ||||
|     objs.update(arch_loss.item() , batch) | ||||
|     top1.update(prec1.item(), batch) | ||||
|     top5.update(prec5.item(), batch) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if step % args.print_freq == 0 or (step+1) == len(valid_queue): | ||||
|       Sstr = ' TRAIN-ARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5) | ||||
|       print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return top1.avg, top5.avg, objs.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def train_joint(train_queue, valid_queue, model, criterion, base_optimizer, arch_optimizer, epoch, log): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   model.train() | ||||
|  | ||||
|   valid_iter = iter(valid_queue) | ||||
|   end = time.time() | ||||
|   for step, (inputs, targets) in enumerate(train_queue): | ||||
|     batch, C, H, W = inputs.size() | ||||
|  | ||||
|     #inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     data_time.update(time.time() - end) | ||||
|  | ||||
|     # get a random minibatch from the search queue with replacement | ||||
|     try: | ||||
|       input_search, target_search = next(valid_iter) | ||||
|     except: | ||||
|       valid_iter = iter(valid_queue) | ||||
|       input_search, target_search = next(valid_iter) | ||||
|      | ||||
|     target_search = target_search.cuda(non_blocking=True) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     output_search = model(input_search) | ||||
|     arch_loss = criterion(output_search, target_search) | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|  | ||||
|     # update the parameters | ||||
|     base_optimizer.zero_grad() | ||||
|     logits = model(inputs) | ||||
|     loss = criterion(logits, targets) | ||||
|  | ||||
|     loss.backward() | ||||
|     torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip) | ||||
|     base_optimizer.step() | ||||
|  | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     objs.update(loss.item() , batch) | ||||
|     top1.update(prec1.item(), batch) | ||||
|     top5.update(prec5.item(), batch) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if step % args.print_freq == 0 or (step+1) == len(train_queue): | ||||
|       Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5) | ||||
|       print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return top1.avg, top5.avg, objs.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def infer(valid_queue, model, criterion, epoch, log): | ||||
|   objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|    | ||||
|   model.eval() | ||||
|   with torch.no_grad(): | ||||
|     for step, (inputs, targets) in enumerate(valid_queue): | ||||
|       batch, C, H, W = inputs.size() | ||||
|       targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|       logits = model(inputs) | ||||
|       loss = criterion(logits, targets) | ||||
|  | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       objs.update(loss.item() , batch) | ||||
|       top1.update(prec1.item(), batch) | ||||
|       top5.update(prec5.item(), batch) | ||||
|  | ||||
|       if step % args.print_freq == 0 or (step+1) == len(valid_queue): | ||||
|         Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue)) | ||||
|         Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5) | ||||
|         print_log(Sstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return top1.avg, top5.avg, objs.avg | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main()  | ||||
							
								
								
									
										94
									
								
								exps-cnn/cvpr-vis.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								exps-cnn/cvpr-vis.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| # python ./exps-nas/cvpr-vis.py --save_dir ./snapshots/NAS-VIS/ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from nas import DMS_V1, DMS_F1 | ||||
| from nas_rnn import DARTS_V2, GDAS | ||||
| from graphviz import Digraph | ||||
|  | ||||
| parser = argparse.ArgumentParser("Visualize the Networks") | ||||
| parser.add_argument('--save_dir',   type=str,   help='The directory to save the network plot.') | ||||
| args = parser.parse_args() | ||||
|  | ||||
|  | ||||
| def plot_cnn(genotype, filename): | ||||
|   g = Digraph( | ||||
|       format='pdf', | ||||
|       edge_attr=dict(fontsize='20', fontname="times"), | ||||
|       node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), | ||||
|       engine='dot') | ||||
|   g.body.extend(['rankdir=LR']) | ||||
|  | ||||
|   g.node("c_{k-2}", fillcolor='darkseagreen2') | ||||
|   g.node("c_{k-1}", fillcolor='darkseagreen2') | ||||
|   assert len(genotype) % 2 == 0, '{:}'.format(genotype) | ||||
|   steps = len(genotype) // 2 | ||||
|  | ||||
|   for i in range(steps): | ||||
|     g.node(str(i), fillcolor='lightblue') | ||||
|  | ||||
|   for i in range(steps): | ||||
|     for k in [2*i, 2*i + 1]: | ||||
|       op, j, weight = genotype[k] | ||||
|       if j == 0: | ||||
|         u = "c_{k-2}" | ||||
|       elif j == 1: | ||||
|         u = "c_{k-1}" | ||||
|       else: | ||||
|         u = str(j-2) | ||||
|       v = str(i) | ||||
|       g.edge(u, v, label=op, fillcolor="gray") | ||||
|  | ||||
|   g.node("c_{k}", fillcolor='palegoldenrod') | ||||
|   for i in range(steps): | ||||
|     g.edge(str(i), "c_{k}", fillcolor="gray") | ||||
|  | ||||
|   g.render(filename, view=False) | ||||
|  | ||||
| def plot_rnn(genotype, filename): | ||||
|   g = Digraph( | ||||
|       format='pdf', | ||||
|       edge_attr=dict(fontsize='20', fontname="times"), | ||||
|       node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), | ||||
|       engine='dot') | ||||
|   g.body.extend(['rankdir=LR']) | ||||
|  | ||||
|   g.node("x_{t}", fillcolor='darkseagreen2') | ||||
|   g.node("h_{t-1}", fillcolor='darkseagreen2') | ||||
|   g.node("0", fillcolor='lightblue') | ||||
|   g.edge("x_{t}", "0", fillcolor="gray") | ||||
|   g.edge("h_{t-1}", "0", fillcolor="gray") | ||||
|   steps = len(genotype) | ||||
|  | ||||
|   for i in range(1, steps + 1): | ||||
|     g.node(str(i), fillcolor='lightblue') | ||||
|  | ||||
|   for i, (op, j) in enumerate(genotype): | ||||
|     g.edge(str(j), str(i + 1), label=op, fillcolor="gray") | ||||
|  | ||||
|   g.node("h_{t}", fillcolor='palegoldenrod') | ||||
|   for i in range(1, steps + 1): | ||||
|     g.edge(str(i), "h_{t}", fillcolor="gray") | ||||
|  | ||||
|   g.render(filename, view=False) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   save_dir   = Path(args.save_dir) | ||||
|  | ||||
|   save_path = str(save_dir / 'DMS_V1-normal') | ||||
|   plot_cnn(DMS_V1.normal, save_path) | ||||
|   save_path = str(save_dir / 'DMS_V1-reduce') | ||||
|   plot_cnn(DMS_V1.reduce, save_path) | ||||
|   save_path = str(save_dir / 'DMS_F1-normal') | ||||
|   plot_cnn(DMS_F1.normal, save_path) | ||||
|  | ||||
|   save_path = str(save_dir / 'DARTS-V2-RNN') | ||||
|   plot_rnn(DARTS_V2.recurrent, save_path) | ||||
|  | ||||
|   save_path = str(save_dir / 'GDAS-V1-RNN') | ||||
|   plot_rnn(GDAS.recurrent, save_path) | ||||
							
								
								
									
										312
									
								
								exps-cnn/meta_search.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										312
									
								
								exps-cnn/meta_search.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,312 @@ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torch.backends.cudnn as cudnn | ||||
| import torchvision.transforms as transforms | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from datasets import TieredImageNet, MetaBatchSampler | ||||
| from utils import AverageMeter, time_string, convert_secs2time | ||||
| from utils import print_log, obtain_accuracy | ||||
| from utils import Cutout, count_parameters_in_MB | ||||
| from meta_nas import return_alphas_str, MetaNetwork | ||||
| from train_utils import main_procedure | ||||
| from scheduler import load_config | ||||
|  | ||||
| Networks = {'meta': MetaNetwork} | ||||
|  | ||||
| parser = argparse.ArgumentParser("cifar") | ||||
| parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||
| parser.add_argument('--arch',               type=str,   choices=Networks.keys(), help='Choose networks.') | ||||
| parser.add_argument('--n_way',              type=int,   help='N-WAY.') | ||||
| parser.add_argument('--k_shot',             type=int,   help='K-SHOT.') | ||||
| # Learning Parameters | ||||
| parser.add_argument('--learning_rate_max',  type=float, help='initial learning rate') | ||||
| parser.add_argument('--learning_rate_min',  type=float, help='minimum learning rate') | ||||
| parser.add_argument('--momentum',           type=float, help='momentum') | ||||
| parser.add_argument('--weight_decay',       type=float, help='weight decay') | ||||
| parser.add_argument('--epochs',             type=int,   help='num of training epochs') | ||||
| # architecture leraning rate | ||||
| parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
| parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
| # | ||||
| parser.add_argument('--init_channels',      type=int, help='num of init channels') | ||||
| parser.add_argument('--layers',             type=int, help='total number of layers') | ||||
| #  | ||||
| parser.add_argument('--cutout',             type=int,   help='cutout length, negative means no cutout') | ||||
| parser.add_argument('--grad_clip',          type=float, help='gradient clipping') | ||||
| parser.add_argument('--model_config',       type=str  , help='the model configuration') | ||||
|  | ||||
| # resume | ||||
| parser.add_argument('--resume',             type=str  , help='the resume path') | ||||
| parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model') | ||||
| # split data | ||||
| parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not') | ||||
| parser.add_argument('--train_portion',      type=float, default=0.5, help='portion of training data') | ||||
| # log | ||||
| parser.add_argument('--workers',            type=int, default=2, help='number of data loading workers (default: 2)') | ||||
| parser.add_argument('--save_path',          type=str, help='Folder to save checkpoints and log.') | ||||
| parser.add_argument('--print_freq',         type=int, help='print frequency (default: 200)') | ||||
| parser.add_argument('--manualSeed',         type=int, help='manual seed') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| assert torch.cuda.is_available(), 'torch.cuda is not available' | ||||
|  | ||||
| if args.manualSeed is None: | ||||
|   args.manualSeed = random.randint(1, 10000) | ||||
| random.seed(args.manualSeed) | ||||
| cudnn.benchmark = True | ||||
| cudnn.enabled   = True | ||||
| torch.manual_seed(args.manualSeed) | ||||
| torch.cuda.manual_seed_all(args.manualSeed) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|  | ||||
|   # Init logger | ||||
|   args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed)) | ||||
|   if not os.path.isdir(args.save_path): | ||||
|     os.makedirs(args.save_path) | ||||
|   log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w') | ||||
|   print_log('save path : {}'.format(args.save_path), log) | ||||
|   state = {k: v for k, v in args._get_kwargs()} | ||||
|   print_log(state, log) | ||||
|   print_log("Random Seed: {}".format(args.manualSeed), log) | ||||
|   print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log) | ||||
|   print_log("Torch  version : {}".format(torch.__version__), log) | ||||
|   print_log("CUDA   version : {}".format(torch.version.cuda), log) | ||||
|   print_log("cuDNN  version : {}".format(cudnn.version()), log) | ||||
|   print_log("Num of GPUs    : {}".format(torch.cuda.device_count()), log) | ||||
|  | ||||
|   # Mean + Std | ||||
|   means, stds = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|   # Data Argumentation | ||||
|   lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), | ||||
|            transforms.Normalize(means, stds)] | ||||
|   if args.cutout > 0 : lists += [Cutout(args.cutout)] | ||||
|   train_transform = transforms.Compose(lists) | ||||
|   test_transform  = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(means, stds)]) | ||||
|    | ||||
|   train_data = TieredImageNet(args.data_path, 'train', train_transform) | ||||
|   test_data  = TieredImageNet(args.data_path, 'val'  , test_transform ) | ||||
|  | ||||
|   train_sampler = MetaBatchSampler(train_data.labels, args.n_way, args.k_shot * 2, len(train_data) // (args.n_way*args.k_shot)) | ||||
|   test_sampler  = MetaBatchSampler( test_data.labels, args.n_way, args.k_shot * 2, len( test_data) // (args.n_way*args.k_shot)) | ||||
|  | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_sampler=train_sampler) | ||||
|   test_loader  = torch.utils.data.DataLoader( test_data, batch_sampler= test_sampler) | ||||
|  | ||||
|   # network | ||||
|   basemodel = Networks[args.arch](args.init_channels, args.layers, head='imagenet') | ||||
|   model     = torch.nn.DataParallel(basemodel).cuda() | ||||
|   print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log) | ||||
|   print_log("Train-transformation : {:}\nTest--transformation : {:}".format(train_transform, test_transform), log) | ||||
|  | ||||
|   # optimizer and LR-scheduler | ||||
|   #base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay) | ||||
|   base_optimizer = torch.optim.Adam(basemodel.base_parameters(), lr=args.learning_rate_max, betas=(0.5, 0.999), weight_decay=args.weight_decay) | ||||
|   base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min) | ||||
|   arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) | ||||
|  | ||||
|   # snapshot | ||||
|   checkpoint_path = os.path.join(args.save_path, 'checkpoint-meta-search.pth') | ||||
|   if args.resume is not None and os.path.isfile(args.resume): | ||||
|     checkpoint = torch.load(args.resume) | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict( checkpoint['state_dict'] ) | ||||
|     base_optimizer.load_state_dict( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] ) | ||||
|     base_scheduler.load_state_dict( checkpoint['base_scheduler'] ) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log) | ||||
|   elif os.path.isfile(checkpoint_path): | ||||
|     checkpoint = torch.load(checkpoint_path) | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict( checkpoint['state_dict'] ) | ||||
|     base_optimizer.load_state_dict( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] ) | ||||
|     base_scheduler.load_state_dict( checkpoint['base_scheduler'] ) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log) | ||||
|   else: | ||||
|     start_epoch, genotypes = 0, {} | ||||
|     print_log('Train model-search from scratch.', log) | ||||
|  | ||||
|   config = load_config(args.model_config) | ||||
|  | ||||
|   if args.only_base: | ||||
|     print_log('---- Only Train the Searched Model ----', log) | ||||
|     CIFAR_DATA_DIR = os.environ['TORCH_HOME'] + '/cifar.python' | ||||
|     main_procedure(config, 'cifar10', CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log) | ||||
|     return | ||||
|  | ||||
|   # Main loop | ||||
|   start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0 | ||||
|   for epoch in range(start_epoch, args.epochs): | ||||
|     base_scheduler.step() | ||||
|  | ||||
|     need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True) | ||||
|     print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr())), log) | ||||
|  | ||||
|     genotype = basemodel.genotype() | ||||
|     print_log('genotype = {:}'.format(genotype), log) | ||||
|     print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log) | ||||
|  | ||||
|     # training | ||||
|     train_acc1, train_obj, train_time \ | ||||
|                                       = train(train_loader, test_loader, model, args.n_way, base_optimizer, arch_optimizer, epoch, log) | ||||
|     total_train_time += train_time | ||||
|     # validation | ||||
|     valid_acc1, valid_obj = infer(test_loader, model, epoch, args.n_way, log) | ||||
|  | ||||
|     print_log('META -> {:}-way {:}-shot : {:03d}/{:03d} : Train Acc : {:.2f}, Test Acc : {:.2f}'.format(args.n_way, args.k_shot, epoch, args.epochs, train_acc1, valid_acc1), log) | ||||
|     # save genotype | ||||
|     genotypes[epoch] = basemodel.genotype() | ||||
|    | ||||
|     # save checkpoint | ||||
|     torch.save({'epoch' : epoch + 1, | ||||
|                 'args'  : deepcopy(args), | ||||
|                 'state_dict': basemodel.state_dict(), | ||||
|                 'genotypes' : genotypes, | ||||
|                 'base_optimizer' : base_optimizer.state_dict(), | ||||
|                 'arch_optimizer' : arch_optimizer.state_dict(), | ||||
|                 'base_scheduler' : base_scheduler.state_dict()}, | ||||
|                 checkpoint_path) | ||||
|     print_log('----> Save into {:}'.format(checkpoint_path), log) | ||||
|  | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log) | ||||
|  | ||||
|   # clear GPU cache | ||||
|   CIFAR_DATA_DIR = os.environ['TORCH_HOME'] + '/cifar.python' | ||||
|   print_log('test for CIFAR-10', log) | ||||
|   torch.cuda.empty_cache() | ||||
|   main_procedure(config, 'cifar10' , CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log) | ||||
|   print_log('test for CIFAR-100', log) | ||||
|   torch.cuda.empty_cache() | ||||
|   main_procedure(config, 'cifar100', CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log) | ||||
|   log.close() | ||||
|  | ||||
|  | ||||
|  | ||||
| def euclidean_dist(A, B): | ||||
|   na, da = A.size() | ||||
|   nb, db = B.size() | ||||
|   assert da == db, 'invalid feature dim : {:} vs. {:}'.format(da, db) | ||||
|   X, Y = A.view(na, 1, da), B.view(1, nb, db) | ||||
|   return torch.pow(X-Y, 2).sum(2) | ||||
|    | ||||
|  | ||||
|  | ||||
| def get_loss(features, targets, n_way): | ||||
|   classes = torch.unique(targets) | ||||
|   shot = features.size(0) // n_way // 2 | ||||
|  | ||||
|   support_index, query_index, labels = [], [], [] | ||||
|   for idx, cls in enumerate( classes.tolist() ): | ||||
|     indexs = (targets == cls).nonzero().view(-1).tolist() | ||||
|     support_index.append(indexs[:shot]) | ||||
|     query_index   += indexs[shot:] | ||||
|     labels        += [idx] * shot | ||||
|   query_features = features[query_index, :] | ||||
|   support_features = features[support_index, :] | ||||
|   support_features = torch.mean(support_features, dim=1) | ||||
|      | ||||
|   labels = torch.LongTensor(labels).cuda(non_blocking=True) | ||||
|   logits = -euclidean_dist(query_features, support_features) | ||||
|   loss = F.cross_entropy(logits, labels) | ||||
|   accuracy = obtain_accuracy(logits.data, labels.data, topk=(1,))[0] | ||||
|   return loss, accuracy | ||||
|  | ||||
|  | ||||
|  | ||||
| def train(train_queue, valid_queue, model, n_way, base_optimizer, arch_optimizer, epoch, log): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   objs, accuracies = AverageMeter(), AverageMeter() | ||||
|   model.train() | ||||
|  | ||||
|   valid_iter = iter(valid_queue) | ||||
|   end = time.time() | ||||
|   for step, (inputs, targets) in enumerate(train_queue): | ||||
|     batch, C, H, W = inputs.size() | ||||
|  | ||||
|     #inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) | ||||
|     #targets = targets.cuda(non_blocking=True) | ||||
|     data_time.update(time.time() - end) | ||||
|  | ||||
|     # get a random minibatch from the search queue with replacement | ||||
|     try: | ||||
|       input_search, target_search = next(valid_iter) | ||||
|     except: | ||||
|       valid_iter = iter(valid_queue) | ||||
|       input_search, target_search = next(valid_iter) | ||||
|      | ||||
|     #target_search = target_search.cuda(non_blocking=True) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     feature_search = model(input_search) | ||||
|     arch_loss, arch_accuracy = get_loss(feature_search, target_search, n_way) | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|  | ||||
|     # update the parameters | ||||
|     base_optimizer.zero_grad() | ||||
|     feature_model = model(inputs) | ||||
|     model_loss, model_accuracy = get_loss(feature_model, targets, n_way) | ||||
|  | ||||
|     model_loss.backward() | ||||
|     torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip) | ||||
|     base_optimizer.step() | ||||
|  | ||||
|     objs.update(model_loss.item() , batch) | ||||
|     accuracies.update(model_accuracy.item(), batch) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if step % args.print_freq == 0 or (step+1) == len(train_queue): | ||||
|       Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f})'.format(loss=objs, top1=accuracies) | ||||
|       Istr = 'I : {:}'.format( list(inputs.size()) ) | ||||
|       print_log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr, log) | ||||
|  | ||||
|   return accuracies.avg, objs.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
|  | ||||
| def infer(valid_queue, model, epoch, n_way, log): | ||||
|   objs, accuracies = AverageMeter(), AverageMeter() | ||||
|    | ||||
|   model.eval() | ||||
|   with torch.no_grad(): | ||||
|     for step, (inputs, targets) in enumerate(valid_queue): | ||||
|       batch, C, H, W = inputs.size() | ||||
|       #targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|       features = model(inputs) | ||||
|       loss, accuracy = get_loss(features, targets, n_way) | ||||
|  | ||||
|       objs.update(loss.item() , batch) | ||||
|       accuracies.update(accuracy.item(), batch) | ||||
|  | ||||
|       if step % (args.print_freq*4) == 0 or (step+1) == len(valid_queue): | ||||
|         Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue)) | ||||
|         Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f})'.format(loss=objs, top1=accuracies) | ||||
|         print_log(Sstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return accuracies.avg, objs.avg | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main()  | ||||
							
								
								
									
										96
									
								
								exps-cnn/train_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								exps-cnn/train_base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,96 @@ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| import torchvision.datasets as dset | ||||
| import torch.backends.cudnn as cudnn | ||||
| import torchvision.transforms as transforms | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from utils import AverageMeter, time_string, convert_secs2time | ||||
| from utils import print_log, obtain_accuracy | ||||
| from utils import Cutout, count_parameters_in_MB | ||||
| from nas import DARTS_V1, DARTS_V2, NASNet, PNASNet, AmoebaNet, ENASNet | ||||
| from nas import DMS_V1, DMS_F1, GDAS_CC | ||||
| from meta_nas import META_V1, META_V2 | ||||
| from train_utils import main_procedure | ||||
| from train_utils_imagenet import main_procedure_imagenet | ||||
| from scheduler import load_config | ||||
|  | ||||
| models = {'DARTS_V1': DARTS_V1, | ||||
|           'DARTS_V2': DARTS_V2, | ||||
|           'NASNet'  : NASNet, | ||||
|           'PNASNet' : PNASNet, | ||||
|           'ENASNet' : ENASNet, | ||||
|           'DMS_V1'  : DMS_V1, | ||||
|           'DMS_F1'  : DMS_F1, | ||||
|           'GDAS_CC' : GDAS_CC, | ||||
|           'META_V1' : META_V1, | ||||
|           'META_V2' : META_V2, | ||||
|           'AmoebaNet' : AmoebaNet} | ||||
|  | ||||
|  | ||||
| parser = argparse.ArgumentParser("cifar") | ||||
| parser.add_argument('--data_path',         type=str,   help='Path to dataset') | ||||
| parser.add_argument('--dataset',           type=str,   choices=['imagenet', 'cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.') | ||||
| parser.add_argument('--arch',              type=str,   choices=models.keys(), help='the searched model.') | ||||
| #  | ||||
| parser.add_argument('--grad_clip',      type=float, help='gradient clipping') | ||||
| parser.add_argument('--model_config',   type=str  , help='the model configuration') | ||||
| parser.add_argument('--init_channels',  type=int  , help='the initial number of channels') | ||||
| parser.add_argument('--layers',         type=int  , help='the number of layers.') | ||||
|  | ||||
| # log | ||||
| parser.add_argument('--workers',       type=int, default=2, help='number of data loading workers (default: 2)') | ||||
| parser.add_argument('--save_path',     type=str, help='Folder to save checkpoints and log.') | ||||
| parser.add_argument('--print_freq',    type=int, help='print frequency (default: 200)') | ||||
| parser.add_argument('--manualSeed',    type=int, help='manual seed') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| assert torch.cuda.is_available(), 'torch.cuda is not available' | ||||
|  | ||||
| if args.manualSeed is None: | ||||
|   args.manualSeed = random.randint(1, 10000) | ||||
| random.seed(args.manualSeed) | ||||
| cudnn.benchmark = True | ||||
| cudnn.enabled   = True | ||||
| torch.manual_seed(args.manualSeed) | ||||
| torch.cuda.manual_seed_all(args.manualSeed) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|  | ||||
|   # Init logger | ||||
|   args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed)) | ||||
|   if not os.path.isdir(args.save_path): | ||||
|     os.makedirs(args.save_path) | ||||
|   log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w') | ||||
|   print_log('save path : {}'.format(args.save_path), log) | ||||
|   state = {k: v for k, v in args._get_kwargs()} | ||||
|   print_log(state, log) | ||||
|   print_log("Random Seed: {}".format(args.manualSeed), log) | ||||
|   print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log) | ||||
|   print_log("Torch  version : {}".format(torch.__version__), log) | ||||
|   print_log("CUDA   version : {}".format(torch.version.cuda), log) | ||||
|   print_log("cuDNN  version : {}".format(cudnn.version()), log) | ||||
|   print_log("Num of GPUs    : {}".format(torch.cuda.device_count()), log) | ||||
|   args.dataset = args.dataset.lower() | ||||
|  | ||||
|   config = load_config(args.model_config) | ||||
|   genotype = models[args.arch] | ||||
|   print_log('configuration : {:}'.format(config), log) | ||||
|   print_log('genotype      : {:}'.format(genotype), log) | ||||
|   # clear GPU cache | ||||
|   torch.cuda.empty_cache() | ||||
|   if args.dataset == 'imagenet': | ||||
|     main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, log) | ||||
|   else: | ||||
|     main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, log) | ||||
|   log.close() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main()  | ||||
							
								
								
									
										312
									
								
								exps-cnn/train_search.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										312
									
								
								exps-cnn/train_search.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,312 @@ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torchvision.datasets as dset | ||||
| import torch.backends.cudnn as cudnn | ||||
| import torchvision.transforms as transforms | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from utils import AverageMeter, time_string, convert_secs2time | ||||
| from utils import print_log, obtain_accuracy | ||||
| from utils import Cutout, count_parameters_in_MB | ||||
| from datasets import TieredImageNet | ||||
| from nas import return_alphas_str, Network, NetworkV1, NetworkF1 | ||||
| from train_utils import main_procedure | ||||
| from scheduler import load_config | ||||
|  | ||||
| Networks = {'base': Network, 'share': NetworkV1, 'fix': NetworkF1} | ||||
|  | ||||
| parser = argparse.ArgumentParser("CNN") | ||||
| parser.add_argument('--data_path',         type=str,   help='Path to dataset') | ||||
| parser.add_argument('--dataset',           type=str,   choices=['cifar10', 'cifar100', 'tiered'], help='Choose between Cifar10/100 and TieredImageNet.') | ||||
| parser.add_argument('--arch',              type=str,   choices=Networks.keys(), help='Choose networks.') | ||||
| parser.add_argument('--batch_size',        type=int,   help='the batch size') | ||||
| parser.add_argument('--learning_rate_max', type=float, help='initial learning rate') | ||||
| parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate') | ||||
| parser.add_argument('--momentum',          type=float, help='momentum') | ||||
| parser.add_argument('--weight_decay',      type=float, help='weight decay') | ||||
| parser.add_argument('--epochs',            type=int,   help='num of training epochs') | ||||
| # architecture leraning rate | ||||
| parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
| parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
| # | ||||
| parser.add_argument('--init_channels',      type=int, help='num of init channels') | ||||
| parser.add_argument('--layers',             type=int, help='total number of layers') | ||||
| #  | ||||
| parser.add_argument('--cutout',         type=int,   help='cutout length, negative means no cutout') | ||||
| parser.add_argument('--grad_clip',      type=float, help='gradient clipping') | ||||
| parser.add_argument('--model_config',   type=str  , help='the model configuration') | ||||
|  | ||||
| # resume | ||||
| parser.add_argument('--resume',         type=str  , help='the resume path') | ||||
| parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model') | ||||
| # split data | ||||
| parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not') | ||||
| parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') | ||||
| # log | ||||
| parser.add_argument('--workers',       type=int, default=2, help='number of data loading workers (default: 2)') | ||||
| parser.add_argument('--save_path',     type=str, help='Folder to save checkpoints and log.') | ||||
| parser.add_argument('--print_freq',    type=int, help='print frequency (default: 200)') | ||||
| parser.add_argument('--manualSeed',    type=int, help='manual seed') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| assert torch.cuda.is_available(), 'torch.cuda is not available' | ||||
|  | ||||
| if args.manualSeed is None: | ||||
|   args.manualSeed = random.randint(1, 10000) | ||||
| random.seed(args.manualSeed) | ||||
| cudnn.benchmark = True | ||||
| cudnn.enabled   = True | ||||
| torch.manual_seed(args.manualSeed) | ||||
| torch.cuda.manual_seed_all(args.manualSeed) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|  | ||||
|   # Init logger | ||||
|   args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed)) | ||||
|   if not os.path.isdir(args.save_path): | ||||
|     os.makedirs(args.save_path) | ||||
|   log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w') | ||||
|   print_log('save path : {}'.format(args.save_path), log) | ||||
|   state = {k: v for k, v in args._get_kwargs()} | ||||
|   print_log(state, log) | ||||
|   print_log("Random Seed: {}".format(args.manualSeed), log) | ||||
|   print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log) | ||||
|   print_log("Torch  version : {}".format(torch.__version__), log) | ||||
|   print_log("CUDA   version : {}".format(torch.version.cuda), log) | ||||
|   print_log("cuDNN  version : {}".format(cudnn.version()), log) | ||||
|   print_log("Num of GPUs    : {}".format(torch.cuda.device_count()), log) | ||||
|   args.dataset = args.dataset.lower() | ||||
|  | ||||
|   # Mean + Std | ||||
|   if args.dataset == 'cifar10': | ||||
|     mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|     std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|   elif args.dataset == 'cifar100': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   elif args.dataset == 'tiered': | ||||
|     mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Data Argumentation | ||||
|   if args.dataset == 'cifar10' or args.dataset == 'cifar100': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), | ||||
|              transforms.Normalize(mean, std)] | ||||
|     if args.cutout > 0 : lists += [Cutout(args.cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|   elif args.dataset == 'tiered': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if args.cutout > 0 : lists += [Cutout(args.cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Datasets | ||||
|   if args.dataset == 'cifar10': | ||||
|     train_data = dset.CIFAR10(args.data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10(args.data_path, train=False, transform=test_transform , download=True) | ||||
|     num_classes, head = 10, 'cifar' | ||||
|   elif args.dataset == 'cifar100': | ||||
|     train_data = dset.CIFAR100(args.data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(args.data_path, train=False, transform=test_transform , download=True) | ||||
|     num_classes, head = 100, 'cifar' | ||||
|   elif args.dataset == 'tiered': | ||||
|     train_data = TieredImageNet(args.data_path, 'train-val', train_transform) | ||||
|     test_data = None | ||||
|     num_classes, head = train_data.n_classes, 'imagenet' | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(args.dataset)) | ||||
|   # Data Loader | ||||
|   if args.validate: | ||||
|     indices = list(range(len(train_data))) | ||||
|     split   = int(args.train_portion * len(indices)) | ||||
|     random.shuffle(indices) | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), | ||||
|                       pin_memory=True, num_workers=args.workers) | ||||
|     test_loader  = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, | ||||
|                       sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), | ||||
|                       pin_memory=True, num_workers=args.workers) | ||||
|   else: | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) | ||||
|     test_loader  = torch.utils.data.DataLoader(test_data,  batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) | ||||
|  | ||||
|   # network and criterion | ||||
|   criterion = torch.nn.CrossEntropyLoss().cuda() | ||||
|   basemodel = Networks[args.arch](args.init_channels, num_classes, args.layers, head=head) | ||||
|   model     = torch.nn.DataParallel(basemodel).cuda() | ||||
|   print_log("Network : {:}".format(model), log) | ||||
|   print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log) | ||||
|   print_log("Train-transformation : {:}\nTest--transformation : {:}\nClass number : {:}".format(train_transform, test_transform, num_classes), log) | ||||
|  | ||||
|   # optimizer and LR-scheduler | ||||
|   base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay) | ||||
|   base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min) | ||||
|   arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) | ||||
|  | ||||
|   # snapshot | ||||
|   checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth') | ||||
|   if args.resume is not None and os.path.isfile(args.resume): | ||||
|     checkpoint = torch.load(args.resume) | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict( checkpoint['state_dict'] ) | ||||
|     base_optimizer.load_state_dict( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] ) | ||||
|     base_scheduler.load_state_dict( checkpoint['base_scheduler'] ) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log) | ||||
|   elif os.path.isfile(checkpoint_path): | ||||
|     checkpoint = torch.load(checkpoint_path) | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict( checkpoint['state_dict'] ) | ||||
|     base_optimizer.load_state_dict( checkpoint['base_optimizer'] ) | ||||
|     arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] ) | ||||
|     base_scheduler.load_state_dict( checkpoint['base_scheduler'] ) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log) | ||||
|   else: | ||||
|     start_epoch, genotypes = 0, {} | ||||
|     print_log('Train model-search from scratch.', log) | ||||
|  | ||||
|   config = load_config(args.model_config) | ||||
|  | ||||
|   if args.only_base: | ||||
|     print_log('---- Only Train the Searched Model ----', log) | ||||
|     main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log) | ||||
|     return | ||||
|  | ||||
|   # Main loop | ||||
|   start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0 | ||||
|   for epoch in range(start_epoch, args.epochs): | ||||
|     base_scheduler.step() | ||||
|  | ||||
|     need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True) | ||||
|     print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}] [Batch={:d}]'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr()), args.batch_size), log) | ||||
|  | ||||
|     genotype = basemodel.genotype() | ||||
|     print_log('genotype = {:}'.format(genotype), log) | ||||
|  | ||||
|     print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log) | ||||
|  | ||||
|     # training | ||||
|     train_acc1, train_acc5, train_obj, train_time \ | ||||
|                                       = train(train_loader, test_loader, model, criterion, base_optimizer, arch_optimizer, epoch, log) | ||||
|     total_train_time += train_time | ||||
|     # validation | ||||
|     valid_acc1, valid_acc5, valid_obj = infer(test_loader, model, criterion, epoch, log) | ||||
|     print_log('Base-Search : {:03d}/{:03d} : Train-Acc={:.3f}, Test-Acc={:.3f}'.format(epoch, args.epochs, train_acc1, valid_acc1), log) | ||||
|     # save genotype | ||||
|     genotypes[epoch] = basemodel.genotype() | ||||
|     # save checkpoint | ||||
|     torch.save({'epoch' : epoch + 1, | ||||
|                 'args'  : deepcopy(args), | ||||
|                 'state_dict': basemodel.state_dict(), | ||||
|                 'genotypes' : genotypes, | ||||
|                 'base_optimizer' : base_optimizer.state_dict(), | ||||
|                 'arch_optimizer' : arch_optimizer.state_dict(), | ||||
|                 'base_scheduler' : base_scheduler.state_dict()}, | ||||
|                 checkpoint_path) | ||||
|     print_log('----> Save into {:}'.format(checkpoint_path), log) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log) | ||||
|  | ||||
|   # clear GPU cache | ||||
|   torch.cuda.empty_cache() | ||||
|   main_procedure(config, 'cifar10', os.environ['TORCH_HOME'] + '/cifar.python', args, basemodel.genotype(), 36, 20, log) | ||||
|   log.close() | ||||
|  | ||||
|  | ||||
| def train(train_queue, valid_queue, model, criterion, base_optimizer, arch_optimizer, epoch, log): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   model.train() | ||||
|  | ||||
|   valid_iter = iter(valid_queue) | ||||
|   end = time.time() | ||||
|   for step, (inputs, targets) in enumerate(train_queue): | ||||
|     batch, C, H, W = inputs.size() | ||||
|  | ||||
|     #inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     data_time.update(time.time() - end) | ||||
|  | ||||
|     # get a random minibatch from the search queue with replacement | ||||
|     try: | ||||
|       input_search, target_search = next(valid_iter) | ||||
|     except: | ||||
|       valid_iter = iter(valid_queue) | ||||
|       input_search, target_search = next(valid_iter) | ||||
|      | ||||
|     target_search = target_search.cuda(non_blocking=True) | ||||
|  | ||||
|     # update the architecture | ||||
|     arch_optimizer.zero_grad() | ||||
|     output_search = model(input_search) | ||||
|     arch_loss = criterion(output_search, target_search) | ||||
|     arch_loss.backward() | ||||
|     arch_optimizer.step() | ||||
|  | ||||
|     # update the parameters | ||||
|     base_optimizer.zero_grad() | ||||
|     logits = model(inputs) | ||||
|     loss = criterion(logits, targets) | ||||
|  | ||||
|     loss.backward() | ||||
|     torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip) | ||||
|     base_optimizer.step() | ||||
|  | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     objs.update(loss.item() , batch) | ||||
|     top1.update(prec1.item(), batch) | ||||
|     top5.update(prec5.item(), batch) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if step % args.print_freq == 0 or (step+1) == len(train_queue): | ||||
|       Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5) | ||||
|       print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return top1.avg, top5.avg, objs.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def infer(valid_queue, model, criterion, epoch, log): | ||||
|   objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|    | ||||
|   model.eval() | ||||
|   with torch.no_grad(): | ||||
|     for step, (inputs, targets) in enumerate(valid_queue): | ||||
|       batch, C, H, W = inputs.size() | ||||
|       targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|       logits = model(inputs) | ||||
|       loss = criterion(logits, targets) | ||||
|  | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       objs.update(loss.item() , batch) | ||||
|       top1.update(prec1.item(), batch) | ||||
|       top5.update(prec5.item(), batch) | ||||
|  | ||||
|       if step % args.print_freq == 0 or (step+1) == len(valid_queue): | ||||
|         Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue)) | ||||
|         Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5) | ||||
|         print_log(Sstr + ' ' + Lstr, log) | ||||
|  | ||||
|   return top1.avg, top5.avg, objs.avg | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   main()  | ||||
							
								
								
									
										184
									
								
								exps-cnn/train_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										184
									
								
								exps-cnn/train_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,184 @@ | ||||
| import os, sys, time | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torchvision.datasets as dset | ||||
| import torchvision.transforms as transforms | ||||
|  | ||||
|  | ||||
| from utils import print_log, obtain_accuracy, AverageMeter | ||||
| from utils import time_string, convert_secs2time | ||||
| from utils import count_parameters_in_MB | ||||
| from utils import Cutout | ||||
| from nas import NetworkCIFAR as Network | ||||
|  | ||||
| def obtain_best(accuracies): | ||||
|   if len(accuracies) == 0: return (0, 0) | ||||
|   tops = [value for key, value in accuracies.items()] | ||||
|   s2b = sorted( tops ) | ||||
|   return s2b[-1] | ||||
|  | ||||
| def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log): | ||||
|    | ||||
|   # Mean + Std | ||||
|   if dataset == 'cifar10': | ||||
|     mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|     std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|   elif dataset == 'cifar100': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(dataset)) | ||||
|   # Dataset Transformation | ||||
|   if dataset == 'cifar10' or dataset == 'cifar100': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), | ||||
|              transforms.Normalize(mean, std)] | ||||
|     if config.cutout > 0 : lists += [Cutout(config.cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(dataset)) | ||||
|   # Dataset Defination | ||||
|   if dataset == 'cifar10': | ||||
|     train_data = dset.CIFAR10(data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10(data_path, train=False, transform=test_transform , download=True) | ||||
|     class_num  = 10 | ||||
|   elif dataset == 'cifar100': | ||||
|     train_data = dset.CIFAR100(data_path, train= True, transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(data_path, train=False, transform=test_transform , download=True) | ||||
|     class_num  = 100 | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(dataset)) | ||||
|  | ||||
|  | ||||
|   print_log('-------------------------------------- main-procedure', log) | ||||
|   print_log('config        : {:}'.format(config), log) | ||||
|   print_log('genotype      : {:}'.format(genotype), log) | ||||
|   print_log('init_channels : {:}'.format(init_channels), log) | ||||
|   print_log('layers        : {:}'.format(layers), log) | ||||
|   print_log('class_num     : {:}'.format(class_num), log) | ||||
|   basemodel = Network(init_channels, class_num, layers, config.auxiliary, genotype) | ||||
|   model     = torch.nn.DataParallel(basemodel).cuda() | ||||
|  | ||||
|   total_param, aux_param = count_parameters_in_MB(basemodel), count_parameters_in_MB(basemodel.auxiliary_param()) | ||||
|   print_log('Network =>\n{:}'.format(basemodel), log) | ||||
|   print_log('Parameters : {:} - {:} = {:.3f} MB'.format(total_param, aux_param, total_param - aux_param), log) | ||||
|   print_log('config        : {:}'.format(config), log) | ||||
|   print_log('genotype      : {:}'.format(genotype), log) | ||||
|   print_log('args          : {:}'.format(args), log) | ||||
|   print_log('Train-Dataset : {:}'.format(train_data), log) | ||||
|   print_log('Train-Trans   : {:}'.format(train_transform), log) | ||||
|   print_log('Test--Dataset : {:}'.format(test_data ), log) | ||||
|   print_log('Test--Trans   : {:}'.format(test_transform ), log) | ||||
|  | ||||
|  | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True, | ||||
|                          num_workers=args.workers, pin_memory=True) | ||||
|   test_loader  = torch.utils.data.DataLoader(test_data , batch_size=config.batch_size, shuffle=False, | ||||
|                          num_workers=args.workers, pin_memory=True) | ||||
|  | ||||
|   criterion = torch.nn.CrossEntropyLoss().cuda() | ||||
|    | ||||
|   optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay) | ||||
|   #optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True) | ||||
|   if config.type == 'cosine': | ||||
|     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(config.epochs)) | ||||
|   else: | ||||
|     raise ValueError('Can not find the schedular type : {:}'.format(config.type)) | ||||
|  | ||||
|   checkpoint_path = os.path.join(args.save_path, 'checkpoint-{:}-model.pth'.format(dataset)) | ||||
|   if os.path.isfile(checkpoint_path): | ||||
|     checkpoint = torch.load( checkpoint_path ) | ||||
|  | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict(checkpoint['state_dict']) | ||||
|     optimizer.load_state_dict(checkpoint['optimizer']) | ||||
|     scheduler.load_state_dict(checkpoint['scheduler']) | ||||
|     accuracies  = checkpoint['accuracies'] | ||||
|     print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log) | ||||
|   else: | ||||
|     start_epoch, accuracies = 0, {} | ||||
|     print_log('Train model from scratch without pre-trained model or snapshot', log) | ||||
|  | ||||
|  | ||||
|   # Main loop | ||||
|   start_time, epoch_time = time.time(), AverageMeter() | ||||
|   for epoch in range(start_epoch, config.epochs): | ||||
|     scheduler.step() | ||||
|  | ||||
|     need_time = convert_secs2time(epoch_time.val * (config.epochs-epoch), True) | ||||
|     print_log("\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} LR={:6.4f} ~ {:6.4f}, Batch={:d}".format(time_string(), epoch, config.epochs, need_time, min(scheduler.get_lr()), max(scheduler.get_lr()), config.batch_size), log) | ||||
|  | ||||
|     basemodel.update_drop_path(config.drop_path_prob * epoch / config.epochs) | ||||
|  | ||||
|     train_acc1, train_acc5, train_los = _train(train_loader, model, criterion, optimizer, 'train', epoch, config, args.print_freq, log) | ||||
|  | ||||
|     with torch.no_grad(): | ||||
|       valid_acc1, valid_acc5, valid_los = _train(test_loader, model, criterion, optimizer, 'test', epoch, config, args.print_freq, log) | ||||
|     accuracies[epoch] = (valid_acc1, valid_acc5) | ||||
|  | ||||
|     torch.save({'epoch'     : epoch + 1, | ||||
|                 'args'      : deepcopy(args), | ||||
|                 'state_dict': basemodel.state_dict(), | ||||
|                 'optimizer' : optimizer.state_dict(), | ||||
|                 'scheduler' : scheduler.state_dict(), | ||||
|                 'accuracies': accuracies}, | ||||
|                 checkpoint_path) | ||||
|     best_acc = obtain_best( accuracies ) | ||||
|     print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log) | ||||
|     print_log('----> Save into {:}'.format(checkpoint_path), log) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|  | ||||
| def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log): | ||||
|   data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   if mode == 'train': | ||||
|     model.train() | ||||
|   elif mode == 'test': | ||||
|     model.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|    | ||||
|   end = time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|     # calculate prediction and loss | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|  | ||||
|     if config.auxiliary and model.training: | ||||
|       logits, logits_aux = model(inputs) | ||||
|     else: | ||||
|       logits = model(inputs) | ||||
|  | ||||
|     loss = criterion(logits, targets) | ||||
|     if config.auxiliary and model.training: | ||||
|       loss_aux = criterion(logits_aux, targets) | ||||
|       loss += config.auxiliary_weight * loss_aux | ||||
|      | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       if config.grad_clip > 0: | ||||
|         torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) | ||||
|       optimizer.step() | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),  inputs.size(0)) | ||||
|     top1.update  (prec1.item(), inputs.size(0)) | ||||
|     top5.update  (prec5.item(), inputs.size(0)) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if i % print_freq == 0 or (i+1) == len(xloader): | ||||
|       Sstr = ' {:5s}'.format(mode) + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, i, len(xloader)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) | ||||
|       print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log) | ||||
|  | ||||
|   print_log ('{TIME:} **{mode:}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(TIME=time_string(), mode=mode, top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg), log) | ||||
|   return top1.avg, top5.avg, losses.avg | ||||
							
								
								
									
										207
									
								
								exps-cnn/train_utils_imagenet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										207
									
								
								exps-cnn/train_utils_imagenet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,207 @@ | ||||
| import os, sys, time | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torchvision.datasets as dset | ||||
| import torchvision.transforms as transforms | ||||
|  | ||||
|  | ||||
| from utils import print_log, obtain_accuracy, AverageMeter | ||||
| from utils import time_string, convert_secs2time | ||||
| from utils import count_parameters_in_MB | ||||
| from utils import print_FLOPs | ||||
| from utils import Cutout | ||||
| from nas import NetworkImageNet as Network | ||||
|  | ||||
|  | ||||
| def obtain_best(accuracies): | ||||
|   if len(accuracies) == 0: return (0, 0) | ||||
|   tops = [value for key, value in accuracies.items()] | ||||
|   s2b = sorted( tops ) | ||||
|   return s2b[-1] | ||||
|  | ||||
|  | ||||
| class CrossEntropyLabelSmooth(nn.Module): | ||||
|  | ||||
|   def __init__(self, num_classes, epsilon): | ||||
|     super(CrossEntropyLabelSmooth, self).__init__() | ||||
|     self.num_classes = num_classes | ||||
|     self.epsilon = epsilon | ||||
|     self.logsoftmax = nn.LogSoftmax(dim=1) | ||||
|  | ||||
|   def forward(self, inputs, targets): | ||||
|     log_probs = self.logsoftmax(inputs) | ||||
|     targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) | ||||
|     targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | ||||
|     loss = (-targets * log_probs).mean(0).sum() | ||||
|     return loss | ||||
|  | ||||
|  | ||||
| def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, log): | ||||
|    | ||||
|   # training data and testing data | ||||
|   traindir = os.path.join(data_path, 'train') | ||||
|   validdir = os.path.join(data_path, 'val') | ||||
|   normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||
|   train_data = dset.ImageFolder( | ||||
|     traindir, | ||||
|     transforms.Compose([ | ||||
|       transforms.RandomResizedCrop(224), | ||||
|       transforms.RandomHorizontalFlip(), | ||||
|       transforms.ColorJitter( | ||||
|         brightness=0.4, | ||||
|         contrast=0.4, | ||||
|         saturation=0.4, | ||||
|         hue=0.2), | ||||
|       transforms.ToTensor(), | ||||
|       normalize, | ||||
|     ])) | ||||
|   valid_data = dset.ImageFolder( | ||||
|     validdir, | ||||
|     transforms.Compose([ | ||||
|       transforms.Resize(256), | ||||
|       transforms.CenterCrop(224), | ||||
|       transforms.ToTensor(), | ||||
|       normalize, | ||||
|     ])) | ||||
|  | ||||
|   train_queue = torch.utils.data.DataLoader( | ||||
|     train_data, batch_size=config.batch_size, shuffle= True, pin_memory=True, num_workers=args.workers) | ||||
|  | ||||
|   valid_queue = torch.utils.data.DataLoader( | ||||
|     valid_data, batch_size=config.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) | ||||
|  | ||||
|   class_num   = 1000 | ||||
|  | ||||
|  | ||||
|   print_log('-------------------------------------- main-procedure', log) | ||||
|   print_log('config        : {:}'.format(config), log) | ||||
|   print_log('genotype      : {:}'.format(genotype), log) | ||||
|   print_log('init_channels : {:}'.format(init_channels), log) | ||||
|   print_log('layers        : {:}'.format(layers), log) | ||||
|   print_log('class_num     : {:}'.format(class_num), log) | ||||
|   basemodel = Network(init_channels, class_num, layers, config.auxiliary, genotype) | ||||
|   model     = torch.nn.DataParallel(basemodel).cuda() | ||||
|  | ||||
|   total_param, aux_param = count_parameters_in_MB(basemodel), count_parameters_in_MB(basemodel.auxiliary_param()) | ||||
|   print_log('Network =>\n{:}'.format(basemodel), log) | ||||
|   #print_FLOPs(basemodel, (1,3,224,224), [print_log, log]) | ||||
|   print_log('Parameters : {:} - {:} = {:.3f} MB'.format(total_param, aux_param, total_param - aux_param), log) | ||||
|   print_log('config        : {:}'.format(config), log) | ||||
|   print_log('genotype      : {:}'.format(genotype), log) | ||||
|   print_log('Train-Dataset : {:}'.format(train_data), log) | ||||
|   print_log('Valid--Dataset : {:}'.format(valid_data), log) | ||||
|   print_log('Args          : {:}'.format(args), log) | ||||
|  | ||||
|  | ||||
|   criterion = torch.nn.CrossEntropyLoss().cuda() | ||||
|   criterion_smooth = CrossEntropyLabelSmooth(class_num, config.label_smooth).cuda() | ||||
|  | ||||
|  | ||||
|   optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay) | ||||
|   #optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True) | ||||
|   if config.type == 'cosine': | ||||
|     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(config.epochs)) | ||||
|   elif config.type == 'steplr': | ||||
|     scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.decay_period, gamma=config.gamma) | ||||
|   else: | ||||
|     raise ValueError('Can not find the schedular type : {:}'.format(config.type)) | ||||
|  | ||||
|  | ||||
|   checkpoint_path = os.path.join(args.save_path, 'checkpoint-imagenet-model.pth') | ||||
|   if os.path.isfile(checkpoint_path): | ||||
|     checkpoint = torch.load( checkpoint_path ) | ||||
|  | ||||
|     start_epoch = checkpoint['epoch'] | ||||
|     basemodel.load_state_dict(checkpoint['state_dict']) | ||||
|     optimizer.load_state_dict(checkpoint['optimizer']) | ||||
|     scheduler.load_state_dict(checkpoint['scheduler']) | ||||
|     accuracies  = checkpoint['accuracies'] | ||||
|     print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log) | ||||
|   else: | ||||
|     start_epoch, accuracies = 0, {} | ||||
|     print_log('Train model from scratch without pre-trained model or snapshot', log) | ||||
|  | ||||
|  | ||||
|   # Main loop | ||||
|   start_time, epoch_time = time.time(), AverageMeter() | ||||
|   for epoch in range(start_epoch, config.epochs): | ||||
|     scheduler.step() | ||||
|  | ||||
|     need_time = convert_secs2time(epoch_time.val * (config.epochs-epoch), True) | ||||
|     print_log("\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} LR={:6.4f} ~ {:6.4f}, Batch={:d}".format(time_string(), epoch, config.epochs, need_time, min(scheduler.get_lr()), max(scheduler.get_lr()), config.batch_size), log) | ||||
|  | ||||
|     basemodel.update_drop_path(config.drop_path_prob * epoch / config.epochs) | ||||
|  | ||||
|     train_acc1, train_acc5, train_los = _train(train_queue, model, criterion_smooth, optimizer, 'train', epoch, config, args.print_freq, log) | ||||
|  | ||||
|     with torch.no_grad(): | ||||
|       valid_acc1, valid_acc5, valid_los = _train(valid_queue, model, criterion,           None, 'test' , epoch, config, args.print_freq, log) | ||||
|     accuracies[epoch] = (valid_acc1, valid_acc5) | ||||
|  | ||||
|     torch.save({'epoch'     : epoch + 1, | ||||
|                 'args'      : deepcopy(args), | ||||
|                 'state_dict': basemodel.state_dict(), | ||||
|                 'optimizer' : optimizer.state_dict(), | ||||
|                 'scheduler' : scheduler.state_dict(), | ||||
|                 'accuracies': accuracies}, | ||||
|                 checkpoint_path) | ||||
|     best_acc = obtain_best( accuracies ) | ||||
|     print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log) | ||||
|     print_log('----> Save into {:}'.format(checkpoint_path), log) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|  | ||||
| def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log): | ||||
|   data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   if mode == 'train': | ||||
|     model.train() | ||||
|   elif mode == 'test': | ||||
|     model.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|    | ||||
|   end = time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     # measure data loading time | ||||
|     data_time.update(time.time() - end) | ||||
|     # calculate prediction and loss | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|  | ||||
|     if config.auxiliary and model.training: | ||||
|       logits, logits_aux = model(inputs) | ||||
|     else: | ||||
|       logits = model(inputs) | ||||
|  | ||||
|     loss = criterion(logits, targets) | ||||
|     if config.auxiliary and model.training: | ||||
|       loss_aux = criterion(logits_aux, targets) | ||||
|       loss += config.auxiliary_weight * loss_aux | ||||
|      | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       if config.grad_clip > 0: | ||||
|         torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) | ||||
|       optimizer.step() | ||||
|     # record | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),  inputs.size(0)) | ||||
|     top1.update  (prec1.item(), inputs.size(0)) | ||||
|     top5.update  (prec5.item(), inputs.size(0)) | ||||
|  | ||||
|     # measure elapsed time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|  | ||||
|     if i % print_freq == 0 or (i+1) == len(xloader): | ||||
|       Sstr = ' {:5s}'.format(mode) + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, i, len(xloader)) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) | ||||
|       print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log) | ||||
|  | ||||
|   print_log ('{TIME:} **{mode:}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(TIME=time_string(), mode=mode, top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg), log) | ||||
|   return top1.avg, top5.avg, losses.avg | ||||
							
								
								
									
										69
									
								
								exps-cnn/vis-arch.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								exps-cnn/vis-arch.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from graphviz import Digraph | ||||
|  | ||||
| parser = argparse.ArgumentParser("Visualize the Networks") | ||||
| parser.add_argument('--checkpoint', type=str,   help='The path to the checkpoint.') | ||||
| parser.add_argument('--save_dir',   type=str,   help='The directory to save the network plot.') | ||||
| args = parser.parse_args() | ||||
|  | ||||
|  | ||||
| def plot(genotype, filename): | ||||
|   g = Digraph( | ||||
|       format='pdf', | ||||
|       edge_attr=dict(fontsize='20', fontname="times"), | ||||
|       node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), | ||||
|       engine='dot') | ||||
|   g.body.extend(['rankdir=LR']) | ||||
|  | ||||
|   g.node("c_{k-2}", fillcolor='darkseagreen2') | ||||
|   g.node("c_{k-1}", fillcolor='darkseagreen2') | ||||
|   assert len(genotype) % 2 == 0 | ||||
|   steps = len(genotype) // 2 | ||||
|  | ||||
|   for i in range(steps): | ||||
|     g.node(str(i), fillcolor='lightblue') | ||||
|  | ||||
|   for i in range(steps): | ||||
|     for k in [2*i, 2*i + 1]: | ||||
|       op, j, weight = genotype[k] | ||||
|       if j == 0: | ||||
|         u = "c_{k-2}" | ||||
|       elif j == 1: | ||||
|         u = "c_{k-1}" | ||||
|       else: | ||||
|         u = str(j-2) | ||||
|       v = str(i) | ||||
|       g.edge(u, v, label=op, fillcolor="gray") | ||||
|  | ||||
|   g.node("c_{k}", fillcolor='palegoldenrod') | ||||
|   for i in range(steps): | ||||
|     g.edge(str(i), "c_{k}", fillcolor="gray") | ||||
|  | ||||
|   g.render(filename, view=False) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   checkpoint = args.checkpoint | ||||
|   assert os.path.isfile(checkpoint), 'Invalid path for checkpoint : {:}'.format(checkpoint) | ||||
|   checkpoint = torch.load( checkpoint, map_location='cpu' ) | ||||
|   genotypes  = checkpoint['genotypes'] | ||||
|   save_dir   = Path(args.save_dir) | ||||
|   subs       = ['normal', 'reduce'] | ||||
|   for sub in subs: | ||||
|     if not (save_dir / sub).exists(): | ||||
|       (save_dir / sub).mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|   for key, network in genotypes.items(): | ||||
|     save_path = str(save_dir / 'normal' / 'epoch-{:03d}'.format( int(key) )) | ||||
|     print('save into {:}'.format(save_path)) | ||||
|     plot(network.normal, save_path) | ||||
|  | ||||
|     save_path = str(save_dir / 'reduce' / 'epoch-{:03d}'.format( int(key) )) | ||||
|     print('save into {:}'.format(save_path)) | ||||
|     plot(network.reduce, save_path) | ||||
		Reference in New Issue
	
	Block a user