56 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			56 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | |
| ##########################################################################
 | |
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
 | |
| ##########################################################################
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| from torch.distributions.categorical import Categorical
 | |
| 
 | |
| class Controller(nn.Module):
 | |
|   # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
 | |
|   def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
 | |
|     super(Controller, self).__init__()
 | |
|     # assign the attributes
 | |
|     self.num_edge  = num_edge
 | |
|     self.num_ops   = num_ops
 | |
|     self.lstm_size = lstm_size
 | |
|     self.lstm_N    = lstm_num_layers
 | |
|     self.tanh_constant = tanh_constant
 | |
|     self.temperature   = temperature
 | |
|     # create parameters
 | |
|     self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
 | |
|     self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
 | |
|     self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
 | |
|     self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
 | |
| 
 | |
|     nn.init.uniform_(self.input_vars         , -0.1, 0.1)
 | |
|     nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
 | |
|     nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
 | |
|     nn.init.uniform_(self.w_embd.weight      , -0.1, 0.1)
 | |
|     nn.init.uniform_(self.w_pred.weight      , -0.1, 0.1)
 | |
| 
 | |
|   def forward(self):
 | |
| 
 | |
|     inputs, h0 = self.input_vars, None
 | |
|     log_probs, entropys, sampled_arch = [], [], []
 | |
|     for iedge in range(self.num_edge):
 | |
|       outputs, h0 = self.w_lstm(inputs, h0)
 | |
|       
 | |
|       logits = self.w_pred(outputs)
 | |
|       logits = logits / self.temperature
 | |
|       logits = self.tanh_constant * torch.tanh(logits)
 | |
|       # distribution
 | |
|       op_distribution = Categorical(logits=logits)
 | |
|       op_index    = op_distribution.sample()
 | |
|       sampled_arch.append( op_index.item() )
 | |
| 
 | |
|       op_log_prob = op_distribution.log_prob(op_index)
 | |
|       log_probs.append( op_log_prob.view(-1) )
 | |
|       op_entropy  = op_distribution.entropy()
 | |
|       entropys.append( op_entropy.view(-1) )
 | |
|       
 | |
|       # obtain the input embedding for the next step
 | |
|       inputs = self.w_embd(op_index)
 | |
|     return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch
 |