naswot/models/cell_searchs/search_model_enas_utils.py
2020-06-03 12:59:01 +01:00

56 lines
2.4 KiB
Python

##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
super(Controller, self).__init__()
# assign the attributes
self.num_edge = num_edge
self.num_ops = num_ops
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars , -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight , -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight , -0.1, 0.1)
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append( op_index.item() )
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append( op_log_prob.view(-1) )
op_entropy = op_distribution.entropy()
entropys.append( op_entropy.view(-1) )
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch