34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .search_cells import NAS201SearchCell as SearchCell
|
|
from .search_model import TinyNetwork as TinyNetwork
|
|
|
|
|
|
class TinyNetworkDarts(TinyNetwork):
|
|
def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args,
|
|
affine=False, track_running_stats=True, stem_channels=3):
|
|
super(TinyNetworkDarts, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args,
|
|
affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels)
|
|
|
|
self.theta_map = lambda x: torch.softmax(x, dim=-1)
|
|
|
|
def get_theta(self):
|
|
return self.theta_map(self._arch_parameters).cpu()
|
|
|
|
def forward(self, inputs):
|
|
weights = self.theta_map(self._arch_parameters)
|
|
feature = self.stem(inputs)
|
|
|
|
for i, cell in enumerate(self.cells):
|
|
if isinstance(cell, SearchCell):
|
|
feature = cell(feature, weights)
|
|
else:
|
|
feature = cell(feature)
|
|
|
|
out = self.lastact(feature)
|
|
out = self.global_pooling( out )
|
|
out = out.view(out.size(0), -1)
|
|
logits = self.classifier(out)
|
|
|
|
return logits
|