Update LFNA
This commit is contained in:
		| @@ -107,11 +107,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | ||||
|             base_model.eval() | ||||
|             time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) | ||||
|             [seq_containers], _ = meta_model(time_seqs, None) | ||||
|             future_container = seq_containers[-2] | ||||
|             _, (future_x, future_y) = env(time_seqs[0, -2].item()) | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             # For Debug | ||||
|             for idx in range(time_seqs.numel()): | ||||
|                 future_container = seq_containers[idx] | ||||
|                 _, (future_x, future_y) = env(time_seqs[0, idx].item()) | ||||
|                 future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|                 future_y_hat = base_model.forward_with_container( | ||||
|                     future_x, future_container | ||||
|                 ) | ||||
|                 future_loss = criterion(future_y_hat, future_y) | ||||
|                 logger.log( | ||||
|                     "--> time={:.4f} -> loss={:.4f}".format( | ||||
|                         time_seqs[0, idx].item(), future_loss.item() | ||||
|                     ) | ||||
|                 ) | ||||
|             logger.log( | ||||
|                 "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||
|                     idx, len(env), future_loss.item() | ||||
|   | ||||
| @@ -47,17 +47,17 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
|  | ||||
|         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||
|             time_embedding, scale=100 | ||||
|             time_embedding, scale=500 | ||||
|         ) | ||||
|  | ||||
|         # build transformer | ||||
|         self._trans_att = super_core.SuperQKVAttention( | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             4, | ||||
|             True, | ||||
|         self._trans_att = super_core.SuperQKVAttentionV2( | ||||
|             qk_att_dim=time_embedding, | ||||
|             in_v_dim=time_embedding, | ||||
|             hidden_dim=time_embedding, | ||||
|             num_heads=4, | ||||
|             proj_dim=time_embedding, | ||||
|             qkv_bias=True, | ||||
|             attn_drop=None, | ||||
|             proj_drop=dropout, | ||||
|         ) | ||||
| @@ -166,9 +166,12 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         # timestamps is a batch of sequence of timestamps | ||||
|         batch, seq = timestamps.shape | ||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||
|         timestamp_q_embed = self._tscalar_embed(timestamps) | ||||
|         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) | ||||
|         # timestamp_q_embed = self._tscalar_embed(timestamps) | ||||
|         # timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) | ||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||
|         timestamp_qk_att_embed = self._tscalar_embed( | ||||
|             torch.unsqueeze(timestamps, dim=-1) - meta_timestamps | ||||
|         ) | ||||
|         # create the mask | ||||
|         mask = ( | ||||
|             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||
| @@ -179,11 +182,13 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|             > self._thresh | ||||
|         ) | ||||
|         timestamp_embeds = self._trans_att( | ||||
|             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask | ||||
|             timestamp_qk_att_embed, timestamp_v_embed, mask | ||||
|         ) | ||||
|         relative_timestamps = timestamps - timestamps[:, :1] | ||||
|         relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||
|         init_timestamp_embeds = torch.cat( | ||||
|             (timestamp_embeds, relative_pos_embeds), dim=-1 | ||||
|         ) | ||||
|         # relative_timestamps = timestamps - timestamps[:, :1] | ||||
|         # relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||
|         init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) | ||||
|         corrected_embeds = self._meta_corrector(init_timestamp_embeds) | ||||
|         return corrected_embeds | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user