75 lines
2.7 KiB
Python
75 lines
2.7 KiB
Python
##################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
##########################################################################
|
|
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
|
##########################################################################
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributions.categorical import Categorical
|
|
|
|
|
|
class Controller(nn.Module):
|
|
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
|
def __init__(
|
|
self,
|
|
num_edge,
|
|
num_ops,
|
|
lstm_size=32,
|
|
lstm_num_layers=2,
|
|
tanh_constant=2.5,
|
|
temperature=5.0,
|
|
):
|
|
super(Controller, self).__init__()
|
|
# assign the attributes
|
|
self.num_edge = num_edge
|
|
self.num_ops = num_ops
|
|
self.lstm_size = lstm_size
|
|
self.lstm_N = lstm_num_layers
|
|
self.tanh_constant = tanh_constant
|
|
self.temperature = temperature
|
|
# create parameters
|
|
self.register_parameter(
|
|
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
|
|
)
|
|
self.w_lstm = nn.LSTM(
|
|
input_size=self.lstm_size,
|
|
hidden_size=self.lstm_size,
|
|
num_layers=self.lstm_N,
|
|
)
|
|
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
|
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
|
|
|
nn.init.uniform_(self.input_vars, -0.1, 0.1)
|
|
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
|
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
|
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
|
|
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
|
|
|
|
def forward(self):
|
|
|
|
inputs, h0 = self.input_vars, None
|
|
log_probs, entropys, sampled_arch = [], [], []
|
|
for iedge in range(self.num_edge):
|
|
outputs, h0 = self.w_lstm(inputs, h0)
|
|
|
|
logits = self.w_pred(outputs)
|
|
logits = logits / self.temperature
|
|
logits = self.tanh_constant * torch.tanh(logits)
|
|
# distribution
|
|
op_distribution = Categorical(logits=logits)
|
|
op_index = op_distribution.sample()
|
|
sampled_arch.append(op_index.item())
|
|
|
|
op_log_prob = op_distribution.log_prob(op_index)
|
|
log_probs.append(op_log_prob.view(-1))
|
|
op_entropy = op_distribution.entropy()
|
|
entropys.append(op_entropy.view(-1))
|
|
|
|
# obtain the input embedding for the next step
|
|
inputs = self.w_embd(op_index)
|
|
return (
|
|
torch.sum(torch.cat(log_probs)),
|
|
torch.sum(torch.cat(entropys)),
|
|
sampled_arch,
|
|
)
|