Remove unnecessary model in GMOA

This commit is contained in:
D-X-Y 2021-05-24 13:14:18 +00:00
parent 7787c1b3c7
commit 6c1fd745d7
2 changed files with 8 additions and 33 deletions

View File

@ -337,11 +337,11 @@ if __name__ == "__main__":
parser.add_argument(
"--refine_lr",
type=float,
default=0.002,
default=0.001,
help="The learning rate for the optimizer, during refine",
)
parser.add_argument(
"--refine_epochs", type=int, default=100, help="The final refine #epochs."
"--refine_epochs", type=int, default=150, help="The final refine #epochs."
)
parser.add_argument(
"--early_stop_thresh",

View File

@ -19,7 +19,6 @@ class MetaModelV1(super_core.SuperModule):
layer_dim,
time_dim,
meta_timestamps,
mha_depth: int = 2,
dropout: float = 0.1,
seq_length: int = 10,
interval: float = None,
@ -69,22 +68,6 @@ class MetaModelV1(super_core.SuperModule):
attn_drop=None,
proj_drop=dropout,
)
layers = []
for ilayer in range(mha_depth):
layers.append(
super_core.SuperTransformerEncoderLayer(
time_dim * 2,
4,
True,
4,
dropout,
norm_affine=False,
order=super_core.LayerOrder.PostNorm,
use_mask=True,
)
)
layers.append(super_core.SuperLinear(time_dim * 2, time_dim))
self._meta_corrector = super_core.SuperSequential(*layers)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
@ -103,13 +86,12 @@ class MetaModelV1(super_core.SuperModule):
std=0.02,
)
def get_parameters(self, time_embed, meta_corrector, generator):
def get_parameters(self, time_embed, attention, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if meta_corrector:
if attention:
parameters.extend(list(self._trans_att.parameters()))
parameters.extend(list(self._meta_corrector.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
@ -199,13 +181,7 @@ class MetaModelV1(super_core.SuperModule):
timestamp_v_embed,
mask,
)
relative_timestamps = timestamps - timestamps[:, :1]
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
init_timestamp_embeds = torch.cat(
(timestamp_embeds, relative_pos_embeds), dim=-1
)
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
return corrected_embeds
return timestamp_embeds
def forward_raw(self, timestamps, time_embeds, get_seq_last):
if time_embeds is None:
@ -264,9 +240,8 @@ class MetaModelV1(super_core.SuperModule):
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
with torch.set_grad_enabled(True):
new_param = self.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
)
optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True)
timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param)
self.train()
@ -297,7 +272,7 @@ class MetaModelV1(super_core.SuperModule):
with torch.no_grad():
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, best_new_param)
return True, meta_loss.item()
return True, best_loss
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(