Try a different model / LFNA V3

This commit is contained in:
D-X-Y 2021-05-24 01:06:22 +08:00
parent be274e0b6c
commit 63a0361152
2 changed files with 73 additions and 29 deletions

View File

@ -5,7 +5,7 @@
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128 # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128
##################################################### #####################################################
import sys, time, copy, torch, random, argparse import pdb, sys, time, copy, torch, random, argparse
from tqdm import tqdm from tqdm import tqdm
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -95,19 +95,13 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
def online_evaluate(env, meta_model, base_model, criterion, args, logger): def online_evaluate(env, meta_model, base_model, criterion, args, logger):
logger.log("Online evaluate: {:}".format(env)) logger.log("Online evaluate: {:}".format(env))
for idx, (timestamp, (future_x, future_y)) in enumerate(env): for idx, (future_time, (future_x, future_y)) in enumerate(env):
future_time = timestamp.item()
time_seqs = [
future_time - iseq * env.timestamp_interval
for iseq in range(args.seq_length)
]
time_seqs.reverse()
with torch.no_grad(): with torch.no_grad():
meta_model.eval() meta_model.eval()
base_model.eval() base_model.eval()
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) _, [future_container], _ = meta_model(
[seq_containers], _ = meta_model(time_seqs, None) future_time.to(args.device).view(1, 1), None, True
future_container = seq_containers[-1] )
future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container) future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y) future_loss = criterion(future_y_hat, future_y)
@ -116,18 +110,17 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
idx, len(env), future_loss.item() idx, len(env), future_loss.item()
) )
) )
meta_model.adapt( refine = meta_model.adapt(
future_time, base_model,
criterion,
future_time.item(),
future_x, future_x,
future_y, future_y,
env.timestamp_interval,
args.refine_lr, args.refine_lr,
args.refine_epochs, args.refine_epochs,
) )
import pdb meta_model.clear_fixed()
meta_model.clear_learnt()
pdb.set_trace()
print("-")
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
@ -251,7 +244,7 @@ def main(args):
logger.log("The meta-model is\n{:}".format(meta_model)) logger.log("The meta-model is\n{:}".format(meta_model))
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
train_env.reset_max_seq_length(args.seq_length) # train_env.reset_max_seq_length(args.seq_length)
# valid_env.reset_max_seq_length(args.seq_length) # valid_env.reset_max_seq_length(args.seq_length)
valid_env_loader = torch.utils.data.DataLoader( valid_env_loader = torch.utils.data.DataLoader(
valid_env, valid_env,
@ -269,8 +262,8 @@ def main(args):
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
# try to evaluate once # try to evaluate once
online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
import pdb
pdb.set_trace() pdb.set_trace()
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
@ -510,11 +503,11 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--refine_lr", "--refine_lr",
type=float, type=float,
default=0.001, default=0.002,
help="The learning rate for the optimizer, during refine", help="The learning rate for the optimizer, during refine",
) )
parser.add_argument( parser.add_argument(
"--refine_epochs", type=int, default=1000, help="The final refine #epochs." "--refine_epochs", type=int, default=50, help="The final refine #epochs."
) )
parser.add_argument( parser.add_argument(
"--early_stop_thresh", "--early_stop_thresh",

View File

@ -60,6 +60,17 @@ class LFNA_Meta(super_core.SuperModule):
) )
# build transformer # build transformer
self._trans_att = super_core.SuperQKVAttentionV2(
qk_att_dim=time_embedding,
in_v_dim=time_embedding,
hidden_dim=time_embedding,
num_heads=4,
proj_dim=time_embedding,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
"""
self._trans_att = super_core.SuperQKVAttention( self._trans_att = super_core.SuperQKVAttention(
time_embedding, time_embedding,
time_embedding, time_embedding,
@ -70,6 +81,7 @@ class LFNA_Meta(super_core.SuperModule):
attn_drop=None, attn_drop=None,
proj_drop=dropout, proj_drop=dropout,
) )
"""
layers = [] layers = []
for ilayer in range(mha_depth): for ilayer in range(mha_depth):
layers.append( layers.append(
@ -153,6 +165,13 @@ class LFNA_Meta(super_core.SuperModule):
def meta_length(self): def meta_length(self):
return self.meta_timestamps.numel() return self.meta_timestamps.numel()
def clear_fixed(self):
self._append_meta_timestamps["fixed"] = None
self._append_meta_embed["fixed"] = None
def clear_learnt(self):
self.replace_append_learnt(None, None)
def append_fixed(self, timestamp, meta_embed): def append_fixed(self, timestamp, meta_embed):
with torch.no_grad(): with torch.no_grad():
device = self._super_meta_embed.device device = self._super_meta_embed.device
@ -175,9 +194,15 @@ class LFNA_Meta(super_core.SuperModule):
# timestamps is a batch of sequence of timestamps # timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape batch, seq = timestamps.shape
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
"""
timestamp_q_embed = self._tscalar_embed(timestamps) timestamp_q_embed = self._tscalar_embed(timestamps)
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
timestamp_v_embed = meta_embeds.unsqueeze(dim=0) timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
"""
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_qk_att_embed = self._tscalar_embed(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
)
# create the mask # create the mask
mask = ( mask = (
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
@ -188,7 +213,10 @@ class LFNA_Meta(super_core.SuperModule):
> self._thresh > self._thresh
) )
timestamp_embeds = self._trans_att( timestamp_embeds = self._trans_att(
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask # timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
timestamp_qk_att_embed,
timestamp_v_embed,
mask,
) )
relative_timestamps = timestamps - timestamps[:, :1] relative_timestamps = timestamps - timestamps[:, :1]
relative_pos_embeds = self._tscalar_embed(relative_timestamps) relative_pos_embeds = self._tscalar_embed(relative_timestamps)
@ -248,18 +276,41 @@ class LFNA_Meta(super_core.SuperModule):
def forward_candidate(self, input): def forward_candidate(self, input):
raise NotImplementedError raise NotImplementedError
def adapt(self, timestamp, x, y, threshold, lr, epochs): def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs):
if distance + threshold * 1e-2 <= threshold: distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
return False return False
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
new_param = self.create_meta_embed() new_param = self.create_meta_embed()
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[new_param], lr=args.refine_lr, weight_decay=1e-5, amsgrad=True [new_param], lr=lr, weight_decay=1e-5, amsgrad=True
) )
import pdb timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param)
self.train()
base_model.train()
best_new_param, best_loss = None, 1e9
for iepoch in range(epochs):
optimizer.zero_grad()
_, [_], time_embed = self(timestamp.view(1, 1), None, True)
match_loss = criterion(new_param, time_embed)
pdb.set_trace() _, [container], time_embed = self(None, new_param.view(1, 1, -1), True)
print("-") y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
if loss.item() < best_loss:
with torch.no_grad():
best_loss = loss.item()
best_new_param = new_param.detach()
with torch.no_grad():
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, best_new_param)
return True
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(