Fix bugs
This commit is contained in:
		| @@ -20,7 +20,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         time_dim, | ||||
|         meta_timestamps, | ||||
|         dropout: float = 0.1, | ||||
|         seq_length: int = 10, | ||||
|         seq_length: int = None, | ||||
|         interval: float = None, | ||||
|         thresh: float = None, | ||||
|     ): | ||||
| @@ -33,8 +33,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         self._raw_meta_timestamps = meta_timestamps | ||||
|         assert interval is not None | ||||
|         self._interval = interval | ||||
|         self._seq_length = seq_length | ||||
|         self._thresh = interval * 50 if thresh is None else thresh | ||||
|         self._thresh = interval * seq_length if thresh is None else thresh | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
| @@ -45,10 +44,6 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||
|         ) | ||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||
|         # register a time difference buffer | ||||
|         # time_interval = [-i * self._interval for i in range(self._seq_length)] | ||||
|         # time_interval.reverse() | ||||
|         # self.register_buffer("_time_interval", torch.Tensor(time_interval)) | ||||
|         self._time_embed_dim = time_dim | ||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
| @@ -186,7 +181,6 @@ class MetaModelV1(super_core.SuperModule): | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         if time_embeds is None: | ||||
|             # time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) | ||||
|             [B] = timestamps.shape | ||||
|             time_embeds = self._obtain_time_embed(timestamps) | ||||
|         else:  # use the hyper-net only | ||||
| @@ -210,7 +204,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             batch_containers.append( | ||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||
|             ) | ||||
|         return time_seq, batch_containers, time_embeds | ||||
|         return batch_containers, time_embeds | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
| @@ -239,10 +233,10 @@ class MetaModelV1(super_core.SuperModule): | ||||
|                 best_new_param = new_param.detach().clone() | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None) | ||||
|                 _, time_embed = self(timestamp.view(1), None) | ||||
|                 match_loss = criterion(new_param, time_embed) | ||||
|  | ||||
|                 _, [container], time_embed = self(None, new_param.view(1, -1)) | ||||
|                 [container], time_embed = self(None, new_param.view(1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|                 meta_loss = criterion(y_hat, y) | ||||
|                 loss = meta_loss + match_loss | ||||
|   | ||||
| @@ -46,8 +46,8 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             _, [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(1, 1), None, False | ||||
|             [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(-1), None, False | ||||
|             ) | ||||
|             if save: | ||||
|                 w_containers[idx] = future_container.no_grad_clone() | ||||
| @@ -117,10 +117,10 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         ) | ||||
|         # future loss | ||||
|         total_future_losses, total_present_losses = [], [] | ||||
|         _, future_containers, _ = meta_model( | ||||
|         future_containers, _ = meta_model( | ||||
|             None, generated_time_embeds[batch_indexes], False | ||||
|         ) | ||||
|         _, present_containers, _ = meta_model( | ||||
|         present_containers, _ = meta_model( | ||||
|             None, meta_model.super_meta_embed[batch_indexes], False | ||||
|         ) | ||||
|         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user