Remove unnecessary model in GMOA

This commit is contained in:
D-X-Y 2021-05-24 13:14:18 +00:00
parent 7787c1b3c7
commit 6c1fd745d7
2 changed files with 8 additions and 33 deletions

View File

@ -337,11 +337,11 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--refine_lr", "--refine_lr",
type=float, type=float,
default=0.002, default=0.001,
help="The learning rate for the optimizer, during refine", help="The learning rate for the optimizer, during refine",
) )
parser.add_argument( parser.add_argument(
"--refine_epochs", type=int, default=100, help="The final refine #epochs." "--refine_epochs", type=int, default=150, help="The final refine #epochs."
) )
parser.add_argument( parser.add_argument(
"--early_stop_thresh", "--early_stop_thresh",

View File

@ -19,7 +19,6 @@ class MetaModelV1(super_core.SuperModule):
layer_dim, layer_dim,
time_dim, time_dim,
meta_timestamps, meta_timestamps,
mha_depth: int = 2,
dropout: float = 0.1, dropout: float = 0.1,
seq_length: int = 10, seq_length: int = 10,
interval: float = None, interval: float = None,
@ -69,22 +68,6 @@ class MetaModelV1(super_core.SuperModule):
attn_drop=None, attn_drop=None,
proj_drop=dropout, proj_drop=dropout,
) )
layers = []
for ilayer in range(mha_depth):
layers.append(
super_core.SuperTransformerEncoderLayer(
time_dim * 2,
4,
True,
4,
dropout,
norm_affine=False,
order=super_core.LayerOrder.PostNorm,
use_mask=True,
)
)
layers.append(super_core.SuperLinear(time_dim * 2, time_dim))
self._meta_corrector = super_core.SuperSequential(*layers)
model_kwargs = dict( model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"), config=dict(model_type="dual_norm_mlp"),
@ -103,13 +86,12 @@ class MetaModelV1(super_core.SuperModule):
std=0.02, std=0.02,
) )
def get_parameters(self, time_embed, meta_corrector, generator): def get_parameters(self, time_embed, attention, generator):
parameters = [] parameters = []
if time_embed: if time_embed:
parameters.append(self._super_meta_embed) parameters.append(self._super_meta_embed)
if meta_corrector: if attention:
parameters.extend(list(self._trans_att.parameters())) parameters.extend(list(self._trans_att.parameters()))
parameters.extend(list(self._meta_corrector.parameters()))
if generator: if generator:
parameters.append(self._super_layer_embed) parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters())) parameters.extend(list(self._generator.parameters()))
@ -199,13 +181,7 @@ class MetaModelV1(super_core.SuperModule):
timestamp_v_embed, timestamp_v_embed,
mask, mask,
) )
relative_timestamps = timestamps - timestamps[:, :1] return timestamp_embeds
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
init_timestamp_embeds = torch.cat(
(timestamp_embeds, relative_pos_embeds), dim=-1
)
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
return corrected_embeds
def forward_raw(self, timestamps, time_embeds, get_seq_last): def forward_raw(self, timestamps, time_embeds, get_seq_last):
if time_embeds is None: if time_embeds is None:
@ -264,9 +240,8 @@ class MetaModelV1(super_core.SuperModule):
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
new_param = self.create_meta_embed() new_param = self.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True)
)
timestamp = torch.Tensor([timestamp]).to(new_param.device) timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param) self.replace_append_learnt(timestamp, new_param)
self.train() self.train()
@ -297,7 +272,7 @@ class MetaModelV1(super_core.SuperModule):
with torch.no_grad(): with torch.no_grad():
self.replace_append_learnt(None, None) self.replace_append_learnt(None, None)
self.append_fixed(timestamp, best_new_param) self.append_fixed(timestamp, best_new_param)
return True, meta_loss.item() return True, best_loss
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(