MeCo/nasbench201/search_model.py
HamsterMimi 2410fe9f5e update
2023-05-04 13:42:06 +08:00

202 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from .cell_operations import ResNetBasicblock
from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure
from torch.autograd import Variable
class TinyNetwork(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self._num_classes = num_classes
self._criterion = criterion
self._args = args
self._affine = affine
self._track_running_stats = track_running_stats
self.stem = nn.Sequential(
nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell )
C_prev = cell.out_dim
self.num_edge = num_edge
self.num_op = len(search_space)
self.op_names = deepcopy( search_space )
self._Layer = len(self.cells)
self.edge2index = edge2index
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self._arch_parameters = Variable(1e-3*torch.randn(num_edge, len(search_space)).cuda(), requires_grad=True)
## optimizer
## 记录的是m在内存中的地址以示区分
arch_params = set(id(m) for m in self.arch_parameters())
self._model_params = [m for m in self.parameters() if id(m) not in arch_params]
# 模型参数优化器
self.optimizer = torch.optim.SGD(
self._model_params,
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov= args.nesterov)
def entropy_y_x(self, p_logit):
p = F.softmax(p_logit, dim=1)
return - torch.sum(p * F.log_softmax(p_logit, dim=1)) / p_logit.shape[0]
def _loss(self, input, target, return_logits=False):
logits = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def arch_parameters(self):
return [self._arch_parameters]
def get_theta(self):
return nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self._arch_parameters[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def forward(self, inputs, weights=None):
sim_nn = []
weights = nn.functional.softmax(self._arch_parameters, dim=-1) if weights is None else weights
if self.slim:
weights[1].data.fill_(0)
weights[3].data.fill_(0)
weights[4].data.fill_(0)
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
if isinstance(cell, SearchCell):
feature = cell(feature, weights)
else:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
def _save_arch_parameters(self):
self._saved_arch_parameters = [p.clone() for p in self._arch_parameters]
def project_arch(self):
self._save_arch_parameters()
for p in self.arch_parameters():
m, n = p.size()
maxIndexs = p.data.cpu().numpy().argmax(axis=1)
p.data = self.proximal_step(p, maxIndexs)
def proximal_step(self, var, maxIndexs=None):
values = var.data.cpu().numpy()
m, n = values.shape
alphas = []
for i in range(m):
for j in range(n):
if j == maxIndexs[i]:
alphas.append(values[i][j].copy())
values[i][j] = 1
else:
values[i][j] = 0
return torch.Tensor(values).cuda()
def restore_arch_parameters(self):
for i, p in enumerate(self._arch_parameters):
p.data.copy_(self._saved_arch_parameters[i])
del self._saved_arch_parameters
def new(self):
model_new = TinyNetwork(self._C, self._layerN, self.max_nodes, self._num_classes, self._criterion,
self.op_names, self._args, self._affine, self._track_running_stats).cuda()
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
x.data.copy_(y.data)
return model_new
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt
def printing(self, logging):
logging.info(self.get_theta())
def set_arch_parameters(self, new_alphas):
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
alpha.data.copy_(new_alpha.data)
def save_arch_parameters(self):
self._saved_arch_parameters = self._arch_parameters.clone()
def restore_arch_parameters(self):
self.set_arch_parameters(self._saved_arch_parameters)
def reset_optimizer(self, lr, momentum, weight_decay):
del self.optimizer
self.optimizer = torch.optim.SGD(
self.get_weights(),
lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov= args.nesterov)