Remove unnecessary model in GMOA
This commit is contained in:
		| @@ -337,11 +337,11 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--refine_lr", | ||||
|         type=float, | ||||
|         default=0.002, | ||||
|         default=0.001, | ||||
|         help="The learning rate for the optimizer, during refine", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--refine_epochs", type=int, default=100, help="The final refine #epochs." | ||||
|         "--refine_epochs", type=int, default=150, help="The final refine #epochs." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|   | ||||
| @@ -19,7 +19,6 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         layer_dim, | ||||
|         time_dim, | ||||
|         meta_timestamps, | ||||
|         mha_depth: int = 2, | ||||
|         dropout: float = 0.1, | ||||
|         seq_length: int = 10, | ||||
|         interval: float = None, | ||||
| @@ -69,22 +68,6 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             attn_drop=None, | ||||
|             proj_drop=dropout, | ||||
|         ) | ||||
|         layers = [] | ||||
|         for ilayer in range(mha_depth): | ||||
|             layers.append( | ||||
|                 super_core.SuperTransformerEncoderLayer( | ||||
|                     time_dim * 2, | ||||
|                     4, | ||||
|                     True, | ||||
|                     4, | ||||
|                     dropout, | ||||
|                     norm_affine=False, | ||||
|                     order=super_core.LayerOrder.PostNorm, | ||||
|                     use_mask=True, | ||||
|                 ) | ||||
|             ) | ||||
|         layers.append(super_core.SuperLinear(time_dim * 2, time_dim)) | ||||
|         self._meta_corrector = super_core.SuperSequential(*layers) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             config=dict(model_type="dual_norm_mlp"), | ||||
| @@ -103,13 +86,12 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             std=0.02, | ||||
|         ) | ||||
|  | ||||
|     def get_parameters(self, time_embed, meta_corrector, generator): | ||||
|     def get_parameters(self, time_embed, attention, generator): | ||||
|         parameters = [] | ||||
|         if time_embed: | ||||
|             parameters.append(self._super_meta_embed) | ||||
|         if meta_corrector: | ||||
|         if attention: | ||||
|             parameters.extend(list(self._trans_att.parameters())) | ||||
|             parameters.extend(list(self._meta_corrector.parameters())) | ||||
|         if generator: | ||||
|             parameters.append(self._super_layer_embed) | ||||
|             parameters.extend(list(self._generator.parameters())) | ||||
| @@ -199,13 +181,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             timestamp_v_embed, | ||||
|             mask, | ||||
|         ) | ||||
|         relative_timestamps = timestamps - timestamps[:, :1] | ||||
|         relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||
|         init_timestamp_embeds = torch.cat( | ||||
|             (timestamp_embeds, relative_pos_embeds), dim=-1 | ||||
|         ) | ||||
|         corrected_embeds = self._meta_corrector(init_timestamp_embeds) | ||||
|         return corrected_embeds | ||||
|         return timestamp_embeds | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, get_seq_last): | ||||
|         if time_embeds is None: | ||||
| @@ -264,9 +240,8 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||
|         with torch.set_grad_enabled(True): | ||||
|             new_param = self.create_meta_embed() | ||||
|             optimizer = torch.optim.Adam( | ||||
|                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True | ||||
|             ) | ||||
|  | ||||
|             optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True) | ||||
|             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||
|             self.replace_append_learnt(timestamp, new_param) | ||||
|             self.train() | ||||
| @@ -297,7 +272,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         with torch.no_grad(): | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, best_new_param) | ||||
|         return True, meta_loss.item() | ||||
|         return True, best_loss | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user