Prototype generic nas model (cont.) for ENAS.
This commit is contained in:
		| @@ -5,11 +5,75 @@ import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import Text | ||||
| from torch.distributions.categorical import Categorical | ||||
|  | ||||
| from ..cell_operations import ResNetBasicblock, drop_path | ||||
| from .search_cells     import NAS201SearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
| from .search_model_enas_utils import Controller | ||||
|  | ||||
|  | ||||
| class Controller(nn.Module): | ||||
|   # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py | ||||
|   def __init__(self, edge2index, op_names, max_nodes, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): | ||||
|     super(Controller, self).__init__() | ||||
|     # assign the attributes | ||||
|     self.max_nodes = max_nodes | ||||
|     self.num_edge  = len(edge2index) | ||||
|     self.edge2index = edge2index | ||||
|     self.num_ops   = len(op_names) | ||||
|     self.op_names  = op_names | ||||
|     self.lstm_size = lstm_size | ||||
|     self.lstm_N    = lstm_num_layers | ||||
|     self.tanh_constant = tanh_constant | ||||
|     self.temperature   = temperature | ||||
|     # create parameters | ||||
|     self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) | ||||
|     self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) | ||||
|     self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) | ||||
|     self.w_pred = nn.Linear(self.lstm_size, self.num_ops) | ||||
|  | ||||
|     nn.init.uniform_(self.input_vars         , -0.1, 0.1) | ||||
|     nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) | ||||
|     nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) | ||||
|     nn.init.uniform_(self.w_embd.weight      , -0.1, 0.1) | ||||
|     nn.init.uniform_(self.w_pred.weight      , -0.1, 0.1) | ||||
|  | ||||
|   def convert_structure(self, _arch): | ||||
|     genotypes = [] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       xlist = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         op_index = _arch[self.edge2index[node_str]] | ||||
|         op_name  = self.op_names[op_index] | ||||
|         xlist.append((op_name, j)) | ||||
|       genotypes.append( tuple(xlist) ) | ||||
|     return Structure(genotypes) | ||||
|  | ||||
|   def forward(self): | ||||
|  | ||||
|     inputs, h0 = self.input_vars, None | ||||
|     log_probs, entropys, sampled_arch = [], [], [] | ||||
|     for iedge in range(self.num_edge): | ||||
|       outputs, h0 = self.w_lstm(inputs, h0) | ||||
|        | ||||
|       logits = self.w_pred(outputs) | ||||
|       logits = logits / self.temperature | ||||
|       logits = self.tanh_constant * torch.tanh(logits) | ||||
|       # distribution | ||||
|       op_distribution = Categorical(logits=logits) | ||||
|       op_index    = op_distribution.sample() | ||||
|       sampled_arch.append( op_index.item() ) | ||||
|  | ||||
|       op_log_prob = op_distribution.log_prob(op_index) | ||||
|       log_probs.append( op_log_prob.view(-1) ) | ||||
|       op_entropy  = op_distribution.entropy() | ||||
|       entropys.append( op_entropy.view(-1) ) | ||||
|        | ||||
|       # obtain the input embedding for the next step | ||||
|       inputs = self.w_embd(op_index) | ||||
|     return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), self.convert_structure(sampled_arch) | ||||
|  | ||||
|  | ||||
|  | ||||
| class GenericNAS201Model(nn.Module): | ||||
| @@ -55,7 +119,7 @@ class GenericNAS201Model(nn.Module): | ||||
|     assert self._algo is None, 'This functioin can only be called once.' | ||||
|     self._algo = algo | ||||
|     if algo == 'enas': | ||||
|       self.controller = Controller(len(self.edge2index), len(self._op_names)) | ||||
|       self.controller = Controller(self.edge2index, self._op_names, self._max_nodes) | ||||
|     else: | ||||
|       self.arch_parameters = nn.Parameter( 1e-3*torch.randn(self._num_edge, len(self._op_names)) ) | ||||
|       if algo == 'gdas': | ||||
| @@ -116,10 +180,9 @@ class GenericNAS201Model(nn.Module): | ||||
|   def show_alphas(self): | ||||
|     with torch.no_grad(): | ||||
|       if self._algo == 'enas': | ||||
|         import pdb; pdb.set_trace() | ||||
|         print('-') | ||||
|         return 'w_pred :\n{:}'.format(self.controller.w_pred.weight) | ||||
|       else: | ||||
|         return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) | ||||
|         return 'arch-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu()) | ||||
|            | ||||
|  | ||||
|   def extra_repr(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user