Update LFNA
This commit is contained in:
		| @@ -94,8 +94,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): | |||||||
|  |  | ||||||
|  |  | ||||||
| def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||||
|  |     base_model.train() | ||||||
|  |     meta_model.train() | ||||||
|     optimizer = torch.optim.Adam( |     optimizer = torch.optim.Adam( | ||||||
|         meta_model.parameters(), |         meta_model.get_parameters(True, True, True), | ||||||
|         lr=args.lr, |         lr=args.lr, | ||||||
|         weight_decay=args.weight_decay, |         weight_decay=args.weight_decay, | ||||||
|         amsgrad=True, |         amsgrad=True, | ||||||
| @@ -103,13 +105,16 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|     logger.log("Pre-train the meta-model") |     logger.log("Pre-train the meta-model") | ||||||
|     logger.log("Using the optimizer: {:}".format(optimizer)) |     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||||
|  |  | ||||||
|     meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") |     meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2") | ||||||
|  |     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||||
|  |     last_success_epoch = 0 | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|     for iepoch in range(args.epochs): |     for iepoch in range(args.epochs): | ||||||
|         left_time = "Time Left: {:}".format( |         left_time = "Time Left: {:}".format( | ||||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
|         ) |         ) | ||||||
|         total_meta_losses, total_match_losses = [], [] |         total_meta_losses, total_match_losses = [], [] | ||||||
|  |         optimizer.zero_grad() | ||||||
|         for ibatch in range(args.meta_batch): |         for ibatch in range(args.meta_batch): | ||||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) |             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||||
|             timestamps = meta_model.meta_timestamps[ |             timestamps = meta_model.meta_timestamps[ | ||||||
| @@ -118,7 +123,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|  |  | ||||||
|             seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) |             seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) | ||||||
|             [seq_containers], time_embeds = meta_model( |             [seq_containers], time_embeds = meta_model( | ||||||
|                 torch.unsqueeze(timestamps, dim=0) |                 torch.unsqueeze(timestamps, dim=0), None | ||||||
|             ) |             ) | ||||||
|             # performance loss |             # performance loss | ||||||
|             losses = [] |             losses = [] | ||||||
| @@ -136,10 +141,10 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|                 torch.squeeze(time_embeds, dim=0), |                 torch.squeeze(time_embeds, dim=0), | ||||||
|                 meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], |                 meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], | ||||||
|             ) |             ) | ||||||
|             # batch_loss = meta_loss + match_loss * 0.1 |  | ||||||
|             # total_losses.append(batch_loss) |  | ||||||
|             total_meta_losses.append(meta_loss) |             total_meta_losses.append(meta_loss) | ||||||
|             total_match_losses.append(match_loss) |             total_match_losses.append(match_loss) | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             meta_std = torch.stack(total_meta_losses).std().item() | ||||||
|         final_meta_loss = torch.stack(total_meta_losses).mean() |         final_meta_loss = torch.stack(total_meta_losses).mean() | ||||||
|         final_match_loss = torch.stack(total_match_losses).mean() |         final_match_loss = torch.stack(total_match_losses).mean() | ||||||
|         total_loss = final_meta_loss + final_match_loss |         total_loss = final_meta_loss + final_match_loss | ||||||
| @@ -148,11 +153,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         # success |         # success | ||||||
|         success, best_score = meta_model.save_best(-total_loss.item()) |         success, best_score = meta_model.save_best(-total_loss.item()) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format( |             "{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format( | ||||||
|                 time_string(), |                 time_string(), | ||||||
|                 iepoch, |                 iepoch, | ||||||
|                 args.epochs, |                 args.epochs, | ||||||
|                 total_loss.item(), |                 total_loss.item(), | ||||||
|  |                 meta_std, | ||||||
|                 final_meta_loss.item(), |                 final_meta_loss.item(), | ||||||
|                 final_match_loss.item(), |                 final_match_loss.item(), | ||||||
|             ) |             ) | ||||||
| @@ -160,11 +166,15 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) |             + ", success={:}, best_score={:.4f}".format(success, -best_score) | ||||||
|             + " {:}".format(left_time) |             + " {:}".format(left_time) | ||||||
|         ) |         ) | ||||||
|  |         if iepoch - last_success_epoch >= args.early_stop_thresh * 5: | ||||||
|  |             logger.log("Early stop the pre-training at {:}".format(iepoch)) | ||||||
|  |             break | ||||||
|         per_epoch_time.update(time.time() - start_time) |         per_epoch_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |     meta_model.load_best() | ||||||
|  |  | ||||||
|  |  | ||||||
| def pretrain(base_model, meta_model, criterion, xenv, args, logger): | def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): | ||||||
|     base_model.train() |     base_model.train() | ||||||
|     meta_model.train() |     meta_model.train() | ||||||
|     optimizer = torch.optim.Adam( |     optimizer = torch.optim.Adam( | ||||||
| @@ -173,12 +183,13 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         weight_decay=args.weight_decay, |         weight_decay=args.weight_decay, | ||||||
|         amsgrad=True, |         amsgrad=True, | ||||||
|     ) |     ) | ||||||
|     logger.log("Pre-train the meta-model") |     logger.log("Pre-train the meta-model's embeddings") | ||||||
|     logger.log("Using the optimizer: {:}".format(optimizer)) |     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||||
|  |  | ||||||
|     meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain") |     meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1") | ||||||
|     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) |     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|  |     last_success_epoch = 0 | ||||||
|     for iepoch in range(args.epochs): |     for iepoch in range(args.epochs): | ||||||
|         left_time = "Time Left: {:}".format( |         left_time = "Time Left: {:}".format( | ||||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
| @@ -213,7 +224,7 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         # success |         # success | ||||||
|         success, best_score = meta_model.save_best(-final_loss.item()) |         success, best_score = meta_model.save_best(-final_loss.item()) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "{:} [{:04d}/{:}] loss : {:.5f}".format( |             "{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format( | ||||||
|                 time_string(), |                 time_string(), | ||||||
|                 iepoch, |                 iepoch, | ||||||
|                 args.epochs, |                 args.epochs, | ||||||
| @@ -223,8 +234,12 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) |             + ", success={:}, best_score={:.4f}".format(success, -best_score) | ||||||
|             + " {:}".format(left_time) |             + " {:}".format(left_time) | ||||||
|         ) |         ) | ||||||
|  |         if iepoch - last_success_epoch >= args.early_stop_thresh * 5: | ||||||
|  |             logger.log("Early stop the pre-training at {:}".format(iepoch)) | ||||||
|  |             break | ||||||
|         per_epoch_time.update(time.time() - start_time) |         per_epoch_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |     meta_model.load_best() | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
| @@ -282,7 +297,7 @@ def main(args): | |||||||
|     logger.log("The scheduler is\n{:}".format(lr_scheduler)) |     logger.log("The scheduler is\n{:}".format(lr_scheduler)) | ||||||
|     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) |     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) | ||||||
|  |  | ||||||
|     pretrain(base_model, meta_model, criterion, train_env, args, logger) |     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||||
|  |  | ||||||
|     if logger.path("model").exists(): |     if logger.path("model").exists(): | ||||||
|         ckp_data = torch.load(logger.path("model")) |         ckp_data = torch.load(logger.path("model")) | ||||||
|   | |||||||
| @@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         layer_embedding, |         layer_embedding, | ||||||
|         time_embedding, |         time_embedding, | ||||||
|         meta_timestamps, |         meta_timestamps, | ||||||
|         mha_depth: int = 1, |         mha_depth: int = 2, | ||||||
|         dropout: float = 0.1, |         dropout: float = 0.1, | ||||||
|     ): |     ): | ||||||
|         super(LFNA_Meta, self).__init__() |         super(LFNA_Meta, self).__init__() | ||||||
| @@ -73,7 +73,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) |         layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) | ||||||
|         self.meta_corrector = super_core.SuperSequential(*layers) |         self._meta_corrector = super_core.SuperSequential(*layers) | ||||||
|  |  | ||||||
|         model_kwargs = dict( |         model_kwargs = dict( | ||||||
|             config=dict(model_type="dual_norm_mlp"), |             config=dict(model_type="dual_norm_mlp"), | ||||||
| @@ -92,6 +92,18 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|             std=0.02, |             std=0.02, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def get_parameters(self, time_embed, meta_corrector, generator): | ||||||
|  |         parameters = [] | ||||||
|  |         if time_embed: | ||||||
|  |             parameters.append(self._super_meta_embed) | ||||||
|  |         if meta_corrector: | ||||||
|  |             parameters.extend(list(self._trans_att.parameters())) | ||||||
|  |             parameters.extend(list(self._meta_corrector.parameters())) | ||||||
|  |         if generator: | ||||||
|  |             parameters.append(self._super_layer_embed) | ||||||
|  |             parameters.extend(list(self._generator.parameters())) | ||||||
|  |         return parameters | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def meta_timestamps(self): |     def meta_timestamps(self): | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
| @@ -159,7 +171,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         # relative_timestamps = timestamps - timestamps[:, :1] |         # relative_timestamps = timestamps - timestamps[:, :1] | ||||||
|         # relative_pos_embeds = self._tscalar_embed(relative_timestamps) |         # relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||||
|         init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) |         init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) | ||||||
|         corrected_embeds = self.meta_corrector(init_timestamp_embeds) |         corrected_embeds = self._meta_corrector(init_timestamp_embeds) | ||||||
|         return corrected_embeds |         return corrected_embeds | ||||||
|  |  | ||||||
|     def forward_raw(self, timestamps, time_embed): |     def forward_raw(self, timestamps, time_embed): | ||||||
|   | |||||||
| @@ -14,14 +14,14 @@ from xautodl import spaces | |||||||
| from xautodl.xlayers import super_core | from xautodl.xlayers import super_core | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSuperAttention(unittest.TestCase): | class TestSuperSelfAttention(unittest.TestCase): | ||||||
|     """Test the super attention layer.""" |     """Test the super attention layer.""" | ||||||
|  |  | ||||||
|     def _internal_func(self, inputs, model): |     def _internal_func(self, inputs, model): | ||||||
|         outputs = model(inputs) |         outputs = model(inputs) | ||||||
|         abstract_space = model.abstract_search_space |         abstract_space = model.abstract_search_space | ||||||
|         print( |         print( | ||||||
|             "The abstract search space for SuperAttention is:\n{:}".format( |             "The abstract search space for SuperSelfAttention is:\n{:}".format( | ||||||
|                 abstract_space |                 abstract_space | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
| @@ -36,7 +36,7 @@ class TestSuperAttention(unittest.TestCase): | |||||||
|     def test_super_attention(self): |     def test_super_attention(self): | ||||||
|         proj_dim = spaces.Categorical(12, 24, 36) |         proj_dim = spaces.Categorical(12, 24, 36) | ||||||
|         num_heads = spaces.Categorical(2, 4, 6) |         num_heads = spaces.Categorical(2, 4, 6) | ||||||
|         model = super_core.SuperAttention(10, proj_dim, num_heads) |         model = super_core.SuperSelfAttention(10, proj_dim, num_heads) | ||||||
|         print(model) |         print(model) | ||||||
|         model.apply_verbose(True) |         model.apply_verbose(True) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -78,7 +78,7 @@ class SuperSelfAttention(SuperModule): | |||||||
|         return root_node |         return root_node | ||||||
|  |  | ||||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|         super(SuperAttention, self).apply_candidate(abstract_child) |         super(SuperSelfAttention, self).apply_candidate(abstract_child) | ||||||
|         if "q_fc" in abstract_child: |         if "q_fc" in abstract_child: | ||||||
|             self.q_fc.apply_candidate(abstract_child["q_fc"]) |             self.q_fc.apply_candidate(abstract_child["q_fc"]) | ||||||
|         if "k_fc" in abstract_child: |         if "k_fc" in abstract_child: | ||||||
| @@ -222,7 +222,7 @@ class SuperQKVAttention(SuperModule): | |||||||
|         return root_node |         return root_node | ||||||
|  |  | ||||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|         super(SuperAttention, self).apply_candidate(abstract_child) |         super(SuperQVKAttention, self).apply_candidate(abstract_child) | ||||||
|         if "q_fc" in abstract_child: |         if "q_fc" in abstract_child: | ||||||
|             self.q_fc.apply_candidate(abstract_child["q_fc"]) |             self.q_fc.apply_candidate(abstract_child["q_fc"]) | ||||||
|         if "k_fc" in abstract_child: |         if "k_fc" in abstract_child: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user