Update LFNA
This commit is contained in:
		| @@ -94,8 +94,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): | ||||
|  | ||||
|  | ||||
| def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     base_model.train() | ||||
|     meta_model.train() | ||||
|     optimizer = torch.optim.Adam( | ||||
|         meta_model.parameters(), | ||||
|         meta_model.get_parameters(True, True, True), | ||||
|         lr=args.lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
| @@ -103,13 +105,16 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     logger.log("Pre-train the meta-model") | ||||
|     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|     meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") | ||||
|     meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2") | ||||
|     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||
|     last_success_epoch = 0 | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|         left_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         total_meta_losses, total_match_losses = [], [] | ||||
|         optimizer.zero_grad() | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||
|             timestamps = meta_model.meta_timestamps[ | ||||
| @@ -118,7 +123,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|  | ||||
|             seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) | ||||
|             [seq_containers], time_embeds = meta_model( | ||||
|                 torch.unsqueeze(timestamps, dim=0) | ||||
|                 torch.unsqueeze(timestamps, dim=0), None | ||||
|             ) | ||||
|             # performance loss | ||||
|             losses = [] | ||||
| @@ -136,10 +141,10 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|                 torch.squeeze(time_embeds, dim=0), | ||||
|                 meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], | ||||
|             ) | ||||
|             # batch_loss = meta_loss + match_loss * 0.1 | ||||
|             # total_losses.append(batch_loss) | ||||
|             total_meta_losses.append(meta_loss) | ||||
|             total_match_losses.append(match_loss) | ||||
|         with torch.no_grad(): | ||||
|             meta_std = torch.stack(total_meta_losses).std().item() | ||||
|         final_meta_loss = torch.stack(total_meta_losses).mean() | ||||
|         final_match_loss = torch.stack(total_match_losses).mean() | ||||
|         total_loss = final_meta_loss + final_match_loss | ||||
| @@ -148,11 +153,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         # success | ||||
|         success, best_score = meta_model.save_best(-total_loss.item()) | ||||
|         logger.log( | ||||
|             "{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format( | ||||
|             "{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format( | ||||
|                 time_string(), | ||||
|                 iepoch, | ||||
|                 args.epochs, | ||||
|                 total_loss.item(), | ||||
|                 meta_std, | ||||
|                 final_meta_loss.item(), | ||||
|                 final_match_loss.item(), | ||||
|             ) | ||||
| @@ -160,11 +166,15 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) | ||||
|             + " {:}".format(left_time) | ||||
|         ) | ||||
|         if iepoch - last_success_epoch >= args.early_stop_thresh * 5: | ||||
|             logger.log("Early stop the pre-training at {:}".format(iepoch)) | ||||
|             break | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|     meta_model.load_best() | ||||
|  | ||||
|  | ||||
| def pretrain(base_model, meta_model, criterion, xenv, args, logger): | ||||
| def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     base_model.train() | ||||
|     meta_model.train() | ||||
|     optimizer = torch.optim.Adam( | ||||
| @@ -173,12 +183,13 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
|     ) | ||||
|     logger.log("Pre-train the meta-model") | ||||
|     logger.log("Pre-train the meta-model's embeddings") | ||||
|     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|     meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain") | ||||
|     meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1") | ||||
|     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     last_success_epoch = 0 | ||||
|     for iepoch in range(args.epochs): | ||||
|         left_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
| @@ -213,7 +224,7 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         # success | ||||
|         success, best_score = meta_model.save_best(-final_loss.item()) | ||||
|         logger.log( | ||||
|             "{:} [{:04d}/{:}] loss : {:.5f}".format( | ||||
|             "{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format( | ||||
|                 time_string(), | ||||
|                 iepoch, | ||||
|                 args.epochs, | ||||
| @@ -223,8 +234,12 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | ||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) | ||||
|             + " {:}".format(left_time) | ||||
|         ) | ||||
|         if iepoch - last_success_epoch >= args.early_stop_thresh * 5: | ||||
|             logger.log("Early stop the pre-training at {:}".format(iepoch)) | ||||
|             break | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|     meta_model.load_best() | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
| @@ -282,7 +297,7 @@ def main(args): | ||||
|     logger.log("The scheduler is\n{:}".format(lr_scheduler)) | ||||
|     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) | ||||
|  | ||||
|     pretrain(base_model, meta_model, criterion, train_env, args, logger) | ||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||
|  | ||||
|     if logger.path("model").exists(): | ||||
|         ckp_data = torch.load(logger.path("model")) | ||||
|   | ||||
| @@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         layer_embedding, | ||||
|         time_embedding, | ||||
|         meta_timestamps, | ||||
|         mha_depth: int = 1, | ||||
|         mha_depth: int = 2, | ||||
|         dropout: float = 0.1, | ||||
|     ): | ||||
|         super(LFNA_Meta, self).__init__() | ||||
| @@ -73,7 +73,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|                 ) | ||||
|             ) | ||||
|         layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) | ||||
|         self.meta_corrector = super_core.SuperSequential(*layers) | ||||
|         self._meta_corrector = super_core.SuperSequential(*layers) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             config=dict(model_type="dual_norm_mlp"), | ||||
| @@ -92,6 +92,18 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|             std=0.02, | ||||
|         ) | ||||
|  | ||||
|     def get_parameters(self, time_embed, meta_corrector, generator): | ||||
|         parameters = [] | ||||
|         if time_embed: | ||||
|             parameters.append(self._super_meta_embed) | ||||
|         if meta_corrector: | ||||
|             parameters.extend(list(self._trans_att.parameters())) | ||||
|             parameters.extend(list(self._meta_corrector.parameters())) | ||||
|         if generator: | ||||
|             parameters.append(self._super_layer_embed) | ||||
|             parameters.extend(list(self._generator.parameters())) | ||||
|         return parameters | ||||
|  | ||||
|     @property | ||||
|     def meta_timestamps(self): | ||||
|         with torch.no_grad(): | ||||
| @@ -159,7 +171,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         # relative_timestamps = timestamps - timestamps[:, :1] | ||||
|         # relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||
|         init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) | ||||
|         corrected_embeds = self.meta_corrector(init_timestamp_embeds) | ||||
|         corrected_embeds = self._meta_corrector(init_timestamp_embeds) | ||||
|         return corrected_embeds | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embed): | ||||
|   | ||||
| @@ -14,14 +14,14 @@ from xautodl import spaces | ||||
| from xautodl.xlayers import super_core | ||||
|  | ||||
|  | ||||
| class TestSuperAttention(unittest.TestCase): | ||||
| class TestSuperSelfAttention(unittest.TestCase): | ||||
|     """Test the super attention layer.""" | ||||
|  | ||||
|     def _internal_func(self, inputs, model): | ||||
|         outputs = model(inputs) | ||||
|         abstract_space = model.abstract_search_space | ||||
|         print( | ||||
|             "The abstract search space for SuperAttention is:\n{:}".format( | ||||
|             "The abstract search space for SuperSelfAttention is:\n{:}".format( | ||||
|                 abstract_space | ||||
|             ) | ||||
|         ) | ||||
| @@ -36,7 +36,7 @@ class TestSuperAttention(unittest.TestCase): | ||||
|     def test_super_attention(self): | ||||
|         proj_dim = spaces.Categorical(12, 24, 36) | ||||
|         num_heads = spaces.Categorical(2, 4, 6) | ||||
|         model = super_core.SuperAttention(10, proj_dim, num_heads) | ||||
|         model = super_core.SuperSelfAttention(10, proj_dim, num_heads) | ||||
|         print(model) | ||||
|         model.apply_verbose(True) | ||||
|  | ||||
|   | ||||
| @@ -78,7 +78,7 @@ class SuperSelfAttention(SuperModule): | ||||
|         return root_node | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperAttention, self).apply_candidate(abstract_child) | ||||
|         super(SuperSelfAttention, self).apply_candidate(abstract_child) | ||||
|         if "q_fc" in abstract_child: | ||||
|             self.q_fc.apply_candidate(abstract_child["q_fc"]) | ||||
|         if "k_fc" in abstract_child: | ||||
| @@ -222,7 +222,7 @@ class SuperQKVAttention(SuperModule): | ||||
|         return root_node | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperAttention, self).apply_candidate(abstract_child) | ||||
|         super(SuperQVKAttention, self).apply_candidate(abstract_child) | ||||
|         if "q_fc" in abstract_child: | ||||
|             self.q_fc.apply_candidate(abstract_child["q_fc"]) | ||||
|         if "k_fc" in abstract_child: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user