Update TuNAS
This commit is contained in:
		| @@ -47,10 +47,10 @@ class GenericNAS301Model(nn.Module): | ||||
|   def set_algo(self, algo: Text): | ||||
|     # used for searching | ||||
|     assert self._algo is None, 'This functioin can only be called once.' | ||||
|     assert algo in ['fbv2', 'enas', 'tas'], 'invalid algo : {:}'.format(algo) | ||||
|     assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo) | ||||
|     self._algo = algo | ||||
|     self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) | ||||
|     if algo == 'fbv2' or algo == 'enas': | ||||
|     if algo == 'fbv2' or algo == 'tunas': | ||||
|       self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) | ||||
|       for i in range(len(self._candidate_Cs)): | ||||
|         self._masks.data[i, :self._candidate_Cs[i]] = 1 | ||||
| @@ -106,15 +106,17 @@ class GenericNAS301Model(nn.Module): | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     feature = inputs | ||||
|  | ||||
|     log_probs = [] | ||||
|     for i, cell in enumerate(self._cells): | ||||
|       feature = cell(feature) | ||||
|       # apply different searching algorithms | ||||
|       idx = max(0, i-1) | ||||
|       if self._algo == 'fbv2': | ||||
|         idx = max(0, i-1) | ||||
|         weights = nn.functional.gumbel_softmax(self._arch_parameters[idx:idx+1], tau=self.tau, dim=-1) | ||||
|         mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) | ||||
|         feature = feature * mask | ||||
|       elif self._algo == 'tas': | ||||
|         idx = max(0, i-1) | ||||
|         selected_cs, selected_probs = select2withP(self._arch_parameters[idx:idx+1], self.tau, num=2) | ||||
|         with torch.no_grad(): | ||||
|           i1, i2 = selected_cs.cpu().view(-1).tolist() | ||||
| @@ -128,6 +130,13 @@ class GenericNAS301Model(nn.Module): | ||||
|         else: | ||||
|           miss = torch.zeros(feature.shape[0], feature.shape[1]-out.shape[1], feature.shape[2], feature.shape[3], device=feature.device) | ||||
|           feature = torch.cat((out, miss), dim=1) | ||||
|       elif self._algo == 'tunas': | ||||
|         prob = nn.functional.softmax(self._arch_parameters[idx:idx+1], dim=-1) | ||||
|         dist = torch.distributions.Categorical(prob) | ||||
|         action = dist.sample() | ||||
|         log_probs.append(dist.log_prob(action)) | ||||
|         mask = self._masks[action.item()].view(1, -1, 1, 1) | ||||
|         feature = feature * mask | ||||
|       else: | ||||
|         raise ValueError('invalid algorithm : {:}'.format(self._algo)) | ||||
|  | ||||
| @@ -136,4 +145,4 @@ class GenericNAS301Model(nn.Module): | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     return out, logits | ||||
|     return out, logits, log_probs | ||||
|   | ||||
		Reference in New Issue
	
	Block a user