Prototype generic nas model (cont.) for GDAS.
This commit is contained in:
		| @@ -102,17 +102,18 @@ class GenericNAS201Model(nn.Module): | ||||
|     self._op_names   = deepcopy(search_space) | ||||
|     self._Layer      = len(self._cells) | ||||
|     self.edge2index  = edge2index | ||||
|     self.lastact     = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|     self.lastact     = nn.Sequential(nn.BatchNorm2d(C_prev, affine=affine, track_running_stats=track_running_stats), nn.ReLU(inplace=True)) | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier  = nn.Linear(C_prev, num_classes) | ||||
|     self._num_edge   = num_edge | ||||
|     # algorithm related | ||||
|     self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||
|     self.arch_parameters = nn.Parameter(1e-3*torch.randn(num_edge, len(search_space))) | ||||
|     self._mode        = None | ||||
|     self.dynamic_cell = None | ||||
|     self._tau         = None | ||||
|     self._algo        = None | ||||
|     self._drop_path   = None | ||||
|     self.verbose      = False | ||||
|  | ||||
|   def set_algo(self, algo: Text): | ||||
|     # used for searching | ||||
| @@ -256,33 +257,45 @@ class GenericNAS201Model(nn.Module): | ||||
|         else: break | ||||
|       with torch.no_grad(): | ||||
|         hardwts_cpu = hardwts.detach().cpu() | ||||
|       return hardwts, hardwts_cpu, index | ||||
|       return hardwts, hardwts_cpu, index, 'GUMBEL' | ||||
|     else: | ||||
|       alphas  = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|       index   = alphas.max(-1, keepdim=True)[1] | ||||
|       with torch.no_grad(): | ||||
|         alphas_cpu = alphas.detach().cpu() | ||||
|       return alphas, alphas_cpu, index | ||||
|       return alphas, alphas_cpu, index, 'SOFTMAX' | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     alphas, alphas_cpu, index = self.normalize_archp() | ||||
|     alphas, alphas_cpu, index, verbose_str = self.normalize_archp() | ||||
|     feature = self._stem(inputs) | ||||
|     for i, cell in enumerate(self._cells): | ||||
|       if isinstance(cell, SearchCell): | ||||
|         if self.mode == 'urs': | ||||
|           feature = cell.forward_urs(feature) | ||||
|           if self.verbose: | ||||
|             verbose_str += '-forward_urs' | ||||
|         elif self.mode == 'select': | ||||
|           feature = cell.forward_select(feature, alphas_cpu) | ||||
|           if self.verbose: | ||||
|             verbose_str += '-forward_select' | ||||
|         elif self.mode == 'joint': | ||||
|           feature = cell.forward_joint(feature, alphas) | ||||
|           if self.verbose: | ||||
|             verbose_str += '-forward_joint' | ||||
|         elif self.mode == 'dynamic': | ||||
|           feature = cell.forward_dynamic(feature, self.dynamic_cell) | ||||
|           if self.verbose: | ||||
|             verbose_str += '-forward_dynamic' | ||||
|         elif self.mode == 'gdas': | ||||
|           feature = cell.forward_gdas(feature, alphas, index) | ||||
|           if self.verbose: | ||||
|             verbose_str += '-forward_gdas' | ||||
|         else: raise ValueError('invalid mode={:}'.format(self.mode)) | ||||
|       else: feature = cell(feature) | ||||
|       if self.drop_path is not None: | ||||
|         feature = drop_path(feature, self.drop_path) | ||||
|     if self.verbose and random.random() < 0.001: | ||||
|       print(verbose_str) | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling(out) | ||||
|     out = out.view(out.size(0), -1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user