From 5eab0de53ea8ae416ffce0508a82da5d6bdf8c70 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 26 May 2021 03:35:25 +0000 Subject: [PATCH] Fix bugs --- exps/GeMOSA/lfna_meta_model.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/exps/GeMOSA/lfna_meta_model.py b/exps/GeMOSA/lfna_meta_model.py index 9c69c83..c36e88b 100644 --- a/exps/GeMOSA/lfna_meta_model.py +++ b/exps/GeMOSA/lfna_meta_model.py @@ -46,9 +46,9 @@ class MetaModelV1(super_core.SuperModule): ) self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) # register a time difference buffer - time_interval = [-i * self._interval for i in range(self._seq_length)] - time_interval.reverse() - self.register_buffer("_time_interval", torch.Tensor(time_interval)) + # time_interval = [-i * self._interval for i in range(self._seq_length)] + # time_interval.reverse() + # self.register_buffer("_time_interval", torch.Tensor(time_interval)) self._time_embed_dim = time_dim self._append_meta_embed = dict(fixed=None, learnt=None) self._append_meta_timestamps = dict(fixed=None, learnt=None) @@ -161,7 +161,8 @@ class MetaModelV1(super_core.SuperModule): def _obtain_time_embed(self, timestamps): # timestamps is a batch of sequence of timestamps - batch, seq = timestamps.shape + # batch, seq = timestamps.shape + timestamps = timestamps.view(-1, 1) meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed timestamp_v_embed = meta_embeds.unsqueeze(dim=0) timestamp_qk_att_embed = self._tscalar_embed( @@ -185,9 +186,9 @@ class MetaModelV1(super_core.SuperModule): def forward_raw(self, timestamps, time_embeds, tembed_only=False): if time_embeds is None: - time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) - B, S = time_seq.shape - time_embeds = self._obtain_time_embed(time_seq) + # time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) + [B] = timestamps.shape + time_embeds = self._obtain_time_embed(timestamps) else: # use the hyper-net only time_seq = None B, _ = time_embeds.shape