Fix bugs
This commit is contained in:
		| @@ -46,9 +46,9 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         ) |         ) | ||||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) |         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||||
|         # register a time difference buffer |         # register a time difference buffer | ||||||
|         time_interval = [-i * self._interval for i in range(self._seq_length)] |         # time_interval = [-i * self._interval for i in range(self._seq_length)] | ||||||
|         time_interval.reverse() |         # time_interval.reverse() | ||||||
|         self.register_buffer("_time_interval", torch.Tensor(time_interval)) |         # self.register_buffer("_time_interval", torch.Tensor(time_interval)) | ||||||
|         self._time_embed_dim = time_dim |         self._time_embed_dim = time_dim | ||||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) |         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) |         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||||
| @@ -161,7 +161,8 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|  |  | ||||||
|     def _obtain_time_embed(self, timestamps): |     def _obtain_time_embed(self, timestamps): | ||||||
|         # timestamps is a batch of sequence of timestamps |         # timestamps is a batch of sequence of timestamps | ||||||
|         batch, seq = timestamps.shape |         # batch, seq = timestamps.shape | ||||||
|  |         timestamps = timestamps.view(-1, 1) | ||||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed |         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) |         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||||
|         timestamp_qk_att_embed = self._tscalar_embed( |         timestamp_qk_att_embed = self._tscalar_embed( | ||||||
| @@ -185,9 +186,9 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|  |  | ||||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): |     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||||
|         if time_embeds is None: |         if time_embeds is None: | ||||||
|             time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) |             # time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) | ||||||
|             B, S = time_seq.shape |             [B] = timestamps.shape | ||||||
|             time_embeds = self._obtain_time_embed(time_seq) |             time_embeds = self._obtain_time_embed(timestamps) | ||||||
|         else:  # use the hyper-net only |         else:  # use the hyper-net only | ||||||
|             time_seq = None |             time_seq = None | ||||||
|             B, _ = time_embeds.shape |             B, _ = time_embeds.shape | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user