202 lines
7.1 KiB
Python
202 lines
7.1 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from copy import deepcopy
|
||
from .cell_operations import ResNetBasicblock
|
||
from .search_cells import NAS201SearchCell as SearchCell
|
||
from .genotypes import Structure
|
||
from torch.autograd import Variable
|
||
|
||
class TinyNetwork(nn.Module):
|
||
|
||
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3):
|
||
super(TinyNetwork, self).__init__()
|
||
self._C = C
|
||
self._layerN = N
|
||
self.max_nodes = max_nodes
|
||
self._num_classes = num_classes
|
||
self._criterion = criterion
|
||
self._args = args
|
||
self._affine = affine
|
||
self._track_running_stats = track_running_stats
|
||
self.stem = nn.Sequential(
|
||
nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False),
|
||
nn.BatchNorm2d(C))
|
||
|
||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||
|
||
C_prev, num_edge, edge2index = C, None, None
|
||
self.cells = nn.ModuleList()
|
||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||
if reduction:
|
||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||
else:
|
||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
|
||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||
self.cells.append( cell )
|
||
C_prev = cell.out_dim
|
||
self.num_edge = num_edge
|
||
self.num_op = len(search_space)
|
||
self.op_names = deepcopy( search_space )
|
||
self._Layer = len(self.cells)
|
||
self.edge2index = edge2index
|
||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||
self.classifier = nn.Linear(C_prev, num_classes)
|
||
# self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||
self._arch_parameters = Variable(1e-3*torch.randn(num_edge, len(search_space)).cuda(), requires_grad=True)
|
||
|
||
## optimizer
|
||
## 记录的是m在内存中的地址,以示区分
|
||
arch_params = set(id(m) for m in self.arch_parameters())
|
||
self._model_params = [m for m in self.parameters() if id(m) not in arch_params]
|
||
|
||
# 模型参数优化器
|
||
self.optimizer = torch.optim.SGD(
|
||
self._model_params,
|
||
args.learning_rate,
|
||
momentum=args.momentum,
|
||
weight_decay=args.weight_decay,
|
||
nesterov= args.nesterov)
|
||
|
||
|
||
def entropy_y_x(self, p_logit):
|
||
p = F.softmax(p_logit, dim=1)
|
||
return - torch.sum(p * F.log_softmax(p_logit, dim=1)) / p_logit.shape[0]
|
||
|
||
def _loss(self, input, target, return_logits=False):
|
||
logits = self(input)
|
||
loss = self._criterion(logits, target)
|
||
|
||
return (loss, logits) if return_logits else loss
|
||
|
||
def get_weights(self):
|
||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
||
xlist+= list( self.classifier.parameters() )
|
||
return xlist
|
||
|
||
def arch_parameters(self):
|
||
return [self._arch_parameters]
|
||
|
||
def get_theta(self):
|
||
return nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
|
||
|
||
def get_message(self):
|
||
string = self.extra_repr()
|
||
for i, cell in enumerate(self.cells):
|
||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||
return string
|
||
|
||
def extra_repr(self):
|
||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||
|
||
def genotype(self):
|
||
genotypes = []
|
||
for i in range(1, self.max_nodes):
|
||
xlist = []
|
||
for j in range(i):
|
||
node_str = '{:}<-{:}'.format(i, j)
|
||
with torch.no_grad():
|
||
weights = self._arch_parameters[ self.edge2index[node_str] ]
|
||
op_name = self.op_names[ weights.argmax().item() ]
|
||
xlist.append((op_name, j))
|
||
genotypes.append( tuple(xlist) )
|
||
return Structure( genotypes )
|
||
|
||
def forward(self, inputs, weights=None):
|
||
sim_nn = []
|
||
|
||
weights = nn.functional.softmax(self._arch_parameters, dim=-1) if weights is None else weights
|
||
|
||
if self.slim:
|
||
weights[1].data.fill_(0)
|
||
weights[3].data.fill_(0)
|
||
weights[4].data.fill_(0)
|
||
|
||
feature = self.stem(inputs)
|
||
for i, cell in enumerate(self.cells):
|
||
if isinstance(cell, SearchCell):
|
||
feature = cell(feature, weights)
|
||
else:
|
||
feature = cell(feature)
|
||
|
||
out = self.lastact(feature)
|
||
out = self.global_pooling( out )
|
||
out = out.view(out.size(0), -1)
|
||
logits = self.classifier(out)
|
||
|
||
return logits
|
||
|
||
def _save_arch_parameters(self):
|
||
self._saved_arch_parameters = [p.clone() for p in self._arch_parameters]
|
||
|
||
def project_arch(self):
|
||
self._save_arch_parameters()
|
||
for p in self.arch_parameters():
|
||
m, n = p.size()
|
||
maxIndexs = p.data.cpu().numpy().argmax(axis=1)
|
||
p.data = self.proximal_step(p, maxIndexs)
|
||
|
||
def proximal_step(self, var, maxIndexs=None):
|
||
values = var.data.cpu().numpy()
|
||
m, n = values.shape
|
||
alphas = []
|
||
for i in range(m):
|
||
for j in range(n):
|
||
if j == maxIndexs[i]:
|
||
alphas.append(values[i][j].copy())
|
||
values[i][j] = 1
|
||
else:
|
||
values[i][j] = 0
|
||
return torch.Tensor(values).cuda()
|
||
|
||
def restore_arch_parameters(self):
|
||
for i, p in enumerate(self._arch_parameters):
|
||
p.data.copy_(self._saved_arch_parameters[i])
|
||
del self._saved_arch_parameters
|
||
|
||
def new(self):
|
||
model_new = TinyNetwork(self._C, self._layerN, self.max_nodes, self._num_classes, self._criterion,
|
||
self.op_names, self._args, self._affine, self._track_running_stats).cuda()
|
||
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
|
||
x.data.copy_(y.data)
|
||
|
||
return model_new
|
||
|
||
def step(self, input, target, args, shared=None, return_grad=False):
|
||
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||
Lt.backward()
|
||
if args.grad_clip != 0:
|
||
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||
self.optimizer.step()
|
||
|
||
if return_grad:
|
||
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
|
||
return logit_t, Lt, grad
|
||
else:
|
||
return logit_t, Lt
|
||
|
||
def printing(self, logging):
|
||
logging.info(self.get_theta())
|
||
|
||
def set_arch_parameters(self, new_alphas):
|
||
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
|
||
alpha.data.copy_(new_alpha.data)
|
||
|
||
def save_arch_parameters(self):
|
||
self._saved_arch_parameters = self._arch_parameters.clone()
|
||
|
||
def restore_arch_parameters(self):
|
||
self.set_arch_parameters(self._saved_arch_parameters)
|
||
|
||
def reset_optimizer(self, lr, momentum, weight_decay):
|
||
del self.optimizer
|
||
self.optimizer = torch.optim.SGD(
|
||
self.get_weights(),
|
||
lr,
|
||
momentum=momentum,
|
||
weight_decay=weight_decay,
|
||
nesterov= args.nesterov) |