Fix bugs
This commit is contained in:
parent
33a8e7c88d
commit
5eab0de53e
@ -46,9 +46,9 @@ class MetaModelV1(super_core.SuperModule):
|
||||
)
|
||||
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
|
||||
# register a time difference buffer
|
||||
time_interval = [-i * self._interval for i in range(self._seq_length)]
|
||||
time_interval.reverse()
|
||||
self.register_buffer("_time_interval", torch.Tensor(time_interval))
|
||||
# time_interval = [-i * self._interval for i in range(self._seq_length)]
|
||||
# time_interval.reverse()
|
||||
# self.register_buffer("_time_interval", torch.Tensor(time_interval))
|
||||
self._time_embed_dim = time_dim
|
||||
self._append_meta_embed = dict(fixed=None, learnt=None)
|
||||
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||
@ -161,7 +161,8 @@ class MetaModelV1(super_core.SuperModule):
|
||||
|
||||
def _obtain_time_embed(self, timestamps):
|
||||
# timestamps is a batch of sequence of timestamps
|
||||
batch, seq = timestamps.shape
|
||||
# batch, seq = timestamps.shape
|
||||
timestamps = timestamps.view(-1, 1)
|
||||
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||
timestamp_qk_att_embed = self._tscalar_embed(
|
||||
@ -185,9 +186,9 @@ class MetaModelV1(super_core.SuperModule):
|
||||
|
||||
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
|
||||
if time_embeds is None:
|
||||
time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1)
|
||||
B, S = time_seq.shape
|
||||
time_embeds = self._obtain_time_embed(time_seq)
|
||||
# time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1)
|
||||
[B] = timestamps.shape
|
||||
time_embeds = self._obtain_time_embed(timestamps)
|
||||
else: # use the hyper-net only
|
||||
time_seq = None
|
||||
B, _ = time_embeds.shape
|
||||
|
Loading…
Reference in New Issue
Block a user