Try a different model / LFNA V3
This commit is contained in:
		| @@ -5,7 +5,7 @@ | |||||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 | # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128 | # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import pdb, sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -95,19 +95,13 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): | |||||||
|  |  | ||||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger): | def online_evaluate(env, meta_model, base_model, criterion, args, logger): | ||||||
|     logger.log("Online evaluate: {:}".format(env)) |     logger.log("Online evaluate: {:}".format(env)) | ||||||
|     for idx, (timestamp, (future_x, future_y)) in enumerate(env): |     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||||
|         future_time = timestamp.item() |  | ||||||
|         time_seqs = [ |  | ||||||
|             future_time - iseq * env.timestamp_interval |  | ||||||
|             for iseq in range(args.seq_length) |  | ||||||
|         ] |  | ||||||
|         time_seqs.reverse() |  | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_model.eval() |             meta_model.eval() | ||||||
|             base_model.eval() |             base_model.eval() | ||||||
|             time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) |             _, [future_container], _ = meta_model( | ||||||
|             [seq_containers], _ = meta_model(time_seqs, None) |                 future_time.to(args.device).view(1, 1), None, True | ||||||
|             future_container = seq_containers[-1] |             ) | ||||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) |             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) |             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||||
|             future_loss = criterion(future_y_hat, future_y) |             future_loss = criterion(future_y_hat, future_y) | ||||||
| @@ -116,18 +110,17 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | |||||||
|                     idx, len(env), future_loss.item() |                     idx, len(env), future_loss.item() | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         meta_model.adapt( |         refine = meta_model.adapt( | ||||||
|             future_time, |             base_model, | ||||||
|  |             criterion, | ||||||
|  |             future_time.item(), | ||||||
|             future_x, |             future_x, | ||||||
|             future_y, |             future_y, | ||||||
|             env.timestamp_interval, |  | ||||||
|             args.refine_lr, |             args.refine_lr, | ||||||
|             args.refine_epochs, |             args.refine_epochs, | ||||||
|         ) |         ) | ||||||
|         import pdb |     meta_model.clear_fixed() | ||||||
|  |     meta_model.clear_learnt() | ||||||
|         pdb.set_trace() |  | ||||||
|         print("-") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||||
| @@ -251,7 +244,7 @@ def main(args): | |||||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) |     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||||
|  |  | ||||||
|     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) |     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||||
|     train_env.reset_max_seq_length(args.seq_length) |     # train_env.reset_max_seq_length(args.seq_length) | ||||||
|     # valid_env.reset_max_seq_length(args.seq_length) |     # valid_env.reset_max_seq_length(args.seq_length) | ||||||
|     valid_env_loader = torch.utils.data.DataLoader( |     valid_env_loader = torch.utils.data.DataLoader( | ||||||
|         valid_env, |         valid_env, | ||||||
| @@ -269,8 +262,8 @@ def main(args): | |||||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) |     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||||
|  |  | ||||||
|     # try to evaluate once |     # try to evaluate once | ||||||
|  |     online_evaluate(train_env, meta_model, base_model, criterion, args, logger) | ||||||
|     online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) |     online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) | ||||||
|     import pdb |  | ||||||
|  |  | ||||||
|     pdb.set_trace() |     pdb.set_trace() | ||||||
|     optimizer = torch.optim.Adam( |     optimizer = torch.optim.Adam( | ||||||
| @@ -510,11 +503,11 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--refine_lr", |         "--refine_lr", | ||||||
|         type=float, |         type=float, | ||||||
|         default=0.001, |         default=0.002, | ||||||
|         help="The learning rate for the optimizer, during refine", |         help="The learning rate for the optimizer, during refine", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--refine_epochs", type=int, default=1000, help="The final refine #epochs." |         "--refine_epochs", type=int, default=50, help="The final refine #epochs." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|   | |||||||
| @@ -60,6 +60,17 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # build transformer |         # build transformer | ||||||
|  |         self._trans_att = super_core.SuperQKVAttentionV2( | ||||||
|  |             qk_att_dim=time_embedding, | ||||||
|  |             in_v_dim=time_embedding, | ||||||
|  |             hidden_dim=time_embedding, | ||||||
|  |             num_heads=4, | ||||||
|  |             proj_dim=time_embedding, | ||||||
|  |             qkv_bias=True, | ||||||
|  |             attn_drop=None, | ||||||
|  |             proj_drop=dropout, | ||||||
|  |         ) | ||||||
|  |         """ | ||||||
|         self._trans_att = super_core.SuperQKVAttention( |         self._trans_att = super_core.SuperQKVAttention( | ||||||
|             time_embedding, |             time_embedding, | ||||||
|             time_embedding, |             time_embedding, | ||||||
| @@ -70,6 +81,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|             attn_drop=None, |             attn_drop=None, | ||||||
|             proj_drop=dropout, |             proj_drop=dropout, | ||||||
|         ) |         ) | ||||||
|  |         """ | ||||||
|         layers = [] |         layers = [] | ||||||
|         for ilayer in range(mha_depth): |         for ilayer in range(mha_depth): | ||||||
|             layers.append( |             layers.append( | ||||||
| @@ -153,6 +165,13 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|     def meta_length(self): |     def meta_length(self): | ||||||
|         return self.meta_timestamps.numel() |         return self.meta_timestamps.numel() | ||||||
|  |  | ||||||
|  |     def clear_fixed(self): | ||||||
|  |         self._append_meta_timestamps["fixed"] = None | ||||||
|  |         self._append_meta_embed["fixed"] = None | ||||||
|  |  | ||||||
|  |     def clear_learnt(self): | ||||||
|  |         self.replace_append_learnt(None, None) | ||||||
|  |  | ||||||
|     def append_fixed(self, timestamp, meta_embed): |     def append_fixed(self, timestamp, meta_embed): | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             device = self._super_meta_embed.device |             device = self._super_meta_embed.device | ||||||
| @@ -175,9 +194,15 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         # timestamps is a batch of sequence of timestamps |         # timestamps is a batch of sequence of timestamps | ||||||
|         batch, seq = timestamps.shape |         batch, seq = timestamps.shape | ||||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed |         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||||
|  |         """ | ||||||
|         timestamp_q_embed = self._tscalar_embed(timestamps) |         timestamp_q_embed = self._tscalar_embed(timestamps) | ||||||
|         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) |         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) | ||||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) |         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||||
|  |         """ | ||||||
|  |         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||||
|  |         timestamp_qk_att_embed = self._tscalar_embed( | ||||||
|  |             torch.unsqueeze(timestamps, dim=-1) - meta_timestamps | ||||||
|  |         ) | ||||||
|         # create the mask |         # create the mask | ||||||
|         mask = ( |         mask = ( | ||||||
|             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) |             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||||
| @@ -188,7 +213,10 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|             > self._thresh |             > self._thresh | ||||||
|         ) |         ) | ||||||
|         timestamp_embeds = self._trans_att( |         timestamp_embeds = self._trans_att( | ||||||
|             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask |             # timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask | ||||||
|  |             timestamp_qk_att_embed, | ||||||
|  |             timestamp_v_embed, | ||||||
|  |             mask, | ||||||
|         ) |         ) | ||||||
|         relative_timestamps = timestamps - timestamps[:, :1] |         relative_timestamps = timestamps - timestamps[:, :1] | ||||||
|         relative_pos_embeds = self._tscalar_embed(relative_timestamps) |         relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||||
| @@ -248,18 +276,41 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|     def forward_candidate(self, input): |     def forward_candidate(self, input): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def adapt(self, timestamp, x, y, threshold, lr, epochs): |     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs): | ||||||
|         if distance + threshold * 1e-2 <= threshold: |         distance = self.get_closest_meta_distance(timestamp) | ||||||
|  |         if distance + self._interval * 1e-2 <= self._interval: | ||||||
|             return False |             return False | ||||||
|  |         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||||
|         with torch.set_grad_enabled(True): |         with torch.set_grad_enabled(True): | ||||||
|             new_param = self.create_meta_embed() |             new_param = self.create_meta_embed() | ||||||
|             optimizer = torch.optim.Adam( |             optimizer = torch.optim.Adam( | ||||||
|                 [new_param], lr=args.refine_lr, weight_decay=1e-5, amsgrad=True |                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True | ||||||
|             ) |             ) | ||||||
|         import pdb |             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||||
|  |             self.replace_append_learnt(timestamp, new_param) | ||||||
|  |             self.train() | ||||||
|  |             base_model.train() | ||||||
|  |             best_new_param, best_loss = None, 1e9 | ||||||
|  |             for iepoch in range(epochs): | ||||||
|  |                 optimizer.zero_grad() | ||||||
|  |                 _, [_], time_embed = self(timestamp.view(1, 1), None, True) | ||||||
|  |                 match_loss = criterion(new_param, time_embed) | ||||||
|  |  | ||||||
|         pdb.set_trace() |                 _, [container], time_embed = self(None, new_param.view(1, 1, -1), True) | ||||||
|         print("-") |                 y_hat = base_model.forward_with_container(x, container) | ||||||
|  |                 meta_loss = criterion(y_hat, y) | ||||||
|  |                 loss = meta_loss + match_loss | ||||||
|  |                 loss.backward() | ||||||
|  |                 optimizer.step() | ||||||
|  |                 # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) | ||||||
|  |                 if loss.item() < best_loss: | ||||||
|  |                     with torch.no_grad(): | ||||||
|  |                         best_loss = loss.item() | ||||||
|  |                         best_new_param = new_param.detach() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             self.replace_append_learnt(None, None) | ||||||
|  |             self.append_fixed(timestamp, best_new_param) | ||||||
|  |         return True | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( |         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user