From ce787df02c88f536d8e96904e8f489491cf5d06c Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 22 May 2021 17:36:09 +0800 Subject: [PATCH] Update LFNA --- exps/LFNA/lfna.py | 65 ++++++++++++++++++++++++++++++- exps/LFNA/lfna_meta_model.py | 28 ++++++++----- xautodl/datasets/synthetic_env.py | 12 ++++++ 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index e9b5abf..d916cc4 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -93,6 +93,67 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): return loss_meter +def pretrain(base_model, meta_model, criterion, xenv, args, logger): + optimizer = torch.optim.Adam( + meta_model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay, + amsgrad=True, + ) + + meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") + for iepoch in range(args.epochs): + total_meta_losses, total_match_losses = [], [] + for ibatch in range(args.meta_batch): + rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) + timestamps = meta_model.meta_timestamps[ + rand_index : rand_index + xenv.seq_length + ] + + seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) + [seq_containers], time_embeds = meta_model( + torch.unsqueeze(timestamps, dim=0) + ) + # performance loss + losses = [] + seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( + args.device + ) + for container, inputs, targets in zip( + seq_containers, seq_inputs, seq_targets + ): + predictions = base_model.forward_with_container(inputs, container) + loss = criterion(predictions, targets) + losses.append(loss) + meta_loss = torch.stack(losses).mean() + match_loss = criterion( + torch.squeeze(time_embeds, dim=0), + meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], + ) + # batch_loss = meta_loss + match_loss * 0.1 + # total_losses.append(batch_loss) + total_meta_losses.append(meta_loss) + total_match_losses.append(match_loss) + final_meta_loss = torch.stack(total_meta_losses).mean() + final_match_loss = torch.stack(total_match_losses).mean() + total_loss = final_meta_loss + final_match_loss + total_loss.backward() + optimizer.step() + # success + success, best_score = meta_model.save_best(-total_loss.item()) + logger.log( + "{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format( + time_string(), + iepoch, + args.epochs, + total_loss.item(), + final_meta_loss.item(), + final_match_loss.item(), + ) + + ", batch={:}".format(len(total_meta_losses)) + ) + + def main(args): logger, env_info, model_kwargs = lfna_setup(args) train_env = get_synthetic_env(mode="train", version=args.env_version) @@ -148,6 +209,8 @@ def main(args): logger.log("The scheduler is\n{:}".format(lr_scheduler)) logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) + pretrain(base_model, meta_model, criterion, train_env, args, logger) + if logger.path("model").exists(): ckp_data = torch.load(logger.path("model")) base_model.load_state_dict(ckp_data["base_model"]) @@ -345,7 +408,7 @@ if __name__ == "__main__": parser.add_argument( "--lr", type=float, - default=0.005, + default=0.002, help="The initial learning rate for the optimizer (default is Adam)", ) parser.add_argument( diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index 66fbcd5..b9aa87a 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -63,7 +63,7 @@ class LFNA_Meta(super_core.SuperModule): for ilayer in range(mha_depth): layers.append( super_core.SuperTransformerEncoderLayer( - time_embedding, + time_embedding * 2, 4, True, 4, @@ -72,7 +72,7 @@ class LFNA_Meta(super_core.SuperModule): order=super_core.LayerOrder.PostNorm, ) ) - layers.append(super_core.SuperLinear(time_embedding, time_embedding)) + layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) self.meta_corrector = super_core.SuperSequential(*layers) model_kwargs = dict( @@ -95,10 +95,11 @@ class LFNA_Meta(super_core.SuperModule): @property def meta_timestamps(self): - meta_timestamps = [self._meta_timestamps] - for key in ("fixed", "learnt"): - if self._append_meta_timestamps[key] is not None: - meta_timestamps.append(self._append_meta_timestamps[key]) + with torch.no_grad(): + meta_timestamps = [self._meta_timestamps] + for key in ("fixed", "learnt"): + if self._append_meta_timestamps[key] is not None: + meta_timestamps.append(self._append_meta_timestamps[key]) return torch.cat(meta_timestamps) @property @@ -125,6 +126,10 @@ class LFNA_Meta(super_core.SuperModule): self._append_meta_timestamps["learnt"] = timestamp self._append_meta_embed["learnt"] = meta_embed + @property + def meta_length(self): + return self.meta_timestamps.numel() + def append_fixed(self, timestamp, meta_embed): with torch.no_grad(): device = self._super_meta_embed.device @@ -152,15 +157,18 @@ class LFNA_Meta(super_core.SuperModule): timestamp_embeds = self._trans_att( timestamp_q_embed, timestamp_k_embed, timestamp_v_embed ) - corrected_embeds = self.meta_corrector(timestamp_embeds) + # relative_timestamps = timestamps - timestamps[:, :1] + # relative_pos_embeds = self._tscalar_embed(relative_timestamps) + init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) + corrected_embeds = self.meta_corrector(init_timestamp_embeds) return corrected_embeds def forward_raw(self, timestamps): batch, seq = timestamps.shape - meta_embed = self._obtain_time_embed(timestamps) + time_embed = self._obtain_time_embed(timestamps) # create joint embed num_layer, _ = self._super_layer_embed.shape - meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) + meta_embed = time_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand( batch, seq, -1, -1 ) @@ -173,7 +181,7 @@ class LFNA_Meta(super_core.SuperModule): weights = torch.split(weights.squeeze(0), 1) seq_containers.append(self._shape_container.translate(weights)) batch_containers.append(seq_containers) - return batch_containers + return batch_containers, time_embed def forward_candidate(self, input): raise NotImplementedError diff --git a/xautodl/datasets/synthetic_env.py b/xautodl/datasets/synthetic_env.py index 7f5e33c..944d5a5 100644 --- a/xautodl/datasets/synthetic_env.py +++ b/xautodl/datasets/synthetic_env.py @@ -68,6 +68,10 @@ class SyntheticDEnv(data.Dataset): self._oracle_map = None self._seq_length = None + @property + def seq_length(self): + return self._seq_length + @property def min_timestamp(self): return self._timestamp_generator.min_timestamp @@ -125,6 +129,14 @@ class SyntheticDEnv(data.Dataset): timestamp + i * self.timestamp_interval + noise for i in range(self._seq_length) ] + # xdata = [self.__call__(timestamp) for timestamp in timestamps] + # return zip_sequence(xdata) + return self.seq_call(timestamps) + + def seq_call(self, timestamps): + with torch.no_grad(): + if isinstance(timestamps, torch.Tensor): + timestamps = timestamps.cpu().tolist() xdata = [self.__call__(timestamp) for timestamp in timestamps] return zip_sequence(xdata)