Complete LFNA 1.0

This commit is contained in:
D-X-Y 2021-05-14 00:36:37 +08:00
parent c2fa181bc5
commit b81ef2dd74
4 changed files with 311 additions and 19 deletions

190
exps/LFNA/basic-prev.py Normal file
View File

@ -0,0 +1,190 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-prev.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from log_utils import time_string
from log_utils import AverageMeter, convert_secs2time
from utils import split_str2indexes
from procedures.advanced_main import basic_train_fn, basic_eval_fn
from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from datasets.synthetic_core import get_synthetic_env
from models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx in range(1, env_info["total"]):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " "
+ need_time
)
# train the same data
historical_x = env_info["{:}-x".format(idx - 1)]
historical_y = env_info["{:}-y".format(idx - 1)]
# build model
model = get_model(**model_kwargs)
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = MSEMetric()
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the last timestamp.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-prev-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}".format(
args.save_dir, args.env_version, args.hidden_dim
)
main(args)

View File

@ -1,6 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
##################################################### #####################################################
# python exps/LFNA/lfna.py --env_version v1 --workers 0
# python exps/LFNA/lfna.py --env_version v1 --device cuda # python exps/LFNA/lfna.py --env_version v1 --device cuda
##################################################### #####################################################
import sys, time, copy, torch, random, argparse import sys, time, copy, torch, random, argparse
@ -156,19 +157,61 @@ def main(args):
per_epoch_time.update(time.time() - start_time) per_epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
# meta-training
meta_model.load_best()
eval_env = env_info["dynamic_env"]
w_container_per_epoch = dict() w_container_per_epoch = dict()
for idx in range(0, total_bar): for idx in range(args.seq_length, env_info["total"]):
# build-timestamp
future_time = env_info["{:}-timestamp".format(idx)] future_time = env_info["{:}-timestamp".format(idx)]
future_x = env_info["{:}-x".format(idx)] time_seqs = []
future_y = env_info["{:}-y".format(idx)] for iseq in range(args.seq_length):
future_container = hypernet(task_embeds[idx]) time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
w_container_per_epoch[idx] = future_container.no_grad_clone() time_seqs.reverse()
with torch.no_grad(): with torch.no_grad():
future_y_hat = model.forward_with_container( meta_model.eval()
base_model.eval()
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
[seq_containers] = meta_model(time_seqs)
future_container = seq_containers[-1]
w_container_per_epoch[idx] = future_container.no_grad_clone()
# evaluation
future_x = env_info["{:}-x".format(idx)]
future_y = env_info["{:}-y".format(idx)]
future_y_hat = base_model.forward_with_container(
future_x, w_container_per_epoch[idx] future_x, w_container_per_epoch[idx]
) )
future_loss = criterion(future_y_hat, future_y) future_loss = criterion(future_y_hat, future_y)
logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) logger.log(
"meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())
)
# creating the new meta-time-embedding
distance = meta_model.get_closest_meta_distance(future_time)
if distance < eval_env.timestamp_interval:
continue
#
new_param = meta_model.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True
)
meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param)
meta_model.eval()
base_model.train()
for iepoch in range(args.epochs):
optimizer.zero_grad()
[seq_containers] = meta_model(time_seqs)
future_container = seq_containers[-1]
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
future_loss.backward()
optimizer.step()
logger.log(
"post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())
)
with torch.no_grad():
meta_model.replace_append_learnt(None, None)
meta_model.append_fixed(torch.Tensor([future_time]), new_param)
save_checkpoint( save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch}, {"w_container_per_epoch": w_container_per_epoch},
@ -216,7 +259,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--init_lr", "--init_lr",
type=float, type=float,
default=0.01, default=0.005,
help="The initial learning rate for the optimizer (default is Adam)", help="The initial learning rate for the optimizer (default is Adam)",
) )
parser.add_argument( parser.add_argument(
@ -235,7 +278,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--early_stop_thresh", "--early_stop_thresh",
type=int, type=int,
default=50, default=25,
help="The maximum epochs for early stop.", help="The maximum epochs for early stop.",
) )
parser.add_argument( parser.add_argument(
@ -256,7 +299,12 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0: if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None" assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}_{:}_{:}".format( args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format(
args.save_dir, args.env_version, args.hidden_dim, args.layer_dim, args.time_dim args.save_dir,
args.env_version,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.epochs,
) )
main(args) main(args)

View File

@ -17,7 +17,7 @@ class LFNA_Meta(super_core.SuperModule):
def __init__( def __init__(
self, self,
shape_container, shape_container,
layer_embeding, layer_embedding,
time_embedding, time_embedding,
meta_timestamps, meta_timestamps,
mha_depth: int = 2, mha_depth: int = 2,
@ -33,13 +33,16 @@ class LFNA_Meta(super_core.SuperModule):
self.register_parameter( self.register_parameter(
"_super_layer_embed", "_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)),
) )
self.register_parameter( self.register_parameter(
"_super_meta_embed", "_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)), torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)),
) )
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
self._time_embed_dim = time_embedding
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
# build transformer # build transformer
layers = [] layers = []
@ -60,9 +63,9 @@ class LFNA_Meta(super_core.SuperModule):
model_kwargs = dict( model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"), config=dict(model_type="dual_norm_mlp"),
input_dim=layer_embeding + time_embedding, input_dim=layer_embedding + time_embedding,
output_dim=max(self._numel_per_layer), output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_embeding + time_embedding) * 2] * 3, hidden_dims=[(layer_embedding + time_embedding) * 2] * 3,
act_cls="gelu", act_cls="gelu",
norm_cls="layer_norm_1d", norm_cls="layer_norm_1d",
dropout=dropout, dropout=dropout,
@ -82,21 +85,68 @@ class LFNA_Meta(super_core.SuperModule):
std=0.02, std=0.02,
) )
@property
def meta_timestamps(self):
meta_timestamps = [self._meta_timestamps]
for key in ("fixed", "learnt"):
if self._append_meta_timestamps[key] is not None:
meta_timestamps.append(self._append_meta_timestamps[key])
return torch.cat(meta_timestamps)
@property
def super_meta_embed(self):
meta_embed = [self._super_meta_embed]
for key in ("fixed", "learnt"):
if self._append_meta_embed[key] is not None:
meta_embed.append(self._append_meta_embed[key])
return torch.cat(meta_embed)
def create_meta_embed(self):
param = torch.nn.Parameter(torch.Tensor(1, self._time_embed_dim))
trunc_normal_(param, std=0.02)
return param.to(self._super_meta_embed.device)
def get_closest_meta_distance(self, timestamp):
with torch.no_grad():
distances = torch.abs(self.meta_timestamps - timestamp)
return torch.min(distances).item()
def replace_append_learnt(self, timestamp, meta_embed):
self._append_meta_embed["learnt"] = meta_embed
self._append_meta_timestamps["learnt"] = timestamp
def append_fixed(self, timestamp, meta_embed):
with torch.no_grad():
timestamp, meta_embed = timestamp.clone(), meta_embed.clone()
if self._append_meta_timestamps["fixed"] is None:
self._append_meta_timestamps["fixed"] = timestamp
else:
self._append_meta_timestamps["fixed"] = torch.cat(
(self._append_meta_timestamps["fixed"], timestamp), dim=0
)
if self._append_meta_embed["fixed"] is None:
self._append_meta_embed["fixed"] = meta_embed
else:
self._append_meta_embed["fixed"] = torch.cat(
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def forward_raw(self, timestamps): def forward_raw(self, timestamps):
# timestamps is a batch of sequence of timestamps # timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape batch, seq = timestamps.shape
timestamps = timestamps.unsqueeze(dim=-1) timestamps = timestamps.unsqueeze(dim=-1)
meta_timestamps = self._meta_timestamps.view(1, 1, -1) meta_timestamps = self.meta_timestamps.view(1, 1, -1)
time_diffs = timestamps - meta_timestamps time_diffs = timestamps - meta_timestamps
time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1) time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1)
# select corresponding meta-knowledge # select corresponding meta-knowledge
meta_match = torch.index_select( meta_match = torch.index_select(
self._super_meta_embed, dim=0, index=time_match_i.view(-1) self.super_meta_embed, dim=0, index=time_match_i.view(-1)
) )
meta_match = meta_match.view(batch, seq, -1) meta_match = meta_match.view(batch, seq, -1)
# create the probability # create the probability
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1) time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
time_probs[:, -1, :] = 0 if self.training:
time_probs[:, -1, :] = 0
unknown_token = self._unknown_token.view(1, 1, -1) unknown_token = self._unknown_token.view(1, 1, -1)
raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token

View File

@ -43,6 +43,7 @@ class SyntheticDEnv(data.Dataset):
num_per_task: int = 5000, num_per_task: int = 5000,
timestamp_config: Optional[Dict] = None, timestamp_config: Optional[Dict] = None,
mode: Optional[str] = None, mode: Optional[str] = None,
timestamp_noise_scale: float = 0.3,
): ):
self._ndim = len(mean_functors) self._ndim = len(mean_functors)
assert self._ndim == len( assert self._ndim == len(
@ -59,6 +60,7 @@ class SyntheticDEnv(data.Dataset):
timestamp_config["mode"] = mode timestamp_config["mode"] = mode
self._timestamp_generator = TimeStamp(**timestamp_config) self._timestamp_generator = TimeStamp(**timestamp_config)
self._timestamp_noise_scale = timestamp_noise_scale
self._mean_functors = mean_functors self._mean_functors = mean_functors
self._cov_functors = cov_functors self._cov_functors = cov_functors
@ -110,7 +112,9 @@ class SyntheticDEnv(data.Dataset):
if self._seq_length is None: if self._seq_length is None:
return self.__call__(timestamp) return self.__call__(timestamp)
else: else:
noise = random.random() * self.timestamp_interval * 0.3 noise = (
random.random() * self.timestamp_interval * self._timestamp_noise_scale
)
timestamps = [ timestamps = [
timestamp + i * self.timestamp_interval + noise timestamp + i * self.timestamp_interval + noise
for i in range(self._seq_length) for i in range(self._seq_length)