MeCo/nasbench201/search_model.py

202 lines
7.1 KiB
Python
Raw Normal View History

2023-05-04 07:42:06 +02:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from .cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
from torch.autograd import Variable
class TinyNetwork(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self._num_classes = num_classes
self._criterion = criterion
self._args = args
self._affine = affine
self._track_running_stats = track_running_stats
self.stem = nn.Sequential(
nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.num_edge = num_edge
self.num_op = len(search_space)
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self._arch_parameters = Variable(1e-3*torch.randn(num_edge, len(search_space)).cuda(), requires_grad=True)
## optimizer
## 记录的是m在内存中的地址以示区分
arch_params = set(id(m) for m in self.arch_parameters())
self._model_params = [m for m in self.parameters() if id(m) not in arch_params]
# 模型参数优化器
self.optimizer = torch.optim.SGD(
self._model_params,
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov= args.nesterov)
def entropy_y_x(self, p_logit):
p = F.softmax(p_logit, dim=1)
return - torch.sum(p * F.log_softmax(p_logit, dim=1)) / p_logit.shape[0]
def _loss(self, input, target, return_logits=False):
logits = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def arch_parameters(self):
return [self._arch_parameters]
def get_theta(self):
return nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self._arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs, weights=None):
sim_nn = []
weights = nn.functional.softmax(self._arch_parameters, dim=-1) if weights is None else weights
if self.slim:
weights[1].data.fill_(0)
weights[3].data.fill_(0)
weights[4].data.fill_(0)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, weights)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
def _save_arch_parameters(self):
self._saved_arch_parameters = [p.clone() for p in self._arch_parameters]
def project_arch(self):
self._save_arch_parameters()
for p in self.arch_parameters():
m, n = p.size()
maxIndexs = p.data.cpu().numpy().argmax(axis=1)
p.data = self.proximal_step(p, maxIndexs)
def proximal_step(self, var, maxIndexs=None):
values = var.data.cpu().numpy()
m, n = values.shape
alphas = []
for i in range(m):
for j in range(n):
if j == maxIndexs[i]:
alphas.append(values[i][j].copy())
values[i][j] = 1
else:
values[i][j] = 0
return torch.Tensor(values).cuda()
def restore_arch_parameters(self):
for i, p in enumerate(self._arch_parameters):
p.data.copy_(self._saved_arch_parameters[i])
del self._saved_arch_parameters
def new(self):
model_new = TinyNetwork(self._C, self._layerN, self.max_nodes, self._num_classes, self._criterion,
self.op_names, self._args, self._affine, self._track_running_stats).cuda()
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
x.data.copy_(y.data)
return model_new
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt
def printing(self, logging):
logging.info(self.get_theta())
def set_arch_parameters(self, new_alphas):
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
alpha.data.copy_(new_alpha.data)
def save_arch_parameters(self):
self._saved_arch_parameters = self._arch_parameters.clone()
def restore_arch_parameters(self):
self.set_arch_parameters(self._saved_arch_parameters)
def reset_optimizer(self, lr, momentum, weight_decay):
del self.optimizer
self.optimizer = torch.optim.SGD(
self.get_weights(),
lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov= args.nesterov)