updates for beta
This commit is contained in:
		| @@ -83,7 +83,8 @@ class SearchCell(nn.Module): | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         weights  = weightss[ self.edge2index[node_str] ] | ||||
|         aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() | ||||
|         #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() | ||||
|         aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) | ||||
|         inter_nodes.append( aggregation ) | ||||
|       nodes.append( sum(inter_nodes) ) | ||||
|     return nodes[-1] | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| ###################################################################################### | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # | ||||
| ###################################################################################### | ||||
| import torch | ||||
| import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| @@ -87,7 +87,7 @@ class TinyNetworkSETN(nn.Module): | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|  | ||||
|   def dync_genotype(self): | ||||
|   def dync_genotype(self, use_random=False): | ||||
|     genotypes = [] | ||||
|     with torch.no_grad(): | ||||
|       alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
| @@ -95,9 +95,12 @@ class TinyNetworkSETN(nn.Module): | ||||
|       xlist = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         weights  = alphas_cpu[ self.edge2index[node_str] ] | ||||
|         op_index = torch.multinomial(weights, 1).item() | ||||
|         op_name  = self.op_names[ op_index ] | ||||
|         if use_random: | ||||
|           op_name  = random.choice(self.op_names) | ||||
|         else: | ||||
|           weights  = alphas_cpu[ self.edge2index[node_str] ] | ||||
|           op_index = torch.multinomial(weights, 1).item() | ||||
|           op_name  = self.op_names[ op_index ] | ||||
|         xlist.append((op_name, j)) | ||||
|       genotypes.append( tuple(xlist) ) | ||||
|     return Structure( genotypes ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user