Try a different model / LFNA V3
This commit is contained in:
parent
be274e0b6c
commit
63a0361152
@ -5,7 +5,7 @@
|
|||||||
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
|
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001
|
||||||
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128
|
# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 --meta_batch 128
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, copy, torch, random, argparse
|
import pdb, sys, time, copy, torch, random, argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -95,19 +95,13 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
|
|||||||
|
|
||||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
||||||
logger.log("Online evaluate: {:}".format(env))
|
logger.log("Online evaluate: {:}".format(env))
|
||||||
for idx, (timestamp, (future_x, future_y)) in enumerate(env):
|
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
||||||
future_time = timestamp.item()
|
|
||||||
time_seqs = [
|
|
||||||
future_time - iseq * env.timestamp_interval
|
|
||||||
for iseq in range(args.seq_length)
|
|
||||||
]
|
|
||||||
time_seqs.reverse()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
meta_model.eval()
|
meta_model.eval()
|
||||||
base_model.eval()
|
base_model.eval()
|
||||||
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
|
_, [future_container], _ = meta_model(
|
||||||
[seq_containers], _ = meta_model(time_seqs, None)
|
future_time.to(args.device).view(1, 1), None, True
|
||||||
future_container = seq_containers[-1]
|
)
|
||||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||||
future_loss = criterion(future_y_hat, future_y)
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
@ -116,18 +110,17 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
|||||||
idx, len(env), future_loss.item()
|
idx, len(env), future_loss.item()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
meta_model.adapt(
|
refine = meta_model.adapt(
|
||||||
future_time,
|
base_model,
|
||||||
|
criterion,
|
||||||
|
future_time.item(),
|
||||||
future_x,
|
future_x,
|
||||||
future_y,
|
future_y,
|
||||||
env.timestamp_interval,
|
|
||||||
args.refine_lr,
|
args.refine_lr,
|
||||||
args.refine_epochs,
|
args.refine_epochs,
|
||||||
)
|
)
|
||||||
import pdb
|
meta_model.clear_fixed()
|
||||||
|
meta_model.clear_learnt()
|
||||||
pdb.set_trace()
|
|
||||||
print("-")
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||||
@ -251,7 +244,7 @@ def main(args):
|
|||||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||||
|
|
||||||
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
||||||
train_env.reset_max_seq_length(args.seq_length)
|
# train_env.reset_max_seq_length(args.seq_length)
|
||||||
# valid_env.reset_max_seq_length(args.seq_length)
|
# valid_env.reset_max_seq_length(args.seq_length)
|
||||||
valid_env_loader = torch.utils.data.DataLoader(
|
valid_env_loader = torch.utils.data.DataLoader(
|
||||||
valid_env,
|
valid_env,
|
||||||
@ -269,8 +262,8 @@ def main(args):
|
|||||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||||
|
|
||||||
# try to evaluate once
|
# try to evaluate once
|
||||||
|
online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||||
online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
||||||
import pdb
|
|
||||||
|
|
||||||
pdb.set_trace()
|
pdb.set_trace()
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
@ -510,11 +503,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--refine_lr",
|
"--refine_lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.001,
|
default=0.002,
|
||||||
help="The learning rate for the optimizer, during refine",
|
help="The learning rate for the optimizer, during refine",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--refine_epochs", type=int, default=1000, help="The final refine #epochs."
|
"--refine_epochs", type=int, default=50, help="The final refine #epochs."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stop_thresh",
|
"--early_stop_thresh",
|
||||||
|
@ -60,6 +60,17 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# build transformer
|
# build transformer
|
||||||
|
self._trans_att = super_core.SuperQKVAttentionV2(
|
||||||
|
qk_att_dim=time_embedding,
|
||||||
|
in_v_dim=time_embedding,
|
||||||
|
hidden_dim=time_embedding,
|
||||||
|
num_heads=4,
|
||||||
|
proj_dim=time_embedding,
|
||||||
|
qkv_bias=True,
|
||||||
|
attn_drop=None,
|
||||||
|
proj_drop=dropout,
|
||||||
|
)
|
||||||
|
"""
|
||||||
self._trans_att = super_core.SuperQKVAttention(
|
self._trans_att = super_core.SuperQKVAttention(
|
||||||
time_embedding,
|
time_embedding,
|
||||||
time_embedding,
|
time_embedding,
|
||||||
@ -70,6 +81,7 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
attn_drop=None,
|
attn_drop=None,
|
||||||
proj_drop=dropout,
|
proj_drop=dropout,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
layers = []
|
layers = []
|
||||||
for ilayer in range(mha_depth):
|
for ilayer in range(mha_depth):
|
||||||
layers.append(
|
layers.append(
|
||||||
@ -153,6 +165,13 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
def meta_length(self):
|
def meta_length(self):
|
||||||
return self.meta_timestamps.numel()
|
return self.meta_timestamps.numel()
|
||||||
|
|
||||||
|
def clear_fixed(self):
|
||||||
|
self._append_meta_timestamps["fixed"] = None
|
||||||
|
self._append_meta_embed["fixed"] = None
|
||||||
|
|
||||||
|
def clear_learnt(self):
|
||||||
|
self.replace_append_learnt(None, None)
|
||||||
|
|
||||||
def append_fixed(self, timestamp, meta_embed):
|
def append_fixed(self, timestamp, meta_embed):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
device = self._super_meta_embed.device
|
device = self._super_meta_embed.device
|
||||||
@ -175,9 +194,15 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
# timestamps is a batch of sequence of timestamps
|
# timestamps is a batch of sequence of timestamps
|
||||||
batch, seq = timestamps.shape
|
batch, seq = timestamps.shape
|
||||||
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
||||||
|
"""
|
||||||
timestamp_q_embed = self._tscalar_embed(timestamps)
|
timestamp_q_embed = self._tscalar_embed(timestamps)
|
||||||
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
||||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||||
|
"""
|
||||||
|
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||||
|
timestamp_qk_att_embed = self._tscalar_embed(
|
||||||
|
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
|
||||||
|
)
|
||||||
# create the mask
|
# create the mask
|
||||||
mask = (
|
mask = (
|
||||||
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
|
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
|
||||||
@ -188,7 +213,10 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
> self._thresh
|
> self._thresh
|
||||||
)
|
)
|
||||||
timestamp_embeds = self._trans_att(
|
timestamp_embeds = self._trans_att(
|
||||||
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
|
# timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
|
||||||
|
timestamp_qk_att_embed,
|
||||||
|
timestamp_v_embed,
|
||||||
|
mask,
|
||||||
)
|
)
|
||||||
relative_timestamps = timestamps - timestamps[:, :1]
|
relative_timestamps = timestamps - timestamps[:, :1]
|
||||||
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
||||||
@ -248,18 +276,41 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
def forward_candidate(self, input):
|
def forward_candidate(self, input):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def adapt(self, timestamp, x, y, threshold, lr, epochs):
|
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs):
|
||||||
if distance + threshold * 1e-2 <= threshold:
|
distance = self.get_closest_meta_distance(timestamp)
|
||||||
|
if distance + self._interval * 1e-2 <= self._interval:
|
||||||
return False
|
return False
|
||||||
|
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
|
||||||
with torch.set_grad_enabled(True):
|
with torch.set_grad_enabled(True):
|
||||||
new_param = self.create_meta_embed()
|
new_param = self.create_meta_embed()
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
[new_param], lr=args.refine_lr, weight_decay=1e-5, amsgrad=True
|
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
|
||||||
)
|
)
|
||||||
import pdb
|
timestamp = torch.Tensor([timestamp]).to(new_param.device)
|
||||||
|
self.replace_append_learnt(timestamp, new_param)
|
||||||
|
self.train()
|
||||||
|
base_model.train()
|
||||||
|
best_new_param, best_loss = None, 1e9
|
||||||
|
for iepoch in range(epochs):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
_, [_], time_embed = self(timestamp.view(1, 1), None, True)
|
||||||
|
match_loss = criterion(new_param, time_embed)
|
||||||
|
|
||||||
pdb.set_trace()
|
_, [container], time_embed = self(None, new_param.view(1, 1, -1), True)
|
||||||
print("-")
|
y_hat = base_model.forward_with_container(x, container)
|
||||||
|
meta_loss = criterion(y_hat, y)
|
||||||
|
loss = meta_loss + match_loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
|
||||||
|
if loss.item() < best_loss:
|
||||||
|
with torch.no_grad():
|
||||||
|
best_loss = loss.item()
|
||||||
|
best_new_param = new_param.detach()
|
||||||
|
with torch.no_grad():
|
||||||
|
self.replace_append_learnt(None, None)
|
||||||
|
self.append_fixed(timestamp, best_new_param)
|
||||||
|
return True
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
|
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
|
||||||
|
Loading…
Reference in New Issue
Block a user