Remove unnecessary model in GMOA
This commit is contained in:
parent
7787c1b3c7
commit
6c1fd745d7
@ -337,11 +337,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--refine_lr",
|
||||
type=float,
|
||||
default=0.002,
|
||||
default=0.001,
|
||||
help="The learning rate for the optimizer, during refine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--refine_epochs", type=int, default=100, help="The final refine #epochs."
|
||||
"--refine_epochs", type=int, default=150, help="The final refine #epochs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stop_thresh",
|
||||
|
@ -19,7 +19,6 @@ class MetaModelV1(super_core.SuperModule):
|
||||
layer_dim,
|
||||
time_dim,
|
||||
meta_timestamps,
|
||||
mha_depth: int = 2,
|
||||
dropout: float = 0.1,
|
||||
seq_length: int = 10,
|
||||
interval: float = None,
|
||||
@ -69,22 +68,6 @@ class MetaModelV1(super_core.SuperModule):
|
||||
attn_drop=None,
|
||||
proj_drop=dropout,
|
||||
)
|
||||
layers = []
|
||||
for ilayer in range(mha_depth):
|
||||
layers.append(
|
||||
super_core.SuperTransformerEncoderLayer(
|
||||
time_dim * 2,
|
||||
4,
|
||||
True,
|
||||
4,
|
||||
dropout,
|
||||
norm_affine=False,
|
||||
order=super_core.LayerOrder.PostNorm,
|
||||
use_mask=True,
|
||||
)
|
||||
)
|
||||
layers.append(super_core.SuperLinear(time_dim * 2, time_dim))
|
||||
self._meta_corrector = super_core.SuperSequential(*layers)
|
||||
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="dual_norm_mlp"),
|
||||
@ -103,13 +86,12 @@ class MetaModelV1(super_core.SuperModule):
|
||||
std=0.02,
|
||||
)
|
||||
|
||||
def get_parameters(self, time_embed, meta_corrector, generator):
|
||||
def get_parameters(self, time_embed, attention, generator):
|
||||
parameters = []
|
||||
if time_embed:
|
||||
parameters.append(self._super_meta_embed)
|
||||
if meta_corrector:
|
||||
if attention:
|
||||
parameters.extend(list(self._trans_att.parameters()))
|
||||
parameters.extend(list(self._meta_corrector.parameters()))
|
||||
if generator:
|
||||
parameters.append(self._super_layer_embed)
|
||||
parameters.extend(list(self._generator.parameters()))
|
||||
@ -199,13 +181,7 @@ class MetaModelV1(super_core.SuperModule):
|
||||
timestamp_v_embed,
|
||||
mask,
|
||||
)
|
||||
relative_timestamps = timestamps - timestamps[:, :1]
|
||||
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
||||
init_timestamp_embeds = torch.cat(
|
||||
(timestamp_embeds, relative_pos_embeds), dim=-1
|
||||
)
|
||||
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
|
||||
return corrected_embeds
|
||||
return timestamp_embeds
|
||||
|
||||
def forward_raw(self, timestamps, time_embeds, get_seq_last):
|
||||
if time_embeds is None:
|
||||
@ -264,9 +240,8 @@ class MetaModelV1(super_core.SuperModule):
|
||||
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
|
||||
with torch.set_grad_enabled(True):
|
||||
new_param = self.create_meta_embed()
|
||||
optimizer = torch.optim.Adam(
|
||||
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
|
||||
)
|
||||
|
||||
optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True)
|
||||
timestamp = torch.Tensor([timestamp]).to(new_param.device)
|
||||
self.replace_append_learnt(timestamp, new_param)
|
||||
self.train()
|
||||
@ -297,7 +272,7 @@ class MetaModelV1(super_core.SuperModule):
|
||||
with torch.no_grad():
|
||||
self.replace_append_learnt(None, None)
|
||||
self.append_fixed(timestamp, best_new_param)
|
||||
return True, meta_loss.item()
|
||||
return True, best_loss
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
|
||||
|
Loading…
Reference in New Issue
Block a user