Prototype generic nas model (cont.).
This commit is contained in:
		| @@ -67,6 +67,14 @@ class GenericNAS201Model(nn.Module): | ||||
|     if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|     else                : self.dynamic_cell = None | ||||
|  | ||||
|   def set_drop_path(self, progress, drop_path_rate): | ||||
|     if drop_path_rate is None: | ||||
|       self._drop_path = None | ||||
|     elif progress is None: | ||||
|       self._drop_path = drop_path_rate | ||||
|     else: | ||||
|       self._drop_path = progress * drop_path_rate | ||||
|  | ||||
|   @property | ||||
|   def mode(self): | ||||
|     return self._mode | ||||
| @@ -210,6 +218,8 @@ class GenericNAS201Model(nn.Module): | ||||
|           feature = cell.forward_gdas(feature, alphas, index) | ||||
|         else: raise ValueError('invalid mode={:}'.format(self.mode)) | ||||
|       else: feature = cell(feature) | ||||
|       if self.drop_path is not None: | ||||
|         feature = drop_path(feature, self.drop_path) | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling(out) | ||||
|     out = out.view(out.size(0), -1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user