Fix bugs in xlayers
This commit is contained in:
		| @@ -10,7 +10,8 @@ from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / "..").resolve() | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
|   | ||||
| @@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         layer_embedding, | ||||
|         time_embedding, | ||||
|         meta_timestamps, | ||||
|         mha_depth: int = 2, | ||||
|         mha_depth: int = 1, | ||||
|         dropout: float = 0.1, | ||||
|     ): | ||||
|         super(LFNA_Meta, self).__init__() | ||||
| @@ -44,8 +44,21 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
|  | ||||
|         self._time_prob_drop = super_core.SuperDrop(dropout, (-1, 1), recover=False) | ||||
|         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||
|             time_embedding, scale=100 | ||||
|         ) | ||||
|  | ||||
|         # build transformer | ||||
|         self._trans_att = super_core.SuperQKVAttention( | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             4, | ||||
|             True, | ||||
|             attn_drop=None, | ||||
|             proj_drop=dropout, | ||||
|         ) | ||||
|         layers = [] | ||||
|         for ilayer in range(mha_depth): | ||||
|             layers.append( | ||||
| @@ -74,15 +87,9 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         self._generator = get_model(**model_kwargs) | ||||
|         # print("generator: {:}".format(self._generator)) | ||||
|  | ||||
|         # unknown token | ||||
|         self.register_parameter( | ||||
|             "_unknown_token", | ||||
|             torch.nn.Parameter(torch.Tensor(1, time_embedding)), | ||||
|         ) | ||||
|  | ||||
|         # initialization | ||||
|         trunc_normal_( | ||||
|             [self._super_layer_embed, self._super_meta_embed, self._unknown_token], | ||||
|             [self._super_layer_embed, self._super_meta_embed], | ||||
|             std=0.02, | ||||
|         ) | ||||
|  | ||||
| @@ -136,28 +143,21 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||
|                 ) | ||||
|  | ||||
|     def forward_raw(self, timestamps): | ||||
|     def _obtain_time_embed(self, timestamps): | ||||
|         # timestamps is a batch of sequence of timestamps | ||||
|         batch, seq = timestamps.shape | ||||
|         timestamps = timestamps.unsqueeze(dim=-1) | ||||
|         meta_timestamps = self.meta_timestamps.view(1, 1, -1) | ||||
|         time_diffs = timestamps - meta_timestamps | ||||
|         time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1) | ||||
|         # select corresponding meta-knowledge | ||||
|         meta_match = torch.index_select( | ||||
|             self.super_meta_embed, dim=0, index=time_match_i.view(-1) | ||||
|         timestamp_q_embed = self._tscalar_embed(timestamps) | ||||
|         timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1)) | ||||
|         timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0) | ||||
|         timestamp_embeds = self._trans_att( | ||||
|             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed | ||||
|         ) | ||||
|         meta_match = meta_match.view(batch, seq, -1) | ||||
|         # create the probability | ||||
|         time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1) | ||||
|         corrected_embeds = self.meta_corrector(timestamp_embeds) | ||||
|         return corrected_embeds | ||||
|  | ||||
|         x_time_probs = self._time_prob_drop(time_probs) | ||||
|         # if self.training: | ||||
|         #    time_probs[:, -1, :] = 0 | ||||
|         unknown_token = self._unknown_token.view(1, 1, -1) | ||||
|         raw_meta_embed = x_time_probs * meta_match + (1 - x_time_probs) * unknown_token | ||||
|  | ||||
|         meta_embed = self.meta_corrector(raw_meta_embed) | ||||
|     def forward_raw(self, timestamps): | ||||
|         batch, seq = timestamps.shape | ||||
|         meta_embed = self._obtain_time_embed(timestamps) | ||||
|         # create joint embed | ||||
|         num_layer, _ = self._super_layer_embed.shape | ||||
|         meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user