This commit is contained in:
D-X-Y 2021-05-26 03:35:25 +00:00
parent 33a8e7c88d
commit 5eab0de53e

View File

@ -46,9 +46,9 @@ class MetaModelV1(super_core.SuperModule):
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
# register a time difference buffer
time_interval = [-i * self._interval for i in range(self._seq_length)]
time_interval.reverse()
self.register_buffer("_time_interval", torch.Tensor(time_interval))
# time_interval = [-i * self._interval for i in range(self._seq_length)]
# time_interval.reverse()
# self.register_buffer("_time_interval", torch.Tensor(time_interval))
self._time_embed_dim = time_dim
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
@ -161,7 +161,8 @@ class MetaModelV1(super_core.SuperModule):
def _obtain_time_embed(self, timestamps):
# timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape
# batch, seq = timestamps.shape
timestamps = timestamps.view(-1, 1)
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_qk_att_embed = self._tscalar_embed(
@ -185,9 +186,9 @@ class MetaModelV1(super_core.SuperModule):
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
if time_embeds is None:
time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1)
B, S = time_seq.shape
time_embeds = self._obtain_time_embed(time_seq)
# time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1)
[B] = timestamps.shape
time_embeds = self._obtain_time_embed(timestamps)
else: # use the hyper-net only
time_seq = None
B, _ = time_embeds.shape