Update LFNA ablation codes
This commit is contained in:
		| @@ -41,4 +41,5 @@ super_name2activation = { | ||||
|  | ||||
|  | ||||
| from .super_trade_stem import SuperAlphaEBDv1 | ||||
| from .super_positional_embedding import SuperDynamicPositionE | ||||
| from .super_positional_embedding import SuperPositionalEncoder | ||||
|   | ||||
| @@ -10,6 +10,41 @@ from .super_module import SuperModule | ||||
| from .super_module import IntSpaceType | ||||
|  | ||||
|  | ||||
| class SuperDynamicPositionE(SuperModule): | ||||
|     """Applies a positional encoding to the input positions.""" | ||||
|  | ||||
|     def __init__(self, dimension: int, scale: float = 1.0) -> None: | ||||
|         super(SuperDynamicPositionE, self).__init__() | ||||
|  | ||||
|         self._scale = scale | ||||
|         self._dimension = dimension | ||||
|         # weights to be optimized | ||||
|         self.register_buffer( | ||||
|             "_div_term", | ||||
|             torch.exp( | ||||
|                 torch.arange(0, dimension, 2).float() * (-math.log(10000.0) / dimension) | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         return root_node | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         import pdb | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         print("---") | ||||
|         return F.linear(input, self._super_weight, self._super_bias) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "scale={:}, dim={:}".format(self._scale, self._dimension) | ||||
|  | ||||
|  | ||||
| class SuperPositionalEncoder(SuperModule): | ||||
|     """Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf | ||||
|     https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user