import torch import torch.nn as nn from .search_cells import NAS201SearchCell as SearchCell from .search_model import TinyNetwork as TinyNetwork class TinyNetworkDarts(TinyNetwork): def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3): super(TinyNetworkDarts, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args, affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels) self.theta_map = lambda x: torch.softmax(x, dim=-1) def get_theta(self): return self.theta_map(self._arch_parameters).cpu() def forward(self, inputs): weights = self.theta_map(self._arch_parameters) feature = self.stem(inputs) for i, cell in enumerate(self.cells): if isinstance(cell, SearchCell): feature = cell(feature, weights) else: feature = cell(feature) out = self.lastact(feature) out = self.global_pooling( out ) out = out.view(out.size(0), -1) logits = self.classifier(out) return logits