diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 00a998b..db7f36b 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -94,8 +94,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): + base_model.train() + meta_model.train() optimizer = torch.optim.Adam( - meta_model.parameters(), + meta_model.get_parameters(True, True, True), lr=args.lr, weight_decay=args.weight_decay, amsgrad=True, @@ -103,13 +105,16 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): logger.log("Pre-train the meta-model") logger.log("Using the optimizer: {:}".format(optimizer)) - meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") + meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2") + meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) + last_success_epoch = 0 per_epoch_time, start_time = AverageMeter(), time.time() for iepoch in range(args.epochs): left_time = "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) ) total_meta_losses, total_match_losses = [], [] + optimizer.zero_grad() for ibatch in range(args.meta_batch): rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) timestamps = meta_model.meta_timestamps[ @@ -118,7 +123,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) [seq_containers], time_embeds = meta_model( - torch.unsqueeze(timestamps, dim=0) + torch.unsqueeze(timestamps, dim=0), None ) # performance loss losses = [] @@ -136,10 +141,10 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): torch.squeeze(time_embeds, dim=0), meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], ) - # batch_loss = meta_loss + match_loss * 0.1 - # total_losses.append(batch_loss) total_meta_losses.append(meta_loss) total_match_losses.append(match_loss) + with torch.no_grad(): + meta_std = torch.stack(total_meta_losses).std().item() final_meta_loss = torch.stack(total_meta_losses).mean() final_match_loss = torch.stack(total_match_losses).mean() total_loss = final_meta_loss + final_match_loss @@ -148,11 +153,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): # success success, best_score = meta_model.save_best(-total_loss.item()) logger.log( - "{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format( + "{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format( time_string(), iepoch, args.epochs, total_loss.item(), + meta_std, final_meta_loss.item(), final_match_loss.item(), ) @@ -160,11 +166,15 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): + ", success={:}, best_score={:.4f}".format(success, -best_score) + " {:}".format(left_time) ) + if iepoch - last_success_epoch >= args.early_stop_thresh * 5: + logger.log("Early stop the pre-training at {:}".format(iepoch)) + break per_epoch_time.update(time.time() - start_time) start_time = time.time() + meta_model.load_best() -def pretrain(base_model, meta_model, criterion, xenv, args, logger): +def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): base_model.train() meta_model.train() optimizer = torch.optim.Adam( @@ -173,12 +183,13 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): weight_decay=args.weight_decay, amsgrad=True, ) - logger.log("Pre-train the meta-model") + logger.log("Pre-train the meta-model's embeddings") logger.log("Using the optimizer: {:}".format(optimizer)) - meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain") + meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1") meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) per_epoch_time, start_time = AverageMeter(), time.time() + last_success_epoch = 0 for iepoch in range(args.epochs): left_time = "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) @@ -213,7 +224,7 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): # success success, best_score = meta_model.save_best(-final_loss.item()) logger.log( - "{:} [{:04d}/{:}] loss : {:.5f}".format( + "{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format( time_string(), iepoch, args.epochs, @@ -223,8 +234,12 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): + ", success={:}, best_score={:.4f}".format(success, -best_score) + " {:}".format(left_time) ) + if iepoch - last_success_epoch >= args.early_stop_thresh * 5: + logger.log("Early stop the pre-training at {:}".format(iepoch)) + break per_epoch_time.update(time.time() - start_time) start_time = time.time() + meta_model.load_best() def main(args): @@ -282,7 +297,7 @@ def main(args): logger.log("The scheduler is\n{:}".format(lr_scheduler)) logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) - pretrain(base_model, meta_model, criterion, train_env, args, logger) + pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) if logger.path("model").exists(): ckp_data = torch.load(logger.path("model")) diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index 10cdeb2..10cfd51 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule): layer_embedding, time_embedding, meta_timestamps, - mha_depth: int = 1, + mha_depth: int = 2, dropout: float = 0.1, ): super(LFNA_Meta, self).__init__() @@ -73,7 +73,7 @@ class LFNA_Meta(super_core.SuperModule): ) ) layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) - self.meta_corrector = super_core.SuperSequential(*layers) + self._meta_corrector = super_core.SuperSequential(*layers) model_kwargs = dict( config=dict(model_type="dual_norm_mlp"), @@ -92,6 +92,18 @@ class LFNA_Meta(super_core.SuperModule): std=0.02, ) + def get_parameters(self, time_embed, meta_corrector, generator): + parameters = [] + if time_embed: + parameters.append(self._super_meta_embed) + if meta_corrector: + parameters.extend(list(self._trans_att.parameters())) + parameters.extend(list(self._meta_corrector.parameters())) + if generator: + parameters.append(self._super_layer_embed) + parameters.extend(list(self._generator.parameters())) + return parameters + @property def meta_timestamps(self): with torch.no_grad(): @@ -159,7 +171,7 @@ class LFNA_Meta(super_core.SuperModule): # relative_timestamps = timestamps - timestamps[:, :1] # relative_pos_embeds = self._tscalar_embed(relative_timestamps) init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) - corrected_embeds = self.meta_corrector(init_timestamp_embeds) + corrected_embeds = self._meta_corrector(init_timestamp_embeds) return corrected_embeds def forward_raw(self, timestamps, time_embed): diff --git a/tests/test_super_att.py b/tests/test_super_att.py index 48f8bf6..fc55b1c 100644 --- a/tests/test_super_att.py +++ b/tests/test_super_att.py @@ -14,14 +14,14 @@ from xautodl import spaces from xautodl.xlayers import super_core -class TestSuperAttention(unittest.TestCase): +class TestSuperSelfAttention(unittest.TestCase): """Test the super attention layer.""" def _internal_func(self, inputs, model): outputs = model(inputs) abstract_space = model.abstract_search_space print( - "The abstract search space for SuperAttention is:\n{:}".format( + "The abstract search space for SuperSelfAttention is:\n{:}".format( abstract_space ) ) @@ -36,7 +36,7 @@ class TestSuperAttention(unittest.TestCase): def test_super_attention(self): proj_dim = spaces.Categorical(12, 24, 36) num_heads = spaces.Categorical(2, 4, 6) - model = super_core.SuperAttention(10, proj_dim, num_heads) + model = super_core.SuperSelfAttention(10, proj_dim, num_heads) print(model) model.apply_verbose(True) diff --git a/xautodl/xlayers/super_attention.py b/xautodl/xlayers/super_attention.py index 924cdd0..a1eb317 100644 --- a/xautodl/xlayers/super_attention.py +++ b/xautodl/xlayers/super_attention.py @@ -78,7 +78,7 @@ class SuperSelfAttention(SuperModule): return root_node def apply_candidate(self, abstract_child: spaces.VirtualNode): - super(SuperAttention, self).apply_candidate(abstract_child) + super(SuperSelfAttention, self).apply_candidate(abstract_child) if "q_fc" in abstract_child: self.q_fc.apply_candidate(abstract_child["q_fc"]) if "k_fc" in abstract_child: @@ -222,7 +222,7 @@ class SuperQKVAttention(SuperModule): return root_node def apply_candidate(self, abstract_child: spaces.VirtualNode): - super(SuperAttention, self).apply_candidate(abstract_child) + super(SuperQVKAttention, self).apply_candidate(abstract_child) if "q_fc" in abstract_child: self.q_fc.apply_candidate(abstract_child["q_fc"]) if "k_fc" in abstract_child: