Update LFNA

This commit is contained in:
D-X-Y 2021-05-22 11:02:29 +00:00
parent ec241e4d69
commit 5b09f059fd
4 changed files with 46 additions and 19 deletions

View File

@ -94,8 +94,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
base_model.train()
meta_model.train()
optimizer = torch.optim.Adam(
meta_model.parameters(),
meta_model.get_parameters(True, True, True),
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=True,
@ -103,13 +105,16 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
logger.log("Pre-train the meta-model")
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain")
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
total_meta_losses, total_match_losses = [], []
optimizer.zero_grad()
for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
timestamps = meta_model.meta_timestamps[
@ -118,7 +123,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
[seq_containers], time_embeds = meta_model(
torch.unsqueeze(timestamps, dim=0)
torch.unsqueeze(timestamps, dim=0), None
)
# performance loss
losses = []
@ -136,10 +141,10 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
torch.squeeze(time_embeds, dim=0),
meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length],
)
# batch_loss = meta_loss + match_loss * 0.1
# total_losses.append(batch_loss)
total_meta_losses.append(meta_loss)
total_match_losses.append(match_loss)
with torch.no_grad():
meta_std = torch.stack(total_meta_losses).std().item()
final_meta_loss = torch.stack(total_meta_losses).mean()
final_match_loss = torch.stack(total_match_losses).mean()
total_loss = final_meta_loss + final_match_loss
@ -148,11 +153,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
# success
success, best_score = meta_model.save_best(-total_loss.item())
logger.log(
"{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format(
"{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format(
time_string(),
iepoch,
args.epochs,
total_loss.item(),
meta_std,
final_meta_loss.item(),
final_match_loss.item(),
)
@ -160,11 +166,15 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ " {:}".format(left_time)
)
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
meta_model.load_best()
def pretrain(base_model, meta_model, criterion, xenv, args, logger):
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
base_model.train()
meta_model.train()
optimizer = torch.optim.Adam(
@ -173,12 +183,13 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
weight_decay=args.weight_decay,
amsgrad=True,
)
logger.log("Pre-train the meta-model")
logger.log("Pre-train the meta-model's embeddings")
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain")
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1")
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
per_epoch_time, start_time = AverageMeter(), time.time()
last_success_epoch = 0
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
@ -213,7 +224,7 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
# success
success, best_score = meta_model.save_best(-final_loss.item())
logger.log(
"{:} [{:04d}/{:}] loss : {:.5f}".format(
"{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format(
time_string(),
iepoch,
args.epochs,
@ -223,8 +234,12 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ " {:}".format(left_time)
)
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
meta_model.load_best()
def main(args):
@ -282,7 +297,7 @@ def main(args):
logger.log("The scheduler is\n{:}".format(lr_scheduler))
logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
pretrain(base_model, meta_model, criterion, train_env, args, logger)
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
if logger.path("model").exists():
ckp_data = torch.load(logger.path("model"))

View File

@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule):
layer_embedding,
time_embedding,
meta_timestamps,
mha_depth: int = 1,
mha_depth: int = 2,
dropout: float = 0.1,
):
super(LFNA_Meta, self).__init__()
@ -73,7 +73,7 @@ class LFNA_Meta(super_core.SuperModule):
)
)
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
self.meta_corrector = super_core.SuperSequential(*layers)
self._meta_corrector = super_core.SuperSequential(*layers)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
@ -92,6 +92,18 @@ class LFNA_Meta(super_core.SuperModule):
std=0.02,
)
def get_parameters(self, time_embed, meta_corrector, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if meta_corrector:
parameters.extend(list(self._trans_att.parameters()))
parameters.extend(list(self._meta_corrector.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
return parameters
@property
def meta_timestamps(self):
with torch.no_grad():
@ -159,7 +171,7 @@ class LFNA_Meta(super_core.SuperModule):
# relative_timestamps = timestamps - timestamps[:, :1]
# relative_pos_embeds = self._tscalar_embed(relative_timestamps)
init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1)
corrected_embeds = self.meta_corrector(init_timestamp_embeds)
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
return corrected_embeds
def forward_raw(self, timestamps, time_embed):

View File

@ -14,14 +14,14 @@ from xautodl import spaces
from xautodl.xlayers import super_core
class TestSuperAttention(unittest.TestCase):
class TestSuperSelfAttention(unittest.TestCase):
"""Test the super attention layer."""
def _internal_func(self, inputs, model):
outputs = model(inputs)
abstract_space = model.abstract_search_space
print(
"The abstract search space for SuperAttention is:\n{:}".format(
"The abstract search space for SuperSelfAttention is:\n{:}".format(
abstract_space
)
)
@ -36,7 +36,7 @@ class TestSuperAttention(unittest.TestCase):
def test_super_attention(self):
proj_dim = spaces.Categorical(12, 24, 36)
num_heads = spaces.Categorical(2, 4, 6)
model = super_core.SuperAttention(10, proj_dim, num_heads)
model = super_core.SuperSelfAttention(10, proj_dim, num_heads)
print(model)
model.apply_verbose(True)

View File

@ -78,7 +78,7 @@ class SuperSelfAttention(SuperModule):
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperAttention, self).apply_candidate(abstract_child)
super(SuperSelfAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child:
@ -222,7 +222,7 @@ class SuperQKVAttention(SuperModule):
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperAttention, self).apply_candidate(abstract_child)
super(SuperQVKAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child: