update NAS-Bench
This commit is contained in:
		| @@ -1,4 +1,4 @@ | ||||
| import os, sys, time, random, argparse | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_attention_args(): | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################## | ||||
| import os, sys, time, random, argparse | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_basic_args(): | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| import os, sys, time, random, argparse | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_cls_init_args(): | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| import os, sys, time, random, argparse | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_cls_kd_args(): | ||||
|   | ||||
| @@ -4,7 +4,7 @@ | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import os, sys, json | ||||
| import os, json | ||||
| from os import path as osp | ||||
| from pathlib import Path | ||||
| from collections import namedtuple | ||||
|   | ||||
| @@ -39,6 +39,13 @@ def get_cell_based_tiny_net(config): | ||||
|       genotype = CellStructure.str2structure(config.arch_str) | ||||
|     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) | ||||
|     return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||
|   elif config.name == 'infer.shape.tiny': | ||||
|     from .shape_infers import DynamicShapeTinyNet | ||||
|     if isinstance(config.channels, str): | ||||
|       channels = tuple([int(x) for x in config.channels.split(':')]) | ||||
|     else: channels = config.channels | ||||
|     genotype = CellStructure.str2structure(config.genotype) | ||||
|     return DynamicShapeTinyNet(channels, genotype, config.num_classes) | ||||
|   elif config.name == 'infer.nasnet-cifar': | ||||
|     from .cell_infers import NASNetonCIFAR | ||||
|     raise NotImplementedError | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .cells import InferCell | ||||
|   | ||||
| @@ -172,14 +172,19 @@ class FactorizedReduce(nn.Module): | ||||
|       for i in range(2): | ||||
|         self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) | ||||
|       self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|     elif stride == 1: | ||||
|       self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False) | ||||
|     else: | ||||
|       raise ValueError('Invalid stride : {:}'.format(stride)) | ||||
|     self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.relu(x) | ||||
|     y = self.pad(x) | ||||
|     out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) | ||||
|     if self.stride == 2: | ||||
|       x = self.relu(x) | ||||
|       y = self.pad(x) | ||||
|       out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) | ||||
|     else: | ||||
|       out = self.conv(x) | ||||
|     out = self.bn(out) | ||||
|     return out | ||||
|  | ||||
|   | ||||
| @@ -14,11 +14,11 @@ from .search_model_darts_nasnet import NASNetworkDARTS | ||||
|  | ||||
|  | ||||
| nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, | ||||
|                   'DARTS-V2': TinyNetworkDarts, | ||||
|                   'GDAS'    : TinyNetworkGDAS, | ||||
|                   'SETN'    : TinyNetworkSETN, | ||||
|                   'ENAS'    : TinyNetworkENAS, | ||||
|                   'RANDOM'  : TinyNetworkRANDOM} | ||||
|                      "DARTS-V2": TinyNetworkDarts, | ||||
|                      "GDAS": TinyNetworkGDAS, | ||||
|                      "SETN": TinyNetworkSETN, | ||||
|                      "ENAS": TinyNetworkENAS, | ||||
|                      "RANDOM": TinyNetworkRANDOM} | ||||
|  | ||||
| nasnet_super_nets = {'GDAS' : NASNetworkGDAS, | ||||
|                      'DARTS': NASNetworkDARTS} | ||||
| nasnet_super_nets = {"GDAS": NASNetworkGDAS, | ||||
|                      "DARTS": NASNetworkDARTS} | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| #################### | ||||
| # DARTS, ICLR 2019 #  | ||||
| # DARTS, ICLR 2019 # | ||||
| #################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| @@ -11,7 +11,8 @@ from .search_cells import NASNetSearchCell as SearchCell | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkDARTS(nn.Module): | ||||
|  | ||||
|   def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): | ||||
|   def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, | ||||
|                num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): | ||||
|     super(NASNetworkDARTS, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
|   | ||||
| @@ -6,14 +6,15 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import List, Text, Dict | ||||
| from .search_cells     import NASNetSearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkSETN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): | ||||
|   def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, | ||||
|                num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): | ||||
|     super(NASNetworkSETN, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
| @@ -45,6 +46,16 @@ class NASNetworkSETN(nn.Module): | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||
|     self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||
|     self.mode = 'urs' | ||||
|     self.dynamic_cell = None | ||||
|  | ||||
|   def set_cal_mode(self, mode, dynamic_cell=None): | ||||
|     assert mode in ['urs', 'joint', 'select', 'dynamic'] | ||||
|     self.mode = mode | ||||
|     if mode == 'dynamic': | ||||
|       self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|     else: | ||||
|       self.dynamic_cell = None | ||||
|  | ||||
|   def get_weights(self): | ||||
|     xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) | ||||
| @@ -70,6 +81,24 @@ class NASNetworkSETN(nn.Module): | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def dync_genotype(self, use_random=False): | ||||
|     genotypes = [] | ||||
|     with torch.no_grad(): | ||||
|       alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|     for i in range(1, self.max_nodes): | ||||
|       xlist = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         if use_random: | ||||
|           op_name  = random.choice(self.op_names) | ||||
|         else: | ||||
|           weights  = alphas_cpu[ self.edge2index[node_str] ] | ||||
|           op_index = torch.multinomial(weights, 1).item() | ||||
|           op_name  = self.op_names[ op_index ] | ||||
|         xlist.append((op_name, j)) | ||||
|       genotypes.append( tuple(xlist) ) | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|   def genotype(self): | ||||
|     def _parse(weights): | ||||
|       gene = [] | ||||
| @@ -94,9 +123,6 @@ class NASNetworkSETN(nn.Module): | ||||
|   def forward(self, inputs): | ||||
|     normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1) | ||||
|     reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1) | ||||
|     with torch.no_grad(): | ||||
|       normal_hardwts_cpu = normal_hardwts.detach().cpu() | ||||
|       reduce_hardwts_cpu = reduce_hardwts.detach().cpu() | ||||
|  | ||||
|     s0 = s1 = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| import math, torch | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils    import additive_func | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| import math, torch | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils    import additive_func | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| import math | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils    import additive_func | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|   | ||||
| @@ -1,8 +1,9 @@ | ||||
| import math, torch | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils    import additive_func | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|   | ||||
| @@ -1,7 +1,10 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 | ||||
| from torch import nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils    import additive_func, parse_channel_info | ||||
| from ..SharedUtils    import parse_channel_info | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|   | ||||
							
								
								
									
										58
									
								
								lib/models/shape_infers/InferTinyCellNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								lib/models/shape_infers/InferTinyCellNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| from typing import List, Text, Any | ||||
| import torch.nn as nn | ||||
| from models.cell_operations import ResNetBasicblock | ||||
| from models.cell_infers.cells import InferCell | ||||
|  | ||||
|  | ||||
| class DynamicShapeTinyNet(nn.Module): | ||||
|  | ||||
|   def __init__(self, channels: List[int], genotype: Any, num_classes: int): | ||||
|     super(DynamicShapeTinyNet, self).__init__() | ||||
|     self._channels = channels | ||||
|     if len(channels) % 3 != 2: | ||||
|       raise ValueError('invalid number of layers : {:}'.format(len(channels))) | ||||
|     self._num_stage = N = len(channels) // 3 | ||||
|  | ||||
|     self.stem = nn.Sequential( | ||||
|                     nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), | ||||
|                     nn.BatchNorm2d(channels[0])) | ||||
|  | ||||
|     # layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|     c_prev = channels[0] | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): | ||||
|       if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True) | ||||
|       else         : cell = InferCell(genotype, c_prev, c_curr, 1) | ||||
|       self.cells.append( cell ) | ||||
|       c_prev = cell.out_dim | ||||
|     self._num_layer = len(self.cells) | ||||
|  | ||||
|     self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(c_prev, num_classes) | ||||
|  | ||||
|   def get_message(self) -> Text: | ||||
|     string = self.extra_repr() | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||
|     return string | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     feature = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       feature = cell(feature) | ||||
|  | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling( out ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     return out, logits | ||||
| @@ -1,5 +1,9 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| from .InferCifarResNet_width import InferWidthCifarResNet | ||||
| from .InferImagenetResNet    import InferImagenetResNet | ||||
| from .InferImagenetResNet import InferImagenetResNet | ||||
| from .InferCifarResNet_depth import InferDepthCifarResNet | ||||
| from .InferCifarResNet       import InferCifarResNet | ||||
| from .InferMobileNetV2       import InferMobileNetV2 | ||||
| from .InferCifarResNet import InferCifarResNet | ||||
| from .InferMobileNetV2 import InferMobileNetV2 | ||||
| from .InferTinyCellNet import DynamicShapeTinyNet | ||||
| @@ -7,7 +7,8 @@ | ||||
| # [2020.03.08] Next version (coming soon) | ||||
| # | ||||
| # | ||||
| import os, sys, copy, random, torch, numpy as np | ||||
| import os, copy, random, torch, numpy as np | ||||
| from typing import List, Text, Union, Dict, Any | ||||
| from collections import OrderedDict, defaultdict | ||||
|  | ||||
|  | ||||
| @@ -43,7 +44,7 @@ This is the class for API of NAS-Bench-201. | ||||
| class NASBench201API(object): | ||||
|  | ||||
|   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ | ||||
|   def __init__(self, file_path_or_dict, verbose=True): | ||||
|   def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True): | ||||
|     if isinstance(file_path_or_dict, str): | ||||
|       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
| @@ -69,7 +70,7 @@ class NASBench201API(object): | ||||
|       assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) | ||||
|       self.archstr2index[ arch ] = idx | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|   def __getitem__(self, index: int): | ||||
|     return copy.deepcopy( self.meta_archs[index] ) | ||||
|  | ||||
|   def __len__(self): | ||||
| @@ -99,7 +100,7 @@ class NASBench201API(object): | ||||
|  | ||||
|   # Overwrite all information of the 'index'-th architecture in the search space. | ||||
|   # It will load its data from 'archive_root'. | ||||
|   def reload(self, archive_root, index): | ||||
|   def reload(self, archive_root: Text, index: int): | ||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||
|     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) | ||||
| @@ -141,7 +142,8 @@ class NASBench201API(object): | ||||
|   #  -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||
|   #  -- cifar100 : training the model on the CIFAR-100 training set. | ||||
|   #  -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||
|   def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): | ||||
|   def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, | ||||
|                      use_12epochs_result: bool = False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) | ||||
| @@ -177,7 +179,7 @@ class NASBench201API(object): | ||||
|     return best_index, highest_accuracy | ||||
|  | ||||
|   # return the topology structure of the `index`-th architecture | ||||
|   def arch(self, index): | ||||
|   def arch(self, index: int): | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
| @@ -238,7 +240,7 @@ class NASBench201API(object): | ||||
|   # `is_random` | ||||
|   #   When is_random=True, the performance of a random architecture will be returned | ||||
|   #   When is_random=False, the performanceo of all trials will be averaged. | ||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|   def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
| @@ -301,7 +303,7 @@ class NASBench201API(object): | ||||
|   If the index < 0: it will loop for all architectures and print their information one by one. | ||||
|   else: it will print the information of the 'index'-th archiitecture. | ||||
|   """ | ||||
|   def show(self, index=-1): | ||||
|   def show(self, index: int = -1) -> None: | ||||
|     if index < 0: # show all architectures | ||||
|       print(self) | ||||
|       for i, idx in enumerate(self.evaluated_indexes): | ||||
| @@ -336,8 +338,8 @@ class NASBench201API(object): | ||||
|   #   for i, node in enumerate(arch): | ||||
|   #     print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) | ||||
|   @staticmethod | ||||
|   def str2lists(xstr): | ||||
|     assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||
|   def str2lists(xstr: Text) -> List[Any]: | ||||
|     # assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||
|     nodestrs = xstr.split('+') | ||||
|     genotypes = [] | ||||
|     for i, node_str in enumerate(nodestrs): | ||||
|   | ||||
| @@ -3,6 +3,8 @@ | ||||
| ################################################## | ||||
| from .starts     import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint | ||||
| from .optimizers import get_optim_scheduler | ||||
| from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed | ||||
| from .funcs_nasbench import pure_evaluate as bench_pure_evaluate | ||||
|  | ||||
| def get_procedures(procedure): | ||||
|   from .basic_main     import basic_train, basic_valid | ||||
|   | ||||
							
								
								
									
										129
									
								
								lib/procedures/funcs_nasbench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								lib/procedures/funcs_nasbench.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| import time, torch | ||||
| from procedures   import prepare_seed, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net | ||||
|  | ||||
|  | ||||
| __all__ = ['evaluate_for_seed', 'pure_evaluate'] | ||||
|  | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   latencies = [] | ||||
|   network.eval() | ||||
|   with torch.no_grad(): | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|       targets = targets.cuda(non_blocking=True) | ||||
|       inputs  = inputs.cuda(non_blocking=True) | ||||
|       data_time.update(time.time() - end) | ||||
|       # forward | ||||
|       features, logits = network(inputs) | ||||
|       loss             = criterion(logits, targets) | ||||
|       batch_time.update(time.time() - end) | ||||
|       if batch is None or batch == inputs.size(0): | ||||
|         batch = inputs.size(0) | ||||
|         latencies.append( batch_time.val - data_time.val ) | ||||
|       # record loss and accuracy | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       losses.update(loss.item(),  inputs.size(0)) | ||||
|       top1.update  (prec1.item(), inputs.size(0)) | ||||
|       top5.update  (prec5.item(), inputs.size(0)) | ||||
|       end = time.time() | ||||
|   if len(latencies) > 2: latencies = latencies[1:] | ||||
|   return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   if mode == 'train'  : network.train() | ||||
|   elif mode == 'valid': network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|   data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|     targets = targets.cuda(non_blocking=True) | ||||
|     if mode == 'train': optimizer.zero_grad() | ||||
|     # forward | ||||
|     features, logits = network(inputs) | ||||
|     loss             = criterion(logits, targets) | ||||
|     # backward | ||||
|     if mode == 'train': | ||||
|       loss.backward() | ||||
|       optimizer.step() | ||||
|     # record loss and accuracy | ||||
|     prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|     losses.update(loss.item(),  inputs.size(0)) | ||||
|     top1.update  (prec1.item(), inputs.size(0)) | ||||
|     top5.update  (prec5.item(), inputs.size(0)) | ||||
|     # count time | ||||
|     batch_time.update(time.time() - end) | ||||
|     end = time.time() | ||||
|   return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger): | ||||
|  | ||||
|   prepare_seed(seed) # random seed | ||||
|   net = get_cell_based_tiny_net(arch_config) | ||||
|   #net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|   flop, param  = get_model_infos(net, opt_config.xshape) | ||||
|   logger.log('Network : {:}'.format(net.get_message()), False) | ||||
|   logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed)) | ||||
|   logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) | ||||
|   # train and valid | ||||
|   optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|   network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() | ||||
|   # start training | ||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup | ||||
|   train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||
|   train_times , valid_times, lrs = {}, {}, {} | ||||
|   for epoch in range(total_epoch): | ||||
|     scheduler.update(epoch, 0.0) | ||||
|     lr = min(scheduler.get_lr()) | ||||
|     train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') | ||||
|     train_losses[epoch] = train_loss | ||||
|     train_acc1es[epoch] = train_acc1  | ||||
|     train_acc5es[epoch] = train_acc5 | ||||
|     train_times [epoch] = train_tm | ||||
|     lrs[epoch] = lr | ||||
|     with torch.no_grad(): | ||||
|       for key, xloder in valid_loaders.items(): | ||||
|         valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder  , network, criterion,      None,      None, 'valid') | ||||
|         valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss | ||||
|         valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1  | ||||
|         valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5 | ||||
|         valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm | ||||
|  | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) ) | ||||
|     logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr)) | ||||
|   info_seed = {'flop' : flop, | ||||
|                'param': param, | ||||
|                'arch_config' : arch_config._asdict(), | ||||
|                'opt_config'  : opt_config._asdict(), | ||||
|                'total_epoch' : total_epoch , | ||||
|                'train_losses': train_losses, | ||||
|                'train_acc1es': train_acc1es, | ||||
|                'train_acc5es': train_acc5es, | ||||
|                'train_times' : train_times, | ||||
|                'valid_losses': valid_losses, | ||||
|                'valid_acc1es': valid_acc1es, | ||||
|                'valid_acc5es': valid_acc5es, | ||||
|                'valid_times' : valid_times, | ||||
|                'learning_rates': lrs, | ||||
|                'net_state_dict': net.state_dict(), | ||||
|                'net_string'  : '{:}'.format(net), | ||||
|                'finish-train': True | ||||
|               } | ||||
|   return info_seed | ||||
		Reference in New Issue
	
	Block a user