Remove unnecessary model in GMOA
This commit is contained in:
parent
7787c1b3c7
commit
6c1fd745d7
@ -337,11 +337,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--refine_lr",
|
"--refine_lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.002,
|
default=0.001,
|
||||||
help="The learning rate for the optimizer, during refine",
|
help="The learning rate for the optimizer, during refine",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--refine_epochs", type=int, default=100, help="The final refine #epochs."
|
"--refine_epochs", type=int, default=150, help="The final refine #epochs."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stop_thresh",
|
"--early_stop_thresh",
|
||||||
|
@ -19,7 +19,6 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
layer_dim,
|
layer_dim,
|
||||||
time_dim,
|
time_dim,
|
||||||
meta_timestamps,
|
meta_timestamps,
|
||||||
mha_depth: int = 2,
|
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
seq_length: int = 10,
|
seq_length: int = 10,
|
||||||
interval: float = None,
|
interval: float = None,
|
||||||
@ -69,22 +68,6 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
attn_drop=None,
|
attn_drop=None,
|
||||||
proj_drop=dropout,
|
proj_drop=dropout,
|
||||||
)
|
)
|
||||||
layers = []
|
|
||||||
for ilayer in range(mha_depth):
|
|
||||||
layers.append(
|
|
||||||
super_core.SuperTransformerEncoderLayer(
|
|
||||||
time_dim * 2,
|
|
||||||
4,
|
|
||||||
True,
|
|
||||||
4,
|
|
||||||
dropout,
|
|
||||||
norm_affine=False,
|
|
||||||
order=super_core.LayerOrder.PostNorm,
|
|
||||||
use_mask=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
layers.append(super_core.SuperLinear(time_dim * 2, time_dim))
|
|
||||||
self._meta_corrector = super_core.SuperSequential(*layers)
|
|
||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
config=dict(model_type="dual_norm_mlp"),
|
config=dict(model_type="dual_norm_mlp"),
|
||||||
@ -103,13 +86,12 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
std=0.02,
|
std=0.02,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_parameters(self, time_embed, meta_corrector, generator):
|
def get_parameters(self, time_embed, attention, generator):
|
||||||
parameters = []
|
parameters = []
|
||||||
if time_embed:
|
if time_embed:
|
||||||
parameters.append(self._super_meta_embed)
|
parameters.append(self._super_meta_embed)
|
||||||
if meta_corrector:
|
if attention:
|
||||||
parameters.extend(list(self._trans_att.parameters()))
|
parameters.extend(list(self._trans_att.parameters()))
|
||||||
parameters.extend(list(self._meta_corrector.parameters()))
|
|
||||||
if generator:
|
if generator:
|
||||||
parameters.append(self._super_layer_embed)
|
parameters.append(self._super_layer_embed)
|
||||||
parameters.extend(list(self._generator.parameters()))
|
parameters.extend(list(self._generator.parameters()))
|
||||||
@ -199,13 +181,7 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
timestamp_v_embed,
|
timestamp_v_embed,
|
||||||
mask,
|
mask,
|
||||||
)
|
)
|
||||||
relative_timestamps = timestamps - timestamps[:, :1]
|
return timestamp_embeds
|
||||||
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
|
||||||
init_timestamp_embeds = torch.cat(
|
|
||||||
(timestamp_embeds, relative_pos_embeds), dim=-1
|
|
||||||
)
|
|
||||||
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
|
|
||||||
return corrected_embeds
|
|
||||||
|
|
||||||
def forward_raw(self, timestamps, time_embeds, get_seq_last):
|
def forward_raw(self, timestamps, time_embeds, get_seq_last):
|
||||||
if time_embeds is None:
|
if time_embeds is None:
|
||||||
@ -264,9 +240,8 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
|
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
|
||||||
with torch.set_grad_enabled(True):
|
with torch.set_grad_enabled(True):
|
||||||
new_param = self.create_meta_embed()
|
new_param = self.create_meta_embed()
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
|
optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True)
|
||||||
)
|
|
||||||
timestamp = torch.Tensor([timestamp]).to(new_param.device)
|
timestamp = torch.Tensor([timestamp]).to(new_param.device)
|
||||||
self.replace_append_learnt(timestamp, new_param)
|
self.replace_append_learnt(timestamp, new_param)
|
||||||
self.train()
|
self.train()
|
||||||
@ -297,7 +272,7 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.replace_append_learnt(None, None)
|
self.replace_append_learnt(None, None)
|
||||||
self.append_fixed(timestamp, best_new_param)
|
self.append_fixed(timestamp, best_new_param)
|
||||||
return True, meta_loss.item()
|
return True, best_loss
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
|
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
|
||||||
|
Loading…
Reference in New Issue
Block a user