diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index bba6025..6c827fd 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -5,7 +5,7 @@ # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128 ##################################################### -import sys, time, copy, torch, random, argparse +import pdb, sys, time, copy, torch, random, argparse from tqdm import tqdm from copy import deepcopy from pathlib import Path @@ -95,19 +95,13 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): def online_evaluate(env, meta_model, base_model, criterion, args, logger): logger.log("Online evaluate: {:}".format(env)) - for idx, (timestamp, (future_x, future_y)) in enumerate(env): - future_time = timestamp.item() - time_seqs = [ - future_time - iseq * env.timestamp_interval - for iseq in range(args.seq_length) - ] - time_seqs.reverse() + for idx, (future_time, (future_x, future_y)) in enumerate(env): with torch.no_grad(): meta_model.eval() base_model.eval() - time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) - [seq_containers], _ = meta_model(time_seqs, None) - future_container = seq_containers[-1] + _, [future_container], _ = meta_model( + future_time.to(args.device).view(1, 1), None, True + ) future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_y_hat = base_model.forward_with_container(future_x, future_container) future_loss = criterion(future_y_hat, future_y) @@ -116,18 +110,17 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): idx, len(env), future_loss.item() ) ) - meta_model.adapt( - future_time, + refine = meta_model.adapt( + base_model, + criterion, + future_time.item(), future_x, future_y, - env.timestamp_interval, args.refine_lr, args.refine_epochs, ) - import pdb - - pdb.set_trace() - print("-") + meta_model.clear_fixed() + meta_model.clear_learnt() def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): @@ -251,7 +244,7 @@ def main(args): logger.log("The meta-model is\n{:}".format(meta_model)) batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) - train_env.reset_max_seq_length(args.seq_length) + # train_env.reset_max_seq_length(args.seq_length) # valid_env.reset_max_seq_length(args.seq_length) valid_env_loader = torch.utils.data.DataLoader( valid_env, @@ -269,8 +262,8 @@ def main(args): pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) # try to evaluate once + online_evaluate(train_env, meta_model, base_model, criterion, args, logger) online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) - import pdb pdb.set_trace() optimizer = torch.optim.Adam( @@ -510,11 +503,11 @@ if __name__ == "__main__": parser.add_argument( "--refine_lr", type=float, - default=0.001, + default=0.002, help="The learning rate for the optimizer, during refine", ) parser.add_argument( - "--refine_epochs", type=int, default=1000, help="The final refine #epochs." + "--refine_epochs", type=int, default=50, help="The final refine #epochs." ) parser.add_argument( "--early_stop_thresh", diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index d847366..7291f76 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -60,6 +60,17 @@ class LFNA_Meta(super_core.SuperModule): ) # build transformer + self._trans_att = super_core.SuperQKVAttentionV2( + qk_att_dim=time_embedding, + in_v_dim=time_embedding, + hidden_dim=time_embedding, + num_heads=4, + proj_dim=time_embedding, + qkv_bias=True, + attn_drop=None, + proj_drop=dropout, + ) + """ self._trans_att = super_core.SuperQKVAttention( time_embedding, time_embedding, @@ -70,6 +81,7 @@ class LFNA_Meta(super_core.SuperModule): attn_drop=None, proj_drop=dropout, ) + """ layers = [] for ilayer in range(mha_depth): layers.append( @@ -153,6 +165,13 @@ class LFNA_Meta(super_core.SuperModule): def meta_length(self): return self.meta_timestamps.numel() + def clear_fixed(self): + self._append_meta_timestamps["fixed"] = None + self._append_meta_embed["fixed"] = None + + def clear_learnt(self): + self.replace_append_learnt(None, None) + def append_fixed(self, timestamp, meta_embed): with torch.no_grad(): device = self._super_meta_embed.device @@ -175,9 +194,15 @@ class LFNA_Meta(super_core.SuperModule): # timestamps is a batch of sequence of timestamps batch, seq = timestamps.shape meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed + """ timestamp_q_embed = self._tscalar_embed(timestamps) timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) timestamp_v_embed = meta_embeds.unsqueeze(dim=0) + """ + timestamp_v_embed = meta_embeds.unsqueeze(dim=0) + timestamp_qk_att_embed = self._tscalar_embed( + torch.unsqueeze(timestamps, dim=-1) - meta_timestamps + ) # create the mask mask = ( torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) @@ -188,7 +213,10 @@ class LFNA_Meta(super_core.SuperModule): > self._thresh ) timestamp_embeds = self._trans_att( - timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask + # timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask + timestamp_qk_att_embed, + timestamp_v_embed, + mask, ) relative_timestamps = timestamps - timestamps[:, :1] relative_pos_embeds = self._tscalar_embed(relative_timestamps) @@ -248,18 +276,41 @@ class LFNA_Meta(super_core.SuperModule): def forward_candidate(self, input): raise NotImplementedError - def adapt(self, timestamp, x, y, threshold, lr, epochs): - if distance + threshold * 1e-2 <= threshold: + def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs): + distance = self.get_closest_meta_distance(timestamp) + if distance + self._interval * 1e-2 <= self._interval: return False + x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) with torch.set_grad_enabled(True): new_param = self.create_meta_embed() optimizer = torch.optim.Adam( - [new_param], lr=args.refine_lr, weight_decay=1e-5, amsgrad=True + [new_param], lr=lr, weight_decay=1e-5, amsgrad=True ) - import pdb + timestamp = torch.Tensor([timestamp]).to(new_param.device) + self.replace_append_learnt(timestamp, new_param) + self.train() + base_model.train() + best_new_param, best_loss = None, 1e9 + for iepoch in range(epochs): + optimizer.zero_grad() + _, [_], time_embed = self(timestamp.view(1, 1), None, True) + match_loss = criterion(new_param, time_embed) - pdb.set_trace() - print("-") + _, [container], time_embed = self(None, new_param.view(1, 1, -1), True) + y_hat = base_model.forward_with_container(x, container) + meta_loss = criterion(y_hat, y) + loss = meta_loss + match_loss + loss.backward() + optimizer.step() + # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) + if loss.item() < best_loss: + with torch.no_grad(): + best_loss = loss.item() + best_new_param = new_param.detach() + with torch.no_grad(): + self.replace_append_learnt(None, None) + self.append_fixed(timestamp, best_new_param) + return True def extra_repr(self) -> str: return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(