Update LFNA version 1.0
This commit is contained in:
parent
80aaac4dfa
commit
34560ad8d1
@ -86,9 +86,10 @@ def main(args):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
output_dim=1,
|
output_dim=1,
|
||||||
act_cls="leaky_relu",
|
act_cls="leaky_relu",
|
||||||
norm_cls="simple_norm",
|
norm_cls="identity",
|
||||||
mean=mean,
|
# norm_cls="simple_norm",
|
||||||
std=std,
|
# mean=mean,
|
||||||
|
# std=std,
|
||||||
)
|
)
|
||||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||||
# build optimizer
|
# build optimizer
|
||||||
|
@ -58,6 +58,8 @@ def main(args):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
w_container_per_epoch = dict()
|
||||||
|
|
||||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||||
for i, idx in enumerate(to_evaluate_indexes):
|
for i, idx in enumerate(to_evaluate_indexes):
|
||||||
|
|
||||||
@ -73,7 +75,6 @@ def main(args):
|
|||||||
+ need_time
|
+ need_time
|
||||||
)
|
)
|
||||||
# train the same data
|
# train the same data
|
||||||
assert idx != 0
|
|
||||||
historical_x = env_info["{:}-x".format(idx)]
|
historical_x = env_info["{:}-x".format(idx)]
|
||||||
historical_y = env_info["{:}-y".format(idx)]
|
historical_y = env_info["{:}-y".format(idx)]
|
||||||
# build model
|
# build model
|
||||||
@ -82,9 +83,10 @@ def main(args):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
output_dim=1,
|
output_dim=1,
|
||||||
act_cls="leaky_relu",
|
act_cls="leaky_relu",
|
||||||
norm_cls="simple_norm",
|
norm_cls="identity",
|
||||||
mean=mean,
|
# norm_cls="simple_norm",
|
||||||
std=std,
|
# mean=mean,
|
||||||
|
# std=std,
|
||||||
)
|
)
|
||||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||||
# build optimizer
|
# build optimizer
|
||||||
@ -137,6 +139,7 @@ def main(args):
|
|||||||
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
|
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
|
||||||
idx, env_info["total"]
|
idx, env_info["total"]
|
||||||
)
|
)
|
||||||
|
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
{
|
{
|
||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
@ -151,6 +154,11 @@ def main(args):
|
|||||||
|
|
||||||
per_timestamp_time.update(time.time() - start_time)
|
per_timestamp_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
save_checkpoint(
|
||||||
|
{"w_container_per_epoch": w_container_per_epoch},
|
||||||
|
logger.path(None) / "final-ckp.pth",
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
logger.log("-" * 200 + "\n")
|
logger.log("-" * 200 + "\n")
|
||||||
logger.close()
|
logger.close()
|
||||||
|
@ -39,9 +39,11 @@ class LFNAmlp:
|
|||||||
self.delta_net.parameters(), lr=0.01, amsgrad=True
|
self.delta_net.parameters(), lr=0.01, amsgrad=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def adapt(self, model, criterion, w_container, xs, ys):
|
def adapt(self, model, criterion, w_container, seq_datasets):
|
||||||
|
w_container.requires_grad_(True)
|
||||||
containers = [w_container]
|
containers = [w_container]
|
||||||
for idx, (x, y) in enumerate(zip(xs, ys)):
|
for idx, dataset in enumerate(seq_datasets):
|
||||||
|
x, y = dataset.x, dataset.y
|
||||||
y_hat = model.forward_with_container(x, containers[-1])
|
y_hat = model.forward_with_container(x, containers[-1])
|
||||||
loss = criterion(y_hat, y)
|
loss = criterion(y_hat, y)
|
||||||
gradients = torch.autograd.grad(loss, containers[-1].tensors)
|
gradients = torch.autograd.grad(loss, containers[-1].tensors)
|
||||||
@ -52,21 +54,30 @@ class LFNAmlp:
|
|||||||
input_statistics = input_statistics.expand(flatten_w.numel(), -1)
|
input_statistics = input_statistics.expand(flatten_w.numel(), -1)
|
||||||
delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1)
|
delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1)
|
||||||
delta = self.delta_net(delta_inputs).view(-1)
|
delta = self.delta_net(delta_inputs).view(-1)
|
||||||
# delta = torch.clamp(delta, -0.5, 0.5)
|
delta = torch.clamp(delta, -0.5, 0.5)
|
||||||
unflatten_delta = containers[-1].unflatten(delta)
|
unflatten_delta = containers[-1].unflatten(delta)
|
||||||
future_container = containers[-1].additive(unflatten_delta)
|
future_container = containers[-1].no_grad_clone().additive(unflatten_delta)
|
||||||
|
# future_container = containers[-1].additive(unflatten_delta)
|
||||||
containers.append(future_container)
|
containers.append(future_container)
|
||||||
# containers = containers[1:]
|
# containers = containers[1:]
|
||||||
meta_loss = []
|
meta_loss = []
|
||||||
for idx, (x, y) in enumerate(zip(xs, ys)):
|
temp_containers = []
|
||||||
|
for idx, dataset in enumerate(seq_datasets):
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
continue
|
continue
|
||||||
current_container = containers[idx]
|
current_container = containers[idx]
|
||||||
y_hat = model.forward_with_container(x, current_container)
|
y_hat = model.forward_with_container(dataset.x, current_container)
|
||||||
loss = criterion(y_hat, y)
|
loss = criterion(y_hat, dataset.y)
|
||||||
meta_loss.append(loss)
|
meta_loss.append(loss)
|
||||||
|
temp_containers.append((dataset.timestamp, current_container, -loss.item()))
|
||||||
meta_loss = sum(meta_loss)
|
meta_loss = sum(meta_loss)
|
||||||
meta_loss.backward()
|
w_container.requires_grad_(False)
|
||||||
|
# meta_loss.backward()
|
||||||
|
# self.meta_optimizer.step()
|
||||||
|
return meta_loss, temp_containers
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0)
|
||||||
self.meta_optimizer.step()
|
self.meta_optimizer.step()
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
@ -74,6 +85,25 @@ class LFNAmlp:
|
|||||||
self.delta_net.zero_grad()
|
self.delta_net.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
class TimeData:
|
||||||
|
def __init__(self, timestamp, xs, ys):
|
||||||
|
self._timestamp = timestamp
|
||||||
|
self._xs = xs
|
||||||
|
self._ys = ys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x(self):
|
||||||
|
return self._xs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y(self):
|
||||||
|
return self._ys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timestamp(self):
|
||||||
|
return self._timestamp
|
||||||
|
|
||||||
|
|
||||||
class Population:
|
class Population:
|
||||||
"""A population used to maintain models at different timestamps."""
|
"""A population used to maintain models at different timestamps."""
|
||||||
|
|
||||||
@ -83,20 +113,29 @@ class Population:
|
|||||||
|
|
||||||
def append(self, timestamp, model, score):
|
def append(self, timestamp, model, score):
|
||||||
if timestamp in self._time2model:
|
if timestamp in self._time2model:
|
||||||
raise ValueError("This timestamp has been added.")
|
if self._time2score[timestamp] > score:
|
||||||
self._time2model[timestamp] = model
|
return
|
||||||
|
self._time2model[timestamp] = model.no_grad_clone()
|
||||||
self._time2score[timestamp] = score
|
self._time2score[timestamp] = score
|
||||||
|
|
||||||
def query(self, timestamp):
|
def query(self, timestamp):
|
||||||
closet_timestamp = None
|
closet_timestamp = None
|
||||||
for xtime, model in self._time2model.items():
|
for xtime, model in self._time2model.items():
|
||||||
if (
|
if closet_timestamp is None or (
|
||||||
closet_timestamp is None
|
xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime
|
||||||
or timestamp - closet_timestamp >= timestamp - xtime
|
|
||||||
):
|
):
|
||||||
closet_timestamp = xtime
|
closet_timestamp = xtime
|
||||||
return self._time2model[closet_timestamp], closet_timestamp
|
return self._time2model[closet_timestamp], closet_timestamp
|
||||||
|
|
||||||
|
def debug_info(self, timestamps):
|
||||||
|
xstrs = []
|
||||||
|
for timestamp in timestamps:
|
||||||
|
if timestamp in self._time2score:
|
||||||
|
xstrs.append(
|
||||||
|
"{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp])
|
||||||
|
)
|
||||||
|
return ", ".join(xstrs)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
prepare_seed(args.rand_seed)
|
prepare_seed(args.rand_seed)
|
||||||
@ -125,21 +164,19 @@ def main(args):
|
|||||||
base_model = get_model(
|
base_model = get_model(
|
||||||
dict(model_type="simple_mlp"),
|
dict(model_type="simple_mlp"),
|
||||||
act_cls="leaky_relu",
|
act_cls="leaky_relu",
|
||||||
norm_cls="simple_learn_norm",
|
norm_cls="identity",
|
||||||
mean=0,
|
|
||||||
std=1,
|
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
output_dim=1,
|
output_dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
w_container = base_model.named_parameters_buffers()
|
w_container = base_model.get_w_container()
|
||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
print("There are {:} weights.".format(w_container.numel()))
|
print("There are {:} weights.".format(w_container.numel()))
|
||||||
|
|
||||||
adaptor = LFNAmlp(4, (50, 20), "leaky_relu")
|
adaptor = LFNAmlp(4, (50, 20), "leaky_relu")
|
||||||
|
|
||||||
pool = Population()
|
pool = Population()
|
||||||
pool.append(0, w_container)
|
pool.append(0, w_container, -100)
|
||||||
|
|
||||||
# LFNA meta-training
|
# LFNA meta-training
|
||||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||||
@ -153,22 +190,35 @@ def main(args):
|
|||||||
+ need_time
|
+ need_time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
adaptor.zero_grad()
|
||||||
|
|
||||||
|
debug_timestamp = set()
|
||||||
|
all_meta_losses = []
|
||||||
for ibatch in range(args.meta_batch):
|
for ibatch in range(args.meta_batch):
|
||||||
sampled_timestamp = random.randint(0, train_time_bar)
|
sampled_timestamp = random.randint(0, train_time_bar)
|
||||||
query_w_container, query_timestamp = pool.query(sampled_timestamp)
|
query_w_container, query_timestamp = pool.query(sampled_timestamp)
|
||||||
# def adapt(self, model, w_container, xs, ys):
|
# def adapt(self, model, w_container, xs, ys):
|
||||||
xs, ys = [], []
|
seq_datasets = []
|
||||||
|
# xs, ys = [], []
|
||||||
for it in range(sampled_timestamp, sampled_timestamp + args.max_seq):
|
for it in range(sampled_timestamp, sampled_timestamp + args.max_seq):
|
||||||
xs.append(env_info["{:}-x".format(it)])
|
xs = env_info["{:}-x".format(it)]
|
||||||
ys.append(env_info["{:}-y".format(it)])
|
ys = env_info["{:}-y".format(it)]
|
||||||
adaptor.adapt(base_model, criterion, query_w_container, xs, ys)
|
seq_datasets.append(TimeData(it, xs, ys))
|
||||||
import pdb
|
temp_meta_loss, temp_containers = adaptor.adapt(
|
||||||
|
base_model, criterion, query_w_container, seq_datasets
|
||||||
|
)
|
||||||
|
all_meta_losses.append(temp_meta_loss)
|
||||||
|
for temp_time, temp_container, temp_score in temp_containers:
|
||||||
|
pool.append(temp_time, temp_container, temp_score)
|
||||||
|
debug_timestamp.add(temp_time)
|
||||||
|
meta_loss = torch.stack(all_meta_losses).mean()
|
||||||
|
meta_loss.backward()
|
||||||
|
adaptor.step()
|
||||||
|
|
||||||
pdb.set_trace()
|
debug_str = pool.debug_info(debug_timestamp)
|
||||||
print("-")
|
logger.log("meta-loss: {:.4f}".format(meta_loss.item()))
|
||||||
logger.log("")
|
|
||||||
|
|
||||||
per_timestamp_time.update(time.time() - start_time)
|
per_epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
logger.log("-" * 200 + "\n")
|
logger.log("-" * 200 + "\n")
|
||||||
@ -192,7 +242,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--meta_batch",
|
"--meta_batch",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=5,
|
||||||
help="The batch size for the meta-model",
|
help="The batch size for the meta-model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -23,7 +23,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
|||||||
if str(lib_dir) not in sys.path:
|
if str(lib_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
from models.xcore import get_model
|
||||||
from datasets.synthetic_core import get_synthetic_env
|
from datasets.synthetic_core import get_synthetic_env
|
||||||
from datasets.synthetic_example import create_example_v1
|
from datasets.synthetic_example import create_example_v1
|
||||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
from utils.temp_sync import optimize_fn, evaluate_fn
|
||||||
@ -300,8 +300,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
|
|
||||||
alg_name2dir = OrderedDict()
|
alg_name2dir = OrderedDict()
|
||||||
alg_name2dir["Optimal"] = "use-same-timestamp"
|
alg_name2dir["Optimal"] = "use-same-timestamp"
|
||||||
alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
|
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
|
||||||
colors = ["r", "g"]
|
alg_name2all_containers = OrderedDict()
|
||||||
|
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
||||||
|
ckp_path = Path(alg_dir) / xdir / "final-ckp.pth"
|
||||||
|
xdata = torch.load(ckp_path)
|
||||||
|
alg_name2all_containers[alg] = xdata["w_container_per_epoch"]
|
||||||
|
# load the basic model
|
||||||
|
model = get_model(
|
||||||
|
dict(model_type="simple_mlp"),
|
||||||
|
act_cls="leaky_relu",
|
||||||
|
norm_cls="identity",
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
|
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
|
||||||
colors = ["r", "g"]
|
colors = ["r", "g"]
|
||||||
@ -323,6 +335,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
|
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
|
||||||
|
|
||||||
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
||||||
|
"""
|
||||||
ckp_path = (
|
ckp_path = (
|
||||||
Path(alg_dir)
|
Path(alg_dir)
|
||||||
/ xdir
|
/ xdir
|
||||||
@ -330,8 +343,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
)
|
)
|
||||||
assert ckp_path.exists()
|
assert ckp_path.exists()
|
||||||
ckp_data = torch.load(ckp_path)
|
ckp_data = torch.load(ckp_path)
|
||||||
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
predicts = ckp_data["model"](ori_allx)
|
# predicts = ckp_data["model"](ori_allx)
|
||||||
|
predicts = model.forward_with_container(
|
||||||
|
ori_allx, alg_name2all_containers[alg][idx]
|
||||||
|
)
|
||||||
predicts = predicts.cpu()
|
predicts = predicts.cpu()
|
||||||
# keep data
|
# keep data
|
||||||
metric = MSEMetric()
|
metric = MSEMetric()
|
||||||
|
@ -55,6 +55,10 @@ class TensorContainer:
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def requires_grad_(self, requires_grad=True):
|
||||||
|
for tensor in self._tensors:
|
||||||
|
tensor.requires_grad_(requires_grad)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tensors(self):
|
def tensors(self):
|
||||||
return self._tensors
|
return self._tensors
|
||||||
@ -162,7 +166,7 @@ class SuperModule(abc.ABC, nn.Module):
|
|||||||
)
|
)
|
||||||
self._abstract_child = abstract_child
|
self._abstract_child = abstract_child
|
||||||
|
|
||||||
def named_parameters_buffers(self):
|
def get_w_container(self):
|
||||||
container = TensorContainer()
|
container = TensorContainer()
|
||||||
for name, param in self.named_parameters():
|
for name, param in self.named_parameters():
|
||||||
container.append(name, param, True)
|
container.append(name, param, True)
|
||||||
|
Loading…
Reference in New Issue
Block a user