From 6c1fd745d78c24a8ae156037873b2d45e40399a6 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 24 May 2021 13:14:18 +0000 Subject: [PATCH] Remove unnecessary model in GMOA --- exps/GMOA/lfna.py | 4 ++-- exps/GMOA/lfna_meta_model.py | 37 ++++++------------------------------ 2 files changed, 8 insertions(+), 33 deletions(-) diff --git a/exps/GMOA/lfna.py b/exps/GMOA/lfna.py index fe6d883..93ae7d0 100644 --- a/exps/GMOA/lfna.py +++ b/exps/GMOA/lfna.py @@ -337,11 +337,11 @@ if __name__ == "__main__": parser.add_argument( "--refine_lr", type=float, - default=0.002, + default=0.001, help="The learning rate for the optimizer, during refine", ) parser.add_argument( - "--refine_epochs", type=int, default=100, help="The final refine #epochs." + "--refine_epochs", type=int, default=150, help="The final refine #epochs." ) parser.add_argument( "--early_stop_thresh", diff --git a/exps/GMOA/lfna_meta_model.py b/exps/GMOA/lfna_meta_model.py index dea34b6..10ebff1 100644 --- a/exps/GMOA/lfna_meta_model.py +++ b/exps/GMOA/lfna_meta_model.py @@ -19,7 +19,6 @@ class MetaModelV1(super_core.SuperModule): layer_dim, time_dim, meta_timestamps, - mha_depth: int = 2, dropout: float = 0.1, seq_length: int = 10, interval: float = None, @@ -69,22 +68,6 @@ class MetaModelV1(super_core.SuperModule): attn_drop=None, proj_drop=dropout, ) - layers = [] - for ilayer in range(mha_depth): - layers.append( - super_core.SuperTransformerEncoderLayer( - time_dim * 2, - 4, - True, - 4, - dropout, - norm_affine=False, - order=super_core.LayerOrder.PostNorm, - use_mask=True, - ) - ) - layers.append(super_core.SuperLinear(time_dim * 2, time_dim)) - self._meta_corrector = super_core.SuperSequential(*layers) model_kwargs = dict( config=dict(model_type="dual_norm_mlp"), @@ -103,13 +86,12 @@ class MetaModelV1(super_core.SuperModule): std=0.02, ) - def get_parameters(self, time_embed, meta_corrector, generator): + def get_parameters(self, time_embed, attention, generator): parameters = [] if time_embed: parameters.append(self._super_meta_embed) - if meta_corrector: + if attention: parameters.extend(list(self._trans_att.parameters())) - parameters.extend(list(self._meta_corrector.parameters())) if generator: parameters.append(self._super_layer_embed) parameters.extend(list(self._generator.parameters())) @@ -199,13 +181,7 @@ class MetaModelV1(super_core.SuperModule): timestamp_v_embed, mask, ) - relative_timestamps = timestamps - timestamps[:, :1] - relative_pos_embeds = self._tscalar_embed(relative_timestamps) - init_timestamp_embeds = torch.cat( - (timestamp_embeds, relative_pos_embeds), dim=-1 - ) - corrected_embeds = self._meta_corrector(init_timestamp_embeds) - return corrected_embeds + return timestamp_embeds def forward_raw(self, timestamps, time_embeds, get_seq_last): if time_embeds is None: @@ -264,9 +240,8 @@ class MetaModelV1(super_core.SuperModule): x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) with torch.set_grad_enabled(True): new_param = self.create_meta_embed() - optimizer = torch.optim.Adam( - [new_param], lr=lr, weight_decay=1e-5, amsgrad=True - ) + + optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True) timestamp = torch.Tensor([timestamp]).to(new_param.device) self.replace_append_learnt(timestamp, new_param) self.train() @@ -297,7 +272,7 @@ class MetaModelV1(super_core.SuperModule): with torch.no_grad(): self.replace_append_learnt(None, None) self.append_fixed(timestamp, best_new_param) - return True, meta_loss.item() + return True, best_loss def extra_repr(self) -> str: return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(