Prototype MAML
This commit is contained in:
		| @@ -8,98 +8,110 @@ from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class Bottleneck(nn.Module): | ||||
|   def __init__(self, nChannels, growthRate): | ||||
|     super(Bottleneck, self).__init__() | ||||
|     interChannels = 4*growthRate | ||||
|     self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|     self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) | ||||
|     self.bn2 = nn.BatchNorm2d(interChannels) | ||||
|     self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) | ||||
|     def __init__(self, nChannels, growthRate): | ||||
|         super(Bottleneck, self).__init__() | ||||
|         interChannels = 4 * growthRate | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) | ||||
|         self.bn2 = nn.BatchNorm2d(interChannels) | ||||
|         self.conv2 = nn.Conv2d( | ||||
|             interChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     out = self.conv1(F.relu(self.bn1(x))) | ||||
|     out = self.conv2(F.relu(self.bn2(out))) | ||||
|     out = torch.cat((x, out), 1) | ||||
|     return out | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = self.conv2(F.relu(self.bn2(out))) | ||||
|         out = torch.cat((x, out), 1) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class SingleLayer(nn.Module): | ||||
|   def __init__(self, nChannels, growthRate): | ||||
|     super(SingleLayer, self).__init__() | ||||
|     self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|     self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) | ||||
|     def __init__(self, nChannels, growthRate): | ||||
|         super(SingleLayer, self).__init__() | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d( | ||||
|             nChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     out = self.conv1(F.relu(self.bn1(x))) | ||||
|     out = torch.cat((x, out), 1) | ||||
|     return out | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = torch.cat((x, out), 1) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class Transition(nn.Module): | ||||
|   def __init__(self, nChannels, nOutChannels): | ||||
|     super(Transition, self).__init__() | ||||
|     self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|     self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) | ||||
|     def __init__(self, nChannels, nOutChannels): | ||||
|         super(Transition, self).__init__() | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     out = self.conv1(F.relu(self.bn1(x))) | ||||
|     out = F.avg_pool2d(out, 2) | ||||
|     return out | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = F.avg_pool2d(out, 2) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class DenseNet(nn.Module): | ||||
|   def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): | ||||
|     super(DenseNet, self).__init__() | ||||
|     def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): | ||||
|         super(DenseNet, self).__init__() | ||||
|  | ||||
|     if bottleneck:  nDenseBlocks = int( (depth-4) / 6 ) | ||||
|     else         :  nDenseBlocks = int( (depth-4) / 3 ) | ||||
|         if bottleneck: | ||||
|             nDenseBlocks = int((depth - 4) / 6) | ||||
|         else: | ||||
|             nDenseBlocks = int((depth - 4) / 3) | ||||
|  | ||||
|     self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses) | ||||
|         self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format( | ||||
|             "bottleneck" if bottleneck else "basic", | ||||
|             depth, | ||||
|             reduction, | ||||
|             growthRate, | ||||
|             nClasses, | ||||
|         ) | ||||
|  | ||||
|     nChannels = 2*growthRate | ||||
|     self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) | ||||
|         nChannels = 2 * growthRate | ||||
|         self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) | ||||
|  | ||||
|     self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|     nChannels += nDenseBlocks*growthRate | ||||
|     nOutChannels = int(math.floor(nChannels*reduction)) | ||||
|     self.trans1 = Transition(nChannels, nOutChannels) | ||||
|         self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|         nOutChannels = int(math.floor(nChannels * reduction)) | ||||
|         self.trans1 = Transition(nChannels, nOutChannels) | ||||
|  | ||||
|     nChannels = nOutChannels | ||||
|     self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|     nChannels += nDenseBlocks*growthRate | ||||
|     nOutChannels = int(math.floor(nChannels*reduction)) | ||||
|     self.trans2 = Transition(nChannels, nOutChannels) | ||||
|         nChannels = nOutChannels | ||||
|         self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|         nOutChannels = int(math.floor(nChannels * reduction)) | ||||
|         self.trans2 = Transition(nChannels, nOutChannels) | ||||
|  | ||||
|     nChannels = nOutChannels | ||||
|     self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|     nChannels += nDenseBlocks*growthRate | ||||
|         nChannels = nOutChannels | ||||
|         self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|  | ||||
|     self.act = nn.Sequential( | ||||
|                   nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), | ||||
|                   nn.AvgPool2d(8)) | ||||
|     self.fc  = nn.Linear(nChannels, nClasses) | ||||
|         self.act = nn.Sequential( | ||||
|             nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8) | ||||
|         ) | ||||
|         self.fc = nn.Linear(nChannels, nClasses) | ||||
|  | ||||
|     self.apply(initialize_resnet) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|   def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): | ||||
|     layers = [] | ||||
|     for i in range(int(nDenseBlocks)): | ||||
|       if bottleneck: | ||||
|         layers.append(Bottleneck(nChannels, growthRate)) | ||||
|       else: | ||||
|         layers.append(SingleLayer(nChannels, growthRate)) | ||||
|       nChannels += growthRate | ||||
|     return nn.Sequential(*layers) | ||||
|     def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): | ||||
|         layers = [] | ||||
|         for i in range(int(nDenseBlocks)): | ||||
|             if bottleneck: | ||||
|                 layers.append(Bottleneck(nChannels, growthRate)) | ||||
|             else: | ||||
|                 layers.append(SingleLayer(nChannels, growthRate)) | ||||
|             nChannels += growthRate | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     out = self.conv1( inputs ) | ||||
|     out = self.trans1(self.dense1(out)) | ||||
|     out = self.trans2(self.dense2(out)) | ||||
|     out = self.dense3(out) | ||||
|     features = self.act(out) | ||||
|     features = features.view(features.size(0), -1) | ||||
|     out = self.fc(features) | ||||
|     return features, out | ||||
|     def forward(self, inputs): | ||||
|         out = self.conv1(inputs) | ||||
|         out = self.trans1(self.dense1(out)) | ||||
|         out = self.trans2(self.dense2(out)) | ||||
|         out = self.dense3(out) | ||||
|         features = self.act(out) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         out = self.fc(features) | ||||
|         return features, out | ||||
|   | ||||
| @@ -2,156 +2,179 @@ import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .initialization import initialize_resnet | ||||
| from .SharedUtils    import additive_func | ||||
| from .SharedUtils import additive_func | ||||
|  | ||||
|  | ||||
| class Downsample(nn.Module):   | ||||
| class Downsample(nn.Module): | ||||
|     def __init__(self, nIn, nOut, stride): | ||||
|         super(Downsample, self).__init__() | ||||
|         assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format( | ||||
|             stride, nIn, nOut | ||||
|         ) | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) | ||||
|  | ||||
|   def __init__(self, nIn, nOut, stride): | ||||
|     super(Downsample, self).__init__()  | ||||
|     assert stride == 2 and nOut == 2*nIn, 'stride:{} IO:{},{}'.format(stride, nIn, nOut) | ||||
|     self.in_dim  = nIn | ||||
|     self.out_dim = nOut | ||||
|     self.avg  = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)    | ||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x   = self.avg(x) | ||||
|     out = self.conv(x) | ||||
|     return out | ||||
|     def forward(self, x): | ||||
|         x = self.avg(x) | ||||
|         out = self.conv(x) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|    | ||||
|   def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): | ||||
|     super(ConvBNReLU, self).__init__() | ||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias) | ||||
|     self.bn   = nn.BatchNorm2d(nOut) | ||||
|     if relu: self.relu = nn.ReLU(inplace=True) | ||||
|     else   : self.relu = None | ||||
|     self.out_dim = nOut | ||||
|     self.num_conv = 1 | ||||
|     def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias | ||||
|         ) | ||||
|         self.bn = nn.BatchNorm2d(nOut) | ||||
|         if relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.out_dim = nOut | ||||
|         self.num_conv = 1 | ||||
|  | ||||
|   def forward(self, x): | ||||
|     conv = self.conv( x ) | ||||
|     bn   = self.bn( conv ) | ||||
|     if self.relu: return self.relu( bn ) | ||||
|     else        : return bn | ||||
|     def forward(self, x): | ||||
|         conv = self.conv(x) | ||||
|         bn = self.bn(conv) | ||||
|         if self.relu: | ||||
|             return self.relu(bn) | ||||
|         else: | ||||
|             return bn | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|   expansion = 1 | ||||
|   def __init__(self, inplanes, planes, stride): | ||||
|     super(ResNetBasicblock, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) | ||||
|     self.conv_b = ConvBNReLU(  planes, planes, 3,      1, 1, False, False) | ||||
|     if stride == 2: | ||||
|       self.downsample = Downsample(inplanes, planes, stride) | ||||
|     elif inplanes != planes: | ||||
|       self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     self.out_dim = planes | ||||
|     self.num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) | ||||
|         self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False) | ||||
|         if stride == 2: | ||||
|             self.downsample = Downsample(inplanes, planes, stride) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.num_conv = 2 | ||||
|  | ||||
|     basicblock = self.conv_a(inputs) | ||||
|     basicblock = self.conv_b(basicblock) | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = additive_func(residual, basicblock) | ||||
|     return F.relu(out, inplace=True) | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|   expansion = 4 | ||||
|   def __init__(self, inplanes, planes, stride): | ||||
|     super(ResNetBottleneck, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     self.conv_1x1 = ConvBNReLU(inplanes, planes, 1,      1, 0, False, True) | ||||
|     self.conv_3x3 = ConvBNReLU(  planes, planes, 3, stride, 1, False, True) | ||||
|     self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, False) | ||||
|     if stride == 2: | ||||
|       self.downsample = Downsample(inplanes, planes*self.expansion, stride) | ||||
|     elif inplanes != planes*self.expansion: | ||||
|       self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, False) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     self.out_dim = planes * self.expansion | ||||
|     self.num_conv = 3 | ||||
|     expansion = 4 | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True) | ||||
|         self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, planes * self.expansion, 1, 1, 0, False, False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = Downsample(inplanes, planes * self.expansion, stride) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, planes * self.expansion, 1, 1, 0, False, False | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.num_conv = 3 | ||||
|  | ||||
|     bottleneck = self.conv_1x1(inputs) | ||||
|     bottleneck = self.conv_3x3(bottleneck) | ||||
|     bottleneck = self.conv_1x4(bottleneck) | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     out = additive_func(residual, bottleneck) | ||||
|     return F.relu(out, inplace=True) | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class CifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes, zero_init_residual): | ||||
|         super(CifarResNet, self).__init__() | ||||
|  | ||||
|   def __init__(self, block_name, depth, num_classes, zero_init_residual): | ||||
|     super(CifarResNet, self).__init__() | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|     #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|     if block_name == 'ResNetBasicblock': | ||||
|       block = ResNetBasicblock | ||||
|       assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' | ||||
|       layer_blocks = (depth - 2) // 6 | ||||
|     elif block_name == 'ResNetBottleneck': | ||||
|       block = ResNetBottleneck | ||||
|       assert (depth - 2) % 9 == 0, 'depth should be one of 164' | ||||
|       layer_blocks = (depth - 2) // 9 | ||||
|     else: | ||||
|       raise ValueError('invalid block : {:}'.format(block_name)) | ||||
|         self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format( | ||||
|             block_name, depth, layer_blocks | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)]) | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|     self.message     = 'CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}'.format(block_name, depth, layer_blocks) | ||||
|     self.num_classes = num_classes | ||||
|     self.channels    = [16] | ||||
|     self.layers      = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, True) ] ) | ||||
|     for stage in range(3): | ||||
|       for iL in range(layer_blocks): | ||||
|         iC     = self.channels[-1] | ||||
|         planes = 16 * (2**stage) | ||||
|         stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|         module = block(iC, planes, stride) | ||||
|         self.channels.append( module.out_dim ) | ||||
|         self.layers.append  ( module ) | ||||
|         self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         assert ( | ||||
|             sum(x.num_conv for x in self.layers) + 1 == depth | ||||
|         ), "invalid depth check {:} vs {:}".format( | ||||
|             sum(x.num_conv for x in self.layers) + 1, depth | ||||
|         ) | ||||
|  | ||||
|     self.avgpool = nn.AvgPool2d(8) | ||||
|     self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|     assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     self.apply(initialize_resnet) | ||||
|     if zero_init_residual: | ||||
|       for m in self.modules(): | ||||
|         if isinstance(m, ResNetBasicblock): | ||||
|           nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|         elif isinstance(m, ResNetBottleneck): | ||||
|           nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     x = inputs | ||||
|     for i, layer in enumerate(self.layers): | ||||
|       x = layer( x ) | ||||
|     features = self.avgpool(x) | ||||
|     features = features.view(features.size(0), -1) | ||||
|     logits   = self.classifier(features) | ||||
|     return features, logits | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
|   | ||||
| @@ -5,90 +5,111 @@ from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class WideBasicblock(nn.Module): | ||||
|   def __init__(self, inplanes, planes, stride, dropout=False): | ||||
|     super(WideBasicblock, self).__init__() | ||||
|     def __init__(self, inplanes, planes, stride, dropout=False): | ||||
|         super(WideBasicblock, self).__init__() | ||||
|  | ||||
|     self.bn_a = nn.BatchNorm2d(inplanes) | ||||
|     self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||||
|         self.bn_a = nn.BatchNorm2d(inplanes) | ||||
|         self.conv_a = nn.Conv2d( | ||||
|             inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|     self.bn_b = nn.BatchNorm2d(planes) | ||||
|     if dropout: | ||||
|       self.dropout = nn.Dropout2d(p=0.5, inplace=True) | ||||
|     else: | ||||
|       self.dropout = None | ||||
|     self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | ||||
|         self.bn_b = nn.BatchNorm2d(planes) | ||||
|         if dropout: | ||||
|             self.dropout = nn.Dropout2d(p=0.5, inplace=True) | ||||
|         else: | ||||
|             self.dropout = None | ||||
|         self.conv_b = nn.Conv2d( | ||||
|             planes, planes, kernel_size=3, stride=1, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|     if inplanes != planes: | ||||
|       self.downsample = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|         if inplanes != planes: | ||||
|             self.downsample = nn.Conv2d( | ||||
|                 inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|  | ||||
|   def forward(self, x): | ||||
|     def forward(self, x): | ||||
|  | ||||
|     basicblock = self.bn_a(x) | ||||
|     basicblock = F.relu(basicblock) | ||||
|     basicblock = self.conv_a(basicblock) | ||||
|         basicblock = self.bn_a(x) | ||||
|         basicblock = F.relu(basicblock) | ||||
|         basicblock = self.conv_a(basicblock) | ||||
|  | ||||
|     basicblock = self.bn_b(basicblock) | ||||
|     basicblock = F.relu(basicblock) | ||||
|     if self.dropout is not None: | ||||
|       basicblock = self.dropout(basicblock) | ||||
|     basicblock = self.conv_b(basicblock) | ||||
|         basicblock = self.bn_b(basicblock) | ||||
|         basicblock = F.relu(basicblock) | ||||
|         if self.dropout is not None: | ||||
|             basicblock = self.dropout(basicblock) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       x = self.downsample(x) | ||||
|      | ||||
|     return x + basicblock | ||||
|         if self.downsample is not None: | ||||
|             x = self.downsample(x) | ||||
|  | ||||
|         return x + basicblock | ||||
|  | ||||
|  | ||||
| class CifarWideResNet(nn.Module): | ||||
|   """ | ||||
|   ResNet optimized for the Cifar dataset, as specified in | ||||
|   https://arxiv.org/abs/1512.03385.pdf | ||||
|   """ | ||||
|   def __init__(self, depth, widen_factor, num_classes, dropout): | ||||
|     super(CifarWideResNet, self).__init__() | ||||
|     """ | ||||
|     ResNet optimized for the Cifar dataset, as specified in | ||||
|     https://arxiv.org/abs/1512.03385.pdf | ||||
|     """ | ||||
|  | ||||
|     #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|     assert (depth - 4) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' | ||||
|     layer_blocks = (depth - 4) // 6 | ||||
|     print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) | ||||
|     def __init__(self, depth, widen_factor, num_classes, dropout): | ||||
|         super(CifarWideResNet, self).__init__() | ||||
|  | ||||
|     self.num_classes = num_classes | ||||
|     self.dropout = dropout | ||||
|     self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|         layer_blocks = (depth - 4) // 6 | ||||
|         print( | ||||
|             "CifarPreResNet : Depth : {} , Layers for each block : {}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     self.message  = 'Wide ResNet : depth={:}, widen_factor={:}, class={:}'.format(depth, widen_factor, num_classes) | ||||
|     self.inplanes = 16 | ||||
|     self.stage_1 = self._make_layer(WideBasicblock, 16*widen_factor, layer_blocks, 1) | ||||
|     self.stage_2 = self._make_layer(WideBasicblock, 32*widen_factor, layer_blocks, 2) | ||||
|     self.stage_3 = self._make_layer(WideBasicblock, 64*widen_factor, layer_blocks, 2) | ||||
|     self.lastact = nn.Sequential(nn.BatchNorm2d(64*widen_factor), nn.ReLU(inplace=True)) | ||||
|     self.avgpool = nn.AvgPool2d(8) | ||||
|     self.classifier = nn.Linear(64*widen_factor, num_classes) | ||||
|         self.num_classes = num_classes | ||||
|         self.dropout = dropout | ||||
|         self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) | ||||
|  | ||||
|     self.apply(initialize_resnet) | ||||
|         self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format( | ||||
|             depth, widen_factor, num_classes | ||||
|         ) | ||||
|         self.inplanes = 16 | ||||
|         self.stage_1 = self._make_layer( | ||||
|             WideBasicblock, 16 * widen_factor, layer_blocks, 1 | ||||
|         ) | ||||
|         self.stage_2 = self._make_layer( | ||||
|             WideBasicblock, 32 * widen_factor, layer_blocks, 2 | ||||
|         ) | ||||
|         self.stage_3 = self._make_layer( | ||||
|             WideBasicblock, 64 * widen_factor, layer_blocks, 2 | ||||
|         ) | ||||
|         self.lastact = nn.Sequential( | ||||
|             nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(64 * widen_factor, num_classes) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|   def _make_layer(self, block, planes, blocks, stride): | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     layers = [] | ||||
|     layers.append(block(self.inplanes, planes, stride, self.dropout)) | ||||
|     self.inplanes = planes | ||||
|     for i in range(1, blocks): | ||||
|       layers.append(block(self.inplanes, planes, 1, self.dropout)) | ||||
|     def _make_layer(self, block, planes, blocks, stride): | ||||
|  | ||||
|     return nn.Sequential(*layers) | ||||
|         layers = [] | ||||
|         layers.append(block(self.inplanes, planes, stride, self.dropout)) | ||||
|         self.inplanes = planes | ||||
|         for i in range(1, blocks): | ||||
|             layers.append(block(self.inplanes, planes, 1, self.dropout)) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.conv_3x3(x) | ||||
|     x = self.stage_1(x) | ||||
|     x = self.stage_2(x) | ||||
|     x = self.stage_3(x) | ||||
|     x = self.lastact(x) | ||||
|     x = self.avgpool(x) | ||||
|     features = x.view(x.size(0), -1) | ||||
|     outs     = self.classifier(features) | ||||
|     return features, outs | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.conv_3x3(x) | ||||
|         x = self.stage_1(x) | ||||
|         x = self.stage_2(x) | ||||
|         x = self.stage_3(x) | ||||
|         x = self.lastact(x) | ||||
|         x = self.avgpool(x) | ||||
|         features = x.view(x.size(0), -1) | ||||
|         outs = self.classifier(features) | ||||
|         return features, outs | ||||
|   | ||||
| @@ -4,98 +4,114 @@ from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|   def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | ||||
|     super(ConvBNReLU, self).__init__() | ||||
|     padding = (kernel_size - 1) // 2 | ||||
|     self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) | ||||
|     self.bn   = nn.BatchNorm2d(out_planes) | ||||
|     self.relu = nn.ReLU6(inplace=True) | ||||
|    | ||||
|   def forward(self, x): | ||||
|     out = self.conv( x ) | ||||
|     out = self.bn  ( out ) | ||||
|     out = self.relu( out ) | ||||
|     return out | ||||
|     def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         padding = (kernel_size - 1) // 2 | ||||
|         self.conv = nn.Conv2d( | ||||
|             in_planes, | ||||
|             out_planes, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             groups=groups, | ||||
|             bias=False, | ||||
|         ) | ||||
|         self.bn = nn.BatchNorm2d(out_planes) | ||||
|         self.relu = nn.ReLU6(inplace=True) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         out = self.bn(out) | ||||
|         out = self.relu(out) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class InvertedResidual(nn.Module): | ||||
|   def __init__(self, inp, oup, stride, expand_ratio): | ||||
|     super(InvertedResidual, self).__init__() | ||||
|     self.stride = stride | ||||
|     assert stride in [1, 2] | ||||
|     def __init__(self, inp, oup, stride, expand_ratio): | ||||
|         super(InvertedResidual, self).__init__() | ||||
|         self.stride = stride | ||||
|         assert stride in [1, 2] | ||||
|  | ||||
|     hidden_dim = int(round(inp * expand_ratio)) | ||||
|     self.use_res_connect = self.stride == 1 and inp == oup | ||||
|         hidden_dim = int(round(inp * expand_ratio)) | ||||
|         self.use_res_connect = self.stride == 1 and inp == oup | ||||
|  | ||||
|     layers = [] | ||||
|     if expand_ratio != 1: | ||||
|       # pw | ||||
|       layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) | ||||
|     layers.extend([ | ||||
|       # dw | ||||
|       ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | ||||
|       # pw-linear | ||||
|       nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||||
|       nn.BatchNorm2d(oup), | ||||
|     ]) | ||||
|     self.conv = nn.Sequential(*layers) | ||||
|         layers = [] | ||||
|         if expand_ratio != 1: | ||||
|             # pw | ||||
|             layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) | ||||
|         layers.extend( | ||||
|             [ | ||||
|                 # dw | ||||
|                 ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | ||||
|                 # pw-linear | ||||
|                 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||||
|                 nn.BatchNorm2d(oup), | ||||
|             ] | ||||
|         ) | ||||
|         self.conv = nn.Sequential(*layers) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     if self.use_res_connect: | ||||
|       return x + self.conv(x) | ||||
|     else: | ||||
|       return self.conv(x) | ||||
|     def forward(self, x): | ||||
|         if self.use_res_connect: | ||||
|             return x + self.conv(x) | ||||
|         else: | ||||
|             return self.conv(x) | ||||
|  | ||||
|  | ||||
| class MobileNetV2(nn.Module): | ||||
|   def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout): | ||||
|     super(MobileNetV2, self).__init__() | ||||
|     if block_name == 'InvertedResidual': | ||||
|       block = InvertedResidual | ||||
|     else: | ||||
|       raise ValueError('invalid block name : {:}'.format(block_name)) | ||||
|     inverted_residual_setting = [ | ||||
|       # t, c,  n, s | ||||
|       [1, 16 , 1, 1], | ||||
|       [6, 24 , 2, 2], | ||||
|       [6, 32 , 3, 2], | ||||
|       [6, 64 , 4, 2], | ||||
|       [6, 96 , 3, 1], | ||||
|       [6, 160, 3, 2], | ||||
|       [6, 320, 1, 1], | ||||
|     ] | ||||
|     def __init__( | ||||
|         self, num_classes, width_mult, input_channel, last_channel, block_name, dropout | ||||
|     ): | ||||
|         super(MobileNetV2, self).__init__() | ||||
|         if block_name == "InvertedResidual": | ||||
|             block = InvertedResidual | ||||
|         else: | ||||
|             raise ValueError("invalid block name : {:}".format(block_name)) | ||||
|         inverted_residual_setting = [ | ||||
|             # t, c,  n, s | ||||
|             [1, 16, 1, 1], | ||||
|             [6, 24, 2, 2], | ||||
|             [6, 32, 3, 2], | ||||
|             [6, 64, 4, 2], | ||||
|             [6, 96, 3, 1], | ||||
|             [6, 160, 3, 2], | ||||
|             [6, 320, 1, 1], | ||||
|         ] | ||||
|  | ||||
|     # building first layer | ||||
|     input_channel = int(input_channel * width_mult) | ||||
|     self.last_channel = int(last_channel * max(1.0, width_mult)) | ||||
|     features = [ConvBNReLU(3, input_channel, stride=2)] | ||||
|     # building inverted residual blocks | ||||
|     for t, c, n, s in inverted_residual_setting: | ||||
|       output_channel = int(c * width_mult) | ||||
|       for i in range(n): | ||||
|         stride = s if i == 0 else 1 | ||||
|         features.append(block(input_channel, output_channel, stride, expand_ratio=t)) | ||||
|         input_channel = output_channel | ||||
|     # building last several layers | ||||
|     features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) | ||||
|     # make it nn.Sequential | ||||
|     self.features = nn.Sequential(*features) | ||||
|         # building first layer | ||||
|         input_channel = int(input_channel * width_mult) | ||||
|         self.last_channel = int(last_channel * max(1.0, width_mult)) | ||||
|         features = [ConvBNReLU(3, input_channel, stride=2)] | ||||
|         # building inverted residual blocks | ||||
|         for t, c, n, s in inverted_residual_setting: | ||||
|             output_channel = int(c * width_mult) | ||||
|             for i in range(n): | ||||
|                 stride = s if i == 0 else 1 | ||||
|                 features.append( | ||||
|                     block(input_channel, output_channel, stride, expand_ratio=t) | ||||
|                 ) | ||||
|                 input_channel = output_channel | ||||
|         # building last several layers | ||||
|         features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) | ||||
|         # make it nn.Sequential | ||||
|         self.features = nn.Sequential(*features) | ||||
|  | ||||
|     # building classifier | ||||
|     self.classifier = nn.Sequential( | ||||
|       nn.Dropout(dropout), | ||||
|       nn.Linear(self.last_channel, num_classes), | ||||
|     ) | ||||
|     self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout) | ||||
|         # building classifier | ||||
|         self.classifier = nn.Sequential( | ||||
|             nn.Dropout(dropout), | ||||
|             nn.Linear(self.last_channel, num_classes), | ||||
|         ) | ||||
|         self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format( | ||||
|             width_mult, input_channel, last_channel, block_name, dropout | ||||
|         ) | ||||
|  | ||||
|     # weight initialization | ||||
|     self.apply( initialize_resnet ) | ||||
|         # weight initialization | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     features = self.features(inputs) | ||||
|     vectors  = features.mean([2, 3]) | ||||
|     predicts = self.classifier(vectors) | ||||
|     return features, predicts | ||||
|     def forward(self, inputs): | ||||
|         features = self.features(inputs) | ||||
|         vectors = features.mean([2, 3]) | ||||
|         predicts = self.classifier(vectors) | ||||
|         return features, predicts | ||||
|   | ||||
| @@ -2,171 +2,216 @@ | ||||
| import torch.nn as nn | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| def conv3x3(in_planes, out_planes, stride=1, groups=1): | ||||
|   return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) | ||||
|     return nn.Conv2d( | ||||
|         in_planes, | ||||
|         out_planes, | ||||
|         kernel_size=3, | ||||
|         stride=stride, | ||||
|         padding=1, | ||||
|         groups=groups, | ||||
|         bias=False, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def conv1x1(in_planes, out_planes, stride=1): | ||||
|   return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||
|     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||
|  | ||||
|  | ||||
| class BasicBlock(nn.Module): | ||||
|   expansion = 1 | ||||
|     expansion = 1 | ||||
|  | ||||
|   def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64): | ||||
|     super(BasicBlock, self).__init__() | ||||
|     if groups != 1 or base_width != 64: | ||||
|       raise ValueError('BasicBlock only supports groups=1 and base_width=64') | ||||
|     # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||||
|     self.conv1 = conv3x3(inplanes, planes, stride) | ||||
|     self.bn1   = nn.BatchNorm2d(planes) | ||||
|     self.relu  = nn.ReLU(inplace=True) | ||||
|     self.conv2 = conv3x3(planes, planes) | ||||
|     self.bn2   = nn.BatchNorm2d(planes) | ||||
|     self.downsample = downsample | ||||
|     self.stride = stride | ||||
|     def __init__( | ||||
|         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||
|     ): | ||||
|         super(BasicBlock, self).__init__() | ||||
|         if groups != 1 or base_width != 64: | ||||
|             raise ValueError("BasicBlock only supports groups=1 and base_width=64") | ||||
|         # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||||
|         self.conv1 = conv3x3(inplanes, planes, stride) | ||||
|         self.bn1 = nn.BatchNorm2d(planes) | ||||
|         self.relu = nn.ReLU(inplace=True) | ||||
|         self.conv2 = conv3x3(planes, planes) | ||||
|         self.bn2 = nn.BatchNorm2d(planes) | ||||
|         self.downsample = downsample | ||||
|         self.stride = stride | ||||
|  | ||||
|   def forward(self, x): | ||||
|     identity = x | ||||
|     def forward(self, x): | ||||
|         identity = x | ||||
|  | ||||
|     out = self.conv1(x) | ||||
|     out = self.bn1(out) | ||||
|     out = self.relu(out) | ||||
|         out = self.conv1(x) | ||||
|         out = self.bn1(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|     out = self.conv2(out) | ||||
|     out = self.bn2(out) | ||||
|         out = self.conv2(out) | ||||
|         out = self.bn2(out) | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       identity = self.downsample(x) | ||||
|         if self.downsample is not None: | ||||
|             identity = self.downsample(x) | ||||
|  | ||||
|     out += identity | ||||
|     out = self.relu(out) | ||||
|         out += identity | ||||
|         out = self.relu(out) | ||||
|  | ||||
|     return out | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class Bottleneck(nn.Module): | ||||
|   expansion = 4 | ||||
|     expansion = 4 | ||||
|  | ||||
|   def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64): | ||||
|     super(Bottleneck, self).__init__() | ||||
|     width = int(planes * (base_width / 64.)) * groups | ||||
|     # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||||
|     self.conv1 = conv1x1(inplanes, width) | ||||
|     self.bn1   = nn.BatchNorm2d(width) | ||||
|     self.conv2 = conv3x3(width, width, stride, groups) | ||||
|     self.bn2   = nn.BatchNorm2d(width) | ||||
|     self.conv3 = conv1x1(width, planes * self.expansion) | ||||
|     self.bn3   = nn.BatchNorm2d(planes * self.expansion) | ||||
|     self.relu  = nn.ReLU(inplace=True) | ||||
|     self.downsample = downsample | ||||
|     self.stride = stride | ||||
|     def __init__( | ||||
|         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||
|     ): | ||||
|         super(Bottleneck, self).__init__() | ||||
|         width = int(planes * (base_width / 64.0)) * groups | ||||
|         # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||||
|         self.conv1 = conv1x1(inplanes, width) | ||||
|         self.bn1 = nn.BatchNorm2d(width) | ||||
|         self.conv2 = conv3x3(width, width, stride, groups) | ||||
|         self.bn2 = nn.BatchNorm2d(width) | ||||
|         self.conv3 = conv1x1(width, planes * self.expansion) | ||||
|         self.bn3 = nn.BatchNorm2d(planes * self.expansion) | ||||
|         self.relu = nn.ReLU(inplace=True) | ||||
|         self.downsample = downsample | ||||
|         self.stride = stride | ||||
|  | ||||
|   def forward(self, x): | ||||
|     identity = x | ||||
|     def forward(self, x): | ||||
|         identity = x | ||||
|  | ||||
|     out = self.conv1(x) | ||||
|     out = self.bn1(out) | ||||
|     out = self.relu(out) | ||||
|         out = self.conv1(x) | ||||
|         out = self.bn1(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|     out = self.conv2(out) | ||||
|     out = self.bn2(out) | ||||
|     out = self.relu(out) | ||||
|         out = self.conv2(out) | ||||
|         out = self.bn2(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|     out = self.conv3(out) | ||||
|     out = self.bn3(out) | ||||
|         out = self.conv3(out) | ||||
|         out = self.bn3(out) | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       identity = self.downsample(x) | ||||
|         if self.downsample is not None: | ||||
|             identity = self.downsample(x) | ||||
|  | ||||
|     out += identity | ||||
|     out = self.relu(out) | ||||
|         out += identity | ||||
|         out = self.relu(out) | ||||
|  | ||||
|     return out | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         block_name, | ||||
|         layers, | ||||
|         deep_stem, | ||||
|         num_classes, | ||||
|         zero_init_residual, | ||||
|         groups, | ||||
|         width_per_group, | ||||
|     ): | ||||
|         super(ResNet, self).__init__() | ||||
|  | ||||
|   def __init__(self, block_name, layers, deep_stem, num_classes, zero_init_residual, groups, width_per_group): | ||||
|     super(ResNet, self).__init__() | ||||
|         # planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] | ||||
|         if block_name == "BasicBlock": | ||||
|             block = BasicBlock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = Bottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block-name : {:}".format(block_name)) | ||||
|  | ||||
|     #planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] | ||||
|     if block_name == 'BasicBlock'  : block= BasicBlock | ||||
|     elif block_name == 'Bottleneck': block= Bottleneck | ||||
|     else                           : raise ValueError('invalid block-name : {:}'.format(block_name)) | ||||
|  | ||||
|     if not deep_stem: | ||||
|       self.conv = nn.Sequential( | ||||
|                    nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), | ||||
|                    nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | ||||
|     else: | ||||
|       self.conv = nn.Sequential( | ||||
|                    nn.Conv2d(           3, 32, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|                    nn.BatchNorm2d(32), nn.ReLU(inplace=True), | ||||
|                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                    nn.BatchNorm2d(32), nn.ReLU(inplace=True), | ||||
|                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                    nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | ||||
|     self.inplanes = 64 | ||||
|     self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
|     self.layer1 = self._make_layer(block, 64 , layers[0], stride=1, groups=groups, base_width=width_per_group) | ||||
|     self.layer2 = self._make_layer(block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group) | ||||
|     self.layer3 = self._make_layer(block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group) | ||||
|     self.layer4 = self._make_layer(block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group) | ||||
|     self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|     self.fc      = nn.Linear(512 * block.expansion, num_classes) | ||||
|     self.message = 'block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}'.format(block, layers, deep_stem, num_classes) | ||||
|  | ||||
|     self.apply( initialize_resnet ) | ||||
|  | ||||
|     # Zero-initialize the last BN in each residual branch, | ||||
|     # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||||
|     # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||||
|     if zero_init_residual: | ||||
|       for m in self.modules(): | ||||
|         if isinstance(m, Bottleneck): | ||||
|           nn.init.constant_(m.bn3.weight, 0) | ||||
|         elif isinstance(m, BasicBlock): | ||||
|           nn.init.constant_(m.bn2.weight, 0) | ||||
|  | ||||
|   def _make_layer(self, block, planes, blocks, stride, groups, base_width): | ||||
|     downsample = None | ||||
|     if stride != 1 or self.inplanes != planes * block.expansion: | ||||
|       if stride == 2: | ||||
|         downsample = nn.Sequential( | ||||
|           nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||
|           conv1x1(self.inplanes, planes * block.expansion, 1), | ||||
|           nn.BatchNorm2d(planes * block.expansion), | ||||
|         if not deep_stem: | ||||
|             self.conv = nn.Sequential( | ||||
|                 nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), | ||||
|                 nn.BatchNorm2d(64), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ) | ||||
|         else: | ||||
|             self.conv = nn.Sequential( | ||||
|                 nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(32), | ||||
|                 nn.ReLU(inplace=True), | ||||
|                 nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(32), | ||||
|                 nn.ReLU(inplace=True), | ||||
|                 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(64), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ) | ||||
|         self.inplanes = 64 | ||||
|         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
|         self.layer1 = self._make_layer( | ||||
|             block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|       elif stride == 1: | ||||
|         downsample = nn.Sequential( | ||||
|           conv1x1(self.inplanes, planes * block.expansion, stride), | ||||
|           nn.BatchNorm2d(planes * block.expansion), | ||||
|         self.layer2 = self._make_layer( | ||||
|             block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer3 = self._make_layer( | ||||
|             block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer4 = self._make_layer( | ||||
|             block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.fc = nn.Linear(512 * block.expansion, num_classes) | ||||
|         self.message = ( | ||||
|             "block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format( | ||||
|                 block, layers, deep_stem, num_classes | ||||
|             ) | ||||
|         ) | ||||
|       else: raise ValueError('invalid stride [{:}] for downsample'.format(stride)) | ||||
|  | ||||
|     layers = [] | ||||
|     layers.append(block(self.inplanes, planes, stride, downsample, groups, base_width)) | ||||
|     self.inplanes = planes * block.expansion | ||||
|     for _ in range(1, blocks): | ||||
|       layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     return nn.Sequential(*layers) | ||||
|         # Zero-initialize the last BN in each residual branch, | ||||
|         # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||||
|         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, Bottleneck): | ||||
|                     nn.init.constant_(m.bn3.weight, 0) | ||||
|                 elif isinstance(m, BasicBlock): | ||||
|                     nn.init.constant_(m.bn2.weight, 0) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.message | ||||
|     def _make_layer(self, block, planes, blocks, stride, groups, base_width): | ||||
|         downsample = None | ||||
|         if stride != 1 or self.inplanes != planes * block.expansion: | ||||
|             if stride == 2: | ||||
|                 downsample = nn.Sequential( | ||||
|                     nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||
|                     conv1x1(self.inplanes, planes * block.expansion, 1), | ||||
|                     nn.BatchNorm2d(planes * block.expansion), | ||||
|                 ) | ||||
|             elif stride == 1: | ||||
|                 downsample = nn.Sequential( | ||||
|                     conv1x1(self.inplanes, planes * block.expansion, stride), | ||||
|                     nn.BatchNorm2d(planes * block.expansion), | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid stride [{:}] for downsample".format(stride)) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.conv(x) | ||||
|     x = self.maxpool(x) | ||||
|         layers = [] | ||||
|         layers.append( | ||||
|             block(self.inplanes, planes, stride, downsample, groups, base_width) | ||||
|         ) | ||||
|         self.inplanes = planes * block.expansion | ||||
|         for _ in range(1, blocks): | ||||
|             layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) | ||||
|  | ||||
|     x = self.layer1(x) | ||||
|     x = self.layer2(x) | ||||
|     x = self.layer3(x) | ||||
|     x = self.layer4(x) | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     features = self.avgpool(x) | ||||
|     features = features.view(features.size(0), -1) | ||||
|     logits   = self.fc(features) | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     return features, logits | ||||
|     def forward(self, x): | ||||
|         x = self.conv(x) | ||||
|         x = self.maxpool(x) | ||||
|  | ||||
|         x = self.layer1(x) | ||||
|         x = self.layer2(x) | ||||
|         x = self.layer3(x) | ||||
|         x = self.layer4(x) | ||||
|  | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.fc(features) | ||||
|  | ||||
|         return features, logits | ||||
|   | ||||
| @@ -6,29 +6,32 @@ import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def additive_func(A, B): | ||||
|   assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size()) | ||||
|   C = min(A.size(1), B.size(1)) | ||||
|   if A.size(1) == B.size(1): | ||||
|     return A + B | ||||
|   elif A.size(1) < B.size(1): | ||||
|     out = B.clone() | ||||
|     out[:,:C] += A | ||||
|     return out | ||||
|   else: | ||||
|     out = A.clone() | ||||
|     out[:,:C] += B | ||||
|     return out | ||||
|     assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format( | ||||
|         A.size(), B.size() | ||||
|     ) | ||||
|     C = min(A.size(1), B.size(1)) | ||||
|     if A.size(1) == B.size(1): | ||||
|         return A + B | ||||
|     elif A.size(1) < B.size(1): | ||||
|         out = B.clone() | ||||
|         out[:, :C] += A | ||||
|         return out | ||||
|     else: | ||||
|         out = A.clone() | ||||
|         out[:, :C] += B | ||||
|         return out | ||||
|  | ||||
|  | ||||
| def change_key(key, value): | ||||
|   def func(m): | ||||
|     if hasattr(m, key): | ||||
|       setattr(m, key, value) | ||||
|   return func | ||||
|     def func(m): | ||||
|         if hasattr(m, key): | ||||
|             setattr(m, key, value) | ||||
|  | ||||
|     return func | ||||
|  | ||||
|  | ||||
| def parse_channel_info(xstring): | ||||
|   blocks = xstring.split(' ') | ||||
|   blocks = [x.split('-') for x in blocks] | ||||
|   blocks = [[int(_) for _ in x] for x in blocks] | ||||
|   return blocks | ||||
|     blocks = xstring.split(" ") | ||||
|     blocks = [x.split("-") for x in blocks] | ||||
|     blocks = [[int(_) for _ in x] for x in blocks] | ||||
|     return blocks | ||||
|   | ||||
| @@ -5,10 +5,18 @@ from os import path as osp | ||||
| from typing import List, Text | ||||
| import torch | ||||
|  | ||||
| __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | ||||
|            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ | ||||
|            'CellStructure', 'CellArchitectures' | ||||
|            ] | ||||
| __all__ = [ | ||||
|     "change_key", | ||||
|     "get_cell_based_tiny_net", | ||||
|     "get_search_spaces", | ||||
|     "get_cifar_models", | ||||
|     "get_imagenet_models", | ||||
|     "obtain_model", | ||||
|     "obtain_search_model", | ||||
|     "load_net_from_checkpoint", | ||||
|     "CellStructure", | ||||
|     "CellArchitectures", | ||||
| ] | ||||
|  | ||||
| # useful modules | ||||
| from config_utils import dict2config | ||||
| @@ -18,178 +26,301 @@ from models.cell_searchs import CellStructure, CellArchitectures | ||||
|  | ||||
| # Cell-based NAS Models | ||||
| def get_cell_based_tiny_net(config): | ||||
|   if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM', 'generic'] | ||||
|   if super_type == 'basic' and config.name in group_names: | ||||
|     from .cell_searchs import nas201_super_nets as nas_super_nets | ||||
|     try: | ||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||
|     except: | ||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   elif super_type == 'search-shape': | ||||
|     from .shape_searchs import GenericNAS301Model | ||||
|     genotype = CellStructure.str2structure(config.genotype) | ||||
|     return GenericNAS301Model(config.candidate_Cs, config.max_num_Cs, genotype, config.num_classes, config.affine, config.track_running_stats) | ||||
|   elif super_type == 'nasnet-super': | ||||
|     from .cell_searchs import nasnet_super_nets as nas_super_nets | ||||
|     return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \ | ||||
|                     config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||
|   elif config.name == 'infer.tiny': | ||||
|     from .cell_infers import TinyNetwork | ||||
|     if hasattr(config, 'genotype'): | ||||
|       genotype = config.genotype | ||||
|     elif hasattr(config, 'arch_str'): | ||||
|       genotype = CellStructure.str2structure(config.arch_str) | ||||
|     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) | ||||
|     return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||
|   elif config.name == 'infer.shape.tiny': | ||||
|     from .shape_infers import DynamicShapeTinyNet | ||||
|     if isinstance(config.channels, str): | ||||
|       channels = tuple([int(x) for x in config.channels.split(':')]) | ||||
|     else: channels = config.channels | ||||
|     genotype = CellStructure.str2structure(config.genotype) | ||||
|     return DynamicShapeTinyNet(channels, genotype, config.num_classes) | ||||
|   elif config.name == 'infer.nasnet-cifar': | ||||
|     from .cell_infers import NASNetonCIFAR | ||||
|     raise NotImplementedError | ||||
|   else: | ||||
|     raise ValueError('invalid network name : {:}'.format(config.name)) | ||||
|     if isinstance(config, dict): | ||||
|         config = dict2config(config, None)  # to support the argument being a dict | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"] | ||||
|     if super_type == "basic" and config.name in group_names: | ||||
|         from .cell_searchs import nas201_super_nets as nas_super_nets | ||||
|  | ||||
|         try: | ||||
|             return nas_super_nets[config.name]( | ||||
|                 config.C, | ||||
|                 config.N, | ||||
|                 config.max_nodes, | ||||
|                 config.num_classes, | ||||
|                 config.space, | ||||
|                 config.affine, | ||||
|                 config.track_running_stats, | ||||
|             ) | ||||
|         except: | ||||
|             return nas_super_nets[config.name]( | ||||
|                 config.C, config.N, config.max_nodes, config.num_classes, config.space | ||||
|             ) | ||||
|     elif super_type == "search-shape": | ||||
|         from .shape_searchs import GenericNAS301Model | ||||
|  | ||||
|         genotype = CellStructure.str2structure(config.genotype) | ||||
|         return GenericNAS301Model( | ||||
|             config.candidate_Cs, | ||||
|             config.max_num_Cs, | ||||
|             genotype, | ||||
|             config.num_classes, | ||||
|             config.affine, | ||||
|             config.track_running_stats, | ||||
|         ) | ||||
|     elif super_type == "nasnet-super": | ||||
|         from .cell_searchs import nasnet_super_nets as nas_super_nets | ||||
|  | ||||
|         return nas_super_nets[config.name]( | ||||
|             config.C, | ||||
|             config.N, | ||||
|             config.steps, | ||||
|             config.multiplier, | ||||
|             config.stem_multiplier, | ||||
|             config.num_classes, | ||||
|             config.space, | ||||
|             config.affine, | ||||
|             config.track_running_stats, | ||||
|         ) | ||||
|     elif config.name == "infer.tiny": | ||||
|         from .cell_infers import TinyNetwork | ||||
|  | ||||
|         if hasattr(config, "genotype"): | ||||
|             genotype = config.genotype | ||||
|         elif hasattr(config, "arch_str"): | ||||
|             genotype = CellStructure.str2structure(config.arch_str) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Can not find genotype from this config : {:}".format(config) | ||||
|             ) | ||||
|         return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||
|     elif config.name == "infer.shape.tiny": | ||||
|         from .shape_infers import DynamicShapeTinyNet | ||||
|  | ||||
|         if isinstance(config.channels, str): | ||||
|             channels = tuple([int(x) for x in config.channels.split(":")]) | ||||
|         else: | ||||
|             channels = config.channels | ||||
|         genotype = CellStructure.str2structure(config.genotype) | ||||
|         return DynamicShapeTinyNet(channels, genotype, config.num_classes) | ||||
|     elif config.name == "infer.nasnet-cifar": | ||||
|         from .cell_infers import NASNetonCIFAR | ||||
|  | ||||
|         raise NotImplementedError | ||||
|     else: | ||||
|         raise ValueError("invalid network name : {:}".format(config.name)) | ||||
|  | ||||
|  | ||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||
| def get_search_spaces(xtype, name) -> List[Text]: | ||||
|   if xtype == 'cell' or xtype == 'tss':  # The topology search space. | ||||
|     from .cell_operations import SearchSpaceNames | ||||
|     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) | ||||
|     return SearchSpaceNames[name] | ||||
|   elif xtype == 'sss':  # The size search space. | ||||
|     if name in ['nats-bench', 'nats-bench-size']: | ||||
|       return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64], | ||||
|               'numbers': 5} | ||||
|     if xtype == "cell" or xtype == "tss":  # The topology search space. | ||||
|         from .cell_operations import SearchSpaceNames | ||||
|  | ||||
|         assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format( | ||||
|             name, SearchSpaceNames.keys() | ||||
|         ) | ||||
|         return SearchSpaceNames[name] | ||||
|     elif xtype == "sss":  # The size search space. | ||||
|         if name in ["nats-bench", "nats-bench-size"]: | ||||
|             return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5} | ||||
|         else: | ||||
|             raise ValueError("Invalid name : {:}".format(name)) | ||||
|     else: | ||||
|       raise ValueError('Invalid name : {:}'.format(name)) | ||||
|   else: | ||||
|     raise ValueError('invalid search-space type is {:}'.format(xtype)) | ||||
|         raise ValueError("invalid search-space type is {:}".format(xtype)) | ||||
|  | ||||
|  | ||||
| def get_cifar_models(config, extra_path=None): | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   if super_type == 'basic': | ||||
|     from .CifarResNet      import CifarResNet | ||||
|     from .CifarDenseNet    import DenseNet | ||||
|     from .CifarWideResNet  import CifarWideResNet | ||||
|     if config.arch == 'resnet': | ||||
|       return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) | ||||
|     elif config.arch == 'densenet': | ||||
|       return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck) | ||||
|     elif config.arch == 'wideresnet': | ||||
|       return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout) | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     if super_type == "basic": | ||||
|         from .CifarResNet import CifarResNet | ||||
|         from .CifarDenseNet import DenseNet | ||||
|         from .CifarWideResNet import CifarWideResNet | ||||
|  | ||||
|         if config.arch == "resnet": | ||||
|             return CifarResNet( | ||||
|                 config.module, config.depth, config.class_num, config.zero_init_residual | ||||
|             ) | ||||
|         elif config.arch == "densenet": | ||||
|             return DenseNet( | ||||
|                 config.growthRate, | ||||
|                 config.depth, | ||||
|                 config.reduction, | ||||
|                 config.class_num, | ||||
|                 config.bottleneck, | ||||
|             ) | ||||
|         elif config.arch == "wideresnet": | ||||
|             return CifarWideResNet( | ||||
|                 config.depth, config.wide_factor, config.class_num, config.dropout | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid module type : {:}".format(config.arch)) | ||||
|     elif super_type.startswith("infer"): | ||||
|         from .shape_infers import InferWidthCifarResNet | ||||
|         from .shape_infers import InferDepthCifarResNet | ||||
|         from .shape_infers import InferCifarResNet | ||||
|         from .cell_infers import NASNetonCIFAR | ||||
|  | ||||
|         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||
|             super_type | ||||
|         ) | ||||
|         infer_mode = super_type.split("-")[1] | ||||
|         if infer_mode == "width": | ||||
|             return InferWidthCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xchannels, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "depth": | ||||
|             return InferDepthCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xblocks, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "shape": | ||||
|             return InferCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xblocks, | ||||
|                 config.xchannels, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "nasnet.cifar": | ||||
|             genotype = config.genotype | ||||
|             if extra_path is not None:  # reload genotype by extra_path | ||||
|                 if not osp.isfile(extra_path): | ||||
|                     raise ValueError("invalid extra_path : {:}".format(extra_path)) | ||||
|                 xdata = torch.load(extra_path) | ||||
|                 current_epoch = xdata["epoch"] | ||||
|                 genotype = xdata["genotypes"][current_epoch - 1] | ||||
|             C = config.C if hasattr(config, "C") else config.ichannel | ||||
|             N = config.N if hasattr(config, "N") else config.layers | ||||
|             return NASNetonCIFAR( | ||||
|                 C, N, config.stem_multi, config.class_num, genotype, config.auxiliary | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||
|     else: | ||||
|       raise ValueError('invalid module type : {:}'.format(config.arch)) | ||||
|   elif super_type.startswith('infer'): | ||||
|     from .shape_infers import InferWidthCifarResNet | ||||
|     from .shape_infers import InferDepthCifarResNet | ||||
|     from .shape_infers import InferCifarResNet | ||||
|     from .cell_infers import NASNetonCIFAR | ||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) | ||||
|     infer_mode = super_type.split('-')[1] | ||||
|     if infer_mode == 'width': | ||||
|       return InferWidthCifarResNet(config.module, config.depth, config.xchannels, config.class_num, config.zero_init_residual) | ||||
|     elif infer_mode == 'depth': | ||||
|       return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) | ||||
|     elif infer_mode == 'shape': | ||||
|       return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) | ||||
|     elif infer_mode == 'nasnet.cifar': | ||||
|       genotype = config.genotype | ||||
|       if extra_path is not None:  # reload genotype by extra_path | ||||
|         if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path)) | ||||
|         xdata = torch.load(extra_path) | ||||
|         current_epoch = xdata['epoch'] | ||||
|         genotype = xdata['genotypes'][current_epoch-1] | ||||
|       C = config.C if hasattr(config, 'C') else config.ichannel | ||||
|       N = config.N if hasattr(config, 'N') else config.layers | ||||
|       return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary) | ||||
|     else: | ||||
|       raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) | ||||
|   else: | ||||
|     raise ValueError('invalid super-type : {:}'.format(super_type)) | ||||
|         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||
|  | ||||
|  | ||||
| def get_imagenet_models(config): | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   if super_type == 'basic': | ||||
|     from .ImageNet_ResNet import ResNet | ||||
|     from .ImageNet_MobileNetV2 import MobileNetV2 | ||||
|     if config.arch == 'resnet': | ||||
|       return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group) | ||||
|     elif config.arch == 'mobilenet_v2': | ||||
|       return MobileNetV2(config.class_num, config.width_multi, config.input_channel, config.last_channel, 'InvertedResidual', config.dropout) | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     if super_type == "basic": | ||||
|         from .ImageNet_ResNet import ResNet | ||||
|         from .ImageNet_MobileNetV2 import MobileNetV2 | ||||
|  | ||||
|         if config.arch == "resnet": | ||||
|             return ResNet( | ||||
|                 config.block_name, | ||||
|                 config.layers, | ||||
|                 config.deep_stem, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|                 config.groups, | ||||
|                 config.width_per_group, | ||||
|             ) | ||||
|         elif config.arch == "mobilenet_v2": | ||||
|             return MobileNetV2( | ||||
|                 config.class_num, | ||||
|                 config.width_multi, | ||||
|                 config.input_channel, | ||||
|                 config.last_channel, | ||||
|                 "InvertedResidual", | ||||
|                 config.dropout, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid arch : {:}".format(config.arch)) | ||||
|     elif super_type.startswith("infer"):  # NAS searched architecture | ||||
|         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||
|             super_type | ||||
|         ) | ||||
|         infer_mode = super_type.split("-")[1] | ||||
|         if infer_mode == "shape": | ||||
|             from .shape_infers import InferImagenetResNet | ||||
|             from .shape_infers import InferMobileNetV2 | ||||
|  | ||||
|             if config.arch == "resnet": | ||||
|                 return InferImagenetResNet( | ||||
|                     config.block_name, | ||||
|                     config.layers, | ||||
|                     config.xblocks, | ||||
|                     config.xchannels, | ||||
|                     config.deep_stem, | ||||
|                     config.class_num, | ||||
|                     config.zero_init_residual, | ||||
|                 ) | ||||
|             elif config.arch == "MobileNetV2": | ||||
|                 return InferMobileNetV2( | ||||
|                     config.class_num, config.xchannels, config.xblocks, config.dropout | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid arch-mode : {:}".format(config.arch)) | ||||
|         else: | ||||
|             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||
|     else: | ||||
|       raise ValueError('invalid arch : {:}'.format( config.arch )) | ||||
|   elif super_type.startswith('infer'): # NAS searched architecture | ||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) | ||||
|     infer_mode = super_type.split('-')[1] | ||||
|     if infer_mode == 'shape': | ||||
|       from .shape_infers import InferImagenetResNet | ||||
|       from .shape_infers import InferMobileNetV2 | ||||
|       if config.arch == 'resnet': | ||||
|         return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual) | ||||
|       elif config.arch == "MobileNetV2": | ||||
|         return InferMobileNetV2(config.class_num, config.xchannels, config.xblocks, config.dropout) | ||||
|       else: | ||||
|         raise ValueError('invalid arch-mode : {:}'.format(config.arch)) | ||||
|     else: | ||||
|       raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) | ||||
|   else: | ||||
|     raise ValueError('invalid super-type : {:}'.format(super_type)) | ||||
|         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||
|  | ||||
|  | ||||
| # Try to obtain the network by config. | ||||
| def obtain_model(config, extra_path=None): | ||||
|   if config.dataset == 'cifar': | ||||
|     return get_cifar_models(config, extra_path) | ||||
|   elif config.dataset == 'imagenet': | ||||
|     return get_imagenet_models(config) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset in the model config : {:}'.format(config)) | ||||
|     if config.dataset == "cifar": | ||||
|         return get_cifar_models(config, extra_path) | ||||
|     elif config.dataset == "imagenet": | ||||
|         return get_imagenet_models(config) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||
|  | ||||
|  | ||||
| def obtain_search_model(config): | ||||
|   if config.dataset == 'cifar': | ||||
|     if config.arch == 'resnet': | ||||
|       from .shape_searchs import SearchWidthCifarResNet | ||||
|       from .shape_searchs import SearchDepthCifarResNet | ||||
|       from .shape_searchs import SearchShapeCifarResNet | ||||
|       if config.search_mode == 'width': | ||||
|         return SearchWidthCifarResNet(config.module, config.depth, config.class_num) | ||||
|       elif config.search_mode == 'depth': | ||||
|         return SearchDepthCifarResNet(config.module, config.depth, config.class_num) | ||||
|       elif config.search_mode == 'shape': | ||||
|         return SearchShapeCifarResNet(config.module, config.depth, config.class_num) | ||||
|       else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) | ||||
|     elif config.arch == 'simres': | ||||
|       from .shape_searchs import SearchWidthSimResNet | ||||
|       if config.search_mode == 'width': | ||||
|         return SearchWidthSimResNet(config.depth, config.class_num) | ||||
|       else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) | ||||
|     if config.dataset == "cifar": | ||||
|         if config.arch == "resnet": | ||||
|             from .shape_searchs import SearchWidthCifarResNet | ||||
|             from .shape_searchs import SearchDepthCifarResNet | ||||
|             from .shape_searchs import SearchShapeCifarResNet | ||||
|  | ||||
|             if config.search_mode == "width": | ||||
|                 return SearchWidthCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             elif config.search_mode == "depth": | ||||
|                 return SearchDepthCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             elif config.search_mode == "shape": | ||||
|                 return SearchShapeCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||
|         elif config.arch == "simres": | ||||
|             from .shape_searchs import SearchWidthSimResNet | ||||
|  | ||||
|             if config.search_mode == "width": | ||||
|                 return SearchWidthSimResNet(config.depth, config.class_num) | ||||
|             else: | ||||
|                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "invalid arch : {:} for dataset [{:}]".format( | ||||
|                     config.arch, config.dataset | ||||
|                 ) | ||||
|             ) | ||||
|     elif config.dataset == "imagenet": | ||||
|         from .shape_searchs import SearchShapeImagenetResNet | ||||
|  | ||||
|         assert config.search_mode == "shape", "invalid search-mode : {:}".format( | ||||
|             config.search_mode | ||||
|         ) | ||||
|         if config.arch == "resnet": | ||||
|             return SearchShapeImagenetResNet( | ||||
|                 config.block_name, config.layers, config.deep_stem, config.class_num | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid model config : {:}".format(config)) | ||||
|     else: | ||||
|       raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset)) | ||||
|   elif config.dataset == 'imagenet': | ||||
|     from .shape_searchs import SearchShapeImagenetResNet | ||||
|     assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode ) | ||||
|     if config.arch == 'resnet': | ||||
|       return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num) | ||||
|     else: | ||||
|       raise ValueError('invalid model config : {:}'.format(config)) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset in the model config : {:}'.format(config)) | ||||
|         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||
|  | ||||
|  | ||||
| def load_net_from_checkpoint(checkpoint): | ||||
|   assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) | ||||
|   checkpoint   = torch.load(checkpoint) | ||||
|   model_config = dict2config(checkpoint['model-config'], None) | ||||
|   model        = obtain_model(model_config) | ||||
|   model.load_state_dict(checkpoint['base-model']) | ||||
|   return model | ||||
|     assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint) | ||||
|     checkpoint = torch.load(checkpoint) | ||||
|     model_config = dict2config(checkpoint["model-config"], None) | ||||
|     model = obtain_model(model_config) | ||||
|     model.load_state_dict(checkpoint["base-model"]) | ||||
|     return model | ||||
|   | ||||
| @@ -21,8 +21,12 @@ def get_model(config: Dict[Text, Any], **kwargs): | ||||
|         act_cls = super_name2activation[kwargs["act_cls"]] | ||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||
|         mean, std = kwargs.get("mean", None), kwargs.get("std", None) | ||||
|         hidden_dim1 = kwargs.get("hidden_dim1", 200) | ||||
|         hidden_dim2 = kwargs.get("hidden_dim2", 100) | ||||
|         if "hidden_dim" in kwargs: | ||||
|             hidden_dim1 = kwargs.get("hidden_dim") | ||||
|             hidden_dim2 = kwargs.get("hidden_dim") | ||||
|         else: | ||||
|             hidden_dim1 = kwargs.get("hidden_dim1", 200) | ||||
|             hidden_dim2 = kwargs.get("hidden_dim2", 100) | ||||
|         model = SuperSequential( | ||||
|             norm_cls(mean=mean, std=std), | ||||
|             SuperLinear(kwargs["input_dim"], hidden_dim1), | ||||
| @@ -34,4 +38,3 @@ def get_model(config: Dict[Text, Any], **kwargs): | ||||
|     else: | ||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||
|     return model | ||||
|  | ||||
|   | ||||
| @@ -59,6 +59,9 @@ class TensorContainer: | ||||
|         for tensor in self._tensors: | ||||
|             tensor.requires_grad_(requires_grad) | ||||
|  | ||||
|     def parameters(self): | ||||
|         return self._tensors | ||||
|  | ||||
|     @property | ||||
|     def tensors(self): | ||||
|         return self._tensors | ||||
|   | ||||
		Reference in New Issue
	
	Block a user