Update LFNA

This commit is contained in:
D-X-Y 2021-05-22 11:02:29 +00:00
parent ec241e4d69
commit 5b09f059fd
4 changed files with 46 additions and 19 deletions

View File

@ -94,8 +94,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
base_model.train()
meta_model.train()
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
meta_model.parameters(), meta_model.get_parameters(True, True, True),
lr=args.lr, lr=args.lr,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
amsgrad=True, amsgrad=True,
@ -103,13 +105,16 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
logger.log("Pre-train the meta-model") logger.log("Pre-train the meta-model")
logger.log("Using the optimizer: {:}".format(optimizer)) logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time() per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs): for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format( left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
) )
total_meta_losses, total_match_losses = [], [] total_meta_losses, total_match_losses = [], []
optimizer.zero_grad()
for ibatch in range(args.meta_batch): for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
timestamps = meta_model.meta_timestamps[ timestamps = meta_model.meta_timestamps[
@ -118,7 +123,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
[seq_containers], time_embeds = meta_model( [seq_containers], time_embeds = meta_model(
torch.unsqueeze(timestamps, dim=0) torch.unsqueeze(timestamps, dim=0), None
) )
# performance loss # performance loss
losses = [] losses = []
@ -136,10 +141,10 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
torch.squeeze(time_embeds, dim=0), torch.squeeze(time_embeds, dim=0),
meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length],
) )
# batch_loss = meta_loss + match_loss * 0.1
# total_losses.append(batch_loss)
total_meta_losses.append(meta_loss) total_meta_losses.append(meta_loss)
total_match_losses.append(match_loss) total_match_losses.append(match_loss)
with torch.no_grad():
meta_std = torch.stack(total_meta_losses).std().item()
final_meta_loss = torch.stack(total_meta_losses).mean() final_meta_loss = torch.stack(total_meta_losses).mean()
final_match_loss = torch.stack(total_match_losses).mean() final_match_loss = torch.stack(total_match_losses).mean()
total_loss = final_meta_loss + final_match_loss total_loss = final_meta_loss + final_match_loss
@ -148,11 +153,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
# success # success
success, best_score = meta_model.save_best(-total_loss.item()) success, best_score = meta_model.save_best(-total_loss.item())
logger.log( logger.log(
"{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format( "{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format(
time_string(), time_string(),
iepoch, iepoch,
args.epochs, args.epochs,
total_loss.item(), total_loss.item(),
meta_std,
final_meta_loss.item(), final_meta_loss.item(),
final_match_loss.item(), final_match_loss.item(),
) )
@ -160,11 +166,15 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
+ ", success={:}, best_score={:.4f}".format(success, -best_score) + ", success={:}, best_score={:.4f}".format(success, -best_score)
+ " {:}".format(left_time) + " {:}".format(left_time)
) )
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time) per_epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
meta_model.load_best()
def pretrain(base_model, meta_model, criterion, xenv, args, logger): def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
base_model.train() base_model.train()
meta_model.train() meta_model.train()
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
@ -173,12 +183,13 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
amsgrad=True, amsgrad=True,
) )
logger.log("Pre-train the meta-model") logger.log("Pre-train the meta-model's embeddings")
logger.log("Using the optimizer: {:}".format(optimizer)) logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain") meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1")
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
per_epoch_time, start_time = AverageMeter(), time.time() per_epoch_time, start_time = AverageMeter(), time.time()
last_success_epoch = 0
for iepoch in range(args.epochs): for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format( left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
@ -213,7 +224,7 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
# success # success
success, best_score = meta_model.save_best(-final_loss.item()) success, best_score = meta_model.save_best(-final_loss.item())
logger.log( logger.log(
"{:} [{:04d}/{:}] loss : {:.5f}".format( "{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format(
time_string(), time_string(),
iepoch, iepoch,
args.epochs, args.epochs,
@ -223,8 +234,12 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
+ ", success={:}, best_score={:.4f}".format(success, -best_score) + ", success={:}, best_score={:.4f}".format(success, -best_score)
+ " {:}".format(left_time) + " {:}".format(left_time)
) )
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time) per_epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
meta_model.load_best()
def main(args): def main(args):
@ -282,7 +297,7 @@ def main(args):
logger.log("The scheduler is\n{:}".format(lr_scheduler)) logger.log("The scheduler is\n{:}".format(lr_scheduler))
logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
pretrain(base_model, meta_model, criterion, train_env, args, logger) pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
if logger.path("model").exists(): if logger.path("model").exists():
ckp_data = torch.load(logger.path("model")) ckp_data = torch.load(logger.path("model"))

View File

@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule):
layer_embedding, layer_embedding,
time_embedding, time_embedding,
meta_timestamps, meta_timestamps,
mha_depth: int = 1, mha_depth: int = 2,
dropout: float = 0.1, dropout: float = 0.1,
): ):
super(LFNA_Meta, self).__init__() super(LFNA_Meta, self).__init__()
@ -73,7 +73,7 @@ class LFNA_Meta(super_core.SuperModule):
) )
) )
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
self.meta_corrector = super_core.SuperSequential(*layers) self._meta_corrector = super_core.SuperSequential(*layers)
model_kwargs = dict( model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"), config=dict(model_type="dual_norm_mlp"),
@ -92,6 +92,18 @@ class LFNA_Meta(super_core.SuperModule):
std=0.02, std=0.02,
) )
def get_parameters(self, time_embed, meta_corrector, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if meta_corrector:
parameters.extend(list(self._trans_att.parameters()))
parameters.extend(list(self._meta_corrector.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
return parameters
@property @property
def meta_timestamps(self): def meta_timestamps(self):
with torch.no_grad(): with torch.no_grad():
@ -159,7 +171,7 @@ class LFNA_Meta(super_core.SuperModule):
# relative_timestamps = timestamps - timestamps[:, :1] # relative_timestamps = timestamps - timestamps[:, :1]
# relative_pos_embeds = self._tscalar_embed(relative_timestamps) # relative_pos_embeds = self._tscalar_embed(relative_timestamps)
init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1)
corrected_embeds = self.meta_corrector(init_timestamp_embeds) corrected_embeds = self._meta_corrector(init_timestamp_embeds)
return corrected_embeds return corrected_embeds
def forward_raw(self, timestamps, time_embed): def forward_raw(self, timestamps, time_embed):

View File

@ -14,14 +14,14 @@ from xautodl import spaces
from xautodl.xlayers import super_core from xautodl.xlayers import super_core
class TestSuperAttention(unittest.TestCase): class TestSuperSelfAttention(unittest.TestCase):
"""Test the super attention layer.""" """Test the super attention layer."""
def _internal_func(self, inputs, model): def _internal_func(self, inputs, model):
outputs = model(inputs) outputs = model(inputs)
abstract_space = model.abstract_search_space abstract_space = model.abstract_search_space
print( print(
"The abstract search space for SuperAttention is:\n{:}".format( "The abstract search space for SuperSelfAttention is:\n{:}".format(
abstract_space abstract_space
) )
) )
@ -36,7 +36,7 @@ class TestSuperAttention(unittest.TestCase):
def test_super_attention(self): def test_super_attention(self):
proj_dim = spaces.Categorical(12, 24, 36) proj_dim = spaces.Categorical(12, 24, 36)
num_heads = spaces.Categorical(2, 4, 6) num_heads = spaces.Categorical(2, 4, 6)
model = super_core.SuperAttention(10, proj_dim, num_heads) model = super_core.SuperSelfAttention(10, proj_dim, num_heads)
print(model) print(model)
model.apply_verbose(True) model.apply_verbose(True)

View File

@ -78,7 +78,7 @@ class SuperSelfAttention(SuperModule):
return root_node return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode): def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperAttention, self).apply_candidate(abstract_child) super(SuperSelfAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child: if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"]) self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child: if "k_fc" in abstract_child:
@ -222,7 +222,7 @@ class SuperQKVAttention(SuperModule):
return root_node return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode): def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperAttention, self).apply_candidate(abstract_child) super(SuperQVKAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child: if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"]) self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child: if "k_fc" in abstract_child: