Update LFNA version 1.0
This commit is contained in:
parent
80aaac4dfa
commit
34560ad8d1
@ -86,9 +86,10 @@ def main(args):
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="simple_norm",
|
||||
mean=mean,
|
||||
std=std,
|
||||
norm_cls="identity",
|
||||
# norm_cls="simple_norm",
|
||||
# mean=mean,
|
||||
# std=std,
|
||||
)
|
||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||
# build optimizer
|
||||
|
@ -58,6 +58,8 @@ def main(args):
|
||||
)
|
||||
)
|
||||
|
||||
w_container_per_epoch = dict()
|
||||
|
||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||
for i, idx in enumerate(to_evaluate_indexes):
|
||||
|
||||
@ -73,7 +75,6 @@ def main(args):
|
||||
+ need_time
|
||||
)
|
||||
# train the same data
|
||||
assert idx != 0
|
||||
historical_x = env_info["{:}-x".format(idx)]
|
||||
historical_y = env_info["{:}-y".format(idx)]
|
||||
# build model
|
||||
@ -82,9 +83,10 @@ def main(args):
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="simple_norm",
|
||||
mean=mean,
|
||||
std=std,
|
||||
norm_cls="identity",
|
||||
# norm_cls="simple_norm",
|
||||
# mean=mean,
|
||||
# std=std,
|
||||
)
|
||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||
# build optimizer
|
||||
@ -137,6 +139,7 @@ def main(args):
|
||||
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
|
||||
idx, env_info["total"]
|
||||
)
|
||||
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
|
||||
save_checkpoint(
|
||||
{
|
||||
"model_state_dict": model.state_dict(),
|
||||
@ -151,6 +154,11 @@ def main(args):
|
||||
|
||||
per_timestamp_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
save_checkpoint(
|
||||
{"w_container_per_epoch": w_container_per_epoch},
|
||||
logger.path(None) / "final-ckp.pth",
|
||||
logger,
|
||||
)
|
||||
|
||||
logger.log("-" * 200 + "\n")
|
||||
logger.close()
|
||||
|
@ -39,9 +39,11 @@ class LFNAmlp:
|
||||
self.delta_net.parameters(), lr=0.01, amsgrad=True
|
||||
)
|
||||
|
||||
def adapt(self, model, criterion, w_container, xs, ys):
|
||||
def adapt(self, model, criterion, w_container, seq_datasets):
|
||||
w_container.requires_grad_(True)
|
||||
containers = [w_container]
|
||||
for idx, (x, y) in enumerate(zip(xs, ys)):
|
||||
for idx, dataset in enumerate(seq_datasets):
|
||||
x, y = dataset.x, dataset.y
|
||||
y_hat = model.forward_with_container(x, containers[-1])
|
||||
loss = criterion(y_hat, y)
|
||||
gradients = torch.autograd.grad(loss, containers[-1].tensors)
|
||||
@ -52,21 +54,30 @@ class LFNAmlp:
|
||||
input_statistics = input_statistics.expand(flatten_w.numel(), -1)
|
||||
delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1)
|
||||
delta = self.delta_net(delta_inputs).view(-1)
|
||||
# delta = torch.clamp(delta, -0.5, 0.5)
|
||||
delta = torch.clamp(delta, -0.5, 0.5)
|
||||
unflatten_delta = containers[-1].unflatten(delta)
|
||||
future_container = containers[-1].additive(unflatten_delta)
|
||||
future_container = containers[-1].no_grad_clone().additive(unflatten_delta)
|
||||
# future_container = containers[-1].additive(unflatten_delta)
|
||||
containers.append(future_container)
|
||||
# containers = containers[1:]
|
||||
meta_loss = []
|
||||
for idx, (x, y) in enumerate(zip(xs, ys)):
|
||||
temp_containers = []
|
||||
for idx, dataset in enumerate(seq_datasets):
|
||||
if idx == 0:
|
||||
continue
|
||||
current_container = containers[idx]
|
||||
y_hat = model.forward_with_container(x, current_container)
|
||||
loss = criterion(y_hat, y)
|
||||
y_hat = model.forward_with_container(dataset.x, current_container)
|
||||
loss = criterion(y_hat, dataset.y)
|
||||
meta_loss.append(loss)
|
||||
temp_containers.append((dataset.timestamp, current_container, -loss.item()))
|
||||
meta_loss = sum(meta_loss)
|
||||
meta_loss.backward()
|
||||
w_container.requires_grad_(False)
|
||||
# meta_loss.backward()
|
||||
# self.meta_optimizer.step()
|
||||
return meta_loss, temp_containers
|
||||
|
||||
def step(self):
|
||||
torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0)
|
||||
self.meta_optimizer.step()
|
||||
|
||||
def zero_grad(self):
|
||||
@ -74,6 +85,25 @@ class LFNAmlp:
|
||||
self.delta_net.zero_grad()
|
||||
|
||||
|
||||
class TimeData:
|
||||
def __init__(self, timestamp, xs, ys):
|
||||
self._timestamp = timestamp
|
||||
self._xs = xs
|
||||
self._ys = ys
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self._xs
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self._ys
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
return self._timestamp
|
||||
|
||||
|
||||
class Population:
|
||||
"""A population used to maintain models at different timestamps."""
|
||||
|
||||
@ -83,20 +113,29 @@ class Population:
|
||||
|
||||
def append(self, timestamp, model, score):
|
||||
if timestamp in self._time2model:
|
||||
raise ValueError("This timestamp has been added.")
|
||||
self._time2model[timestamp] = model
|
||||
if self._time2score[timestamp] > score:
|
||||
return
|
||||
self._time2model[timestamp] = model.no_grad_clone()
|
||||
self._time2score[timestamp] = score
|
||||
|
||||
def query(self, timestamp):
|
||||
closet_timestamp = None
|
||||
for xtime, model in self._time2model.items():
|
||||
if (
|
||||
closet_timestamp is None
|
||||
or timestamp - closet_timestamp >= timestamp - xtime
|
||||
if closet_timestamp is None or (
|
||||
xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime
|
||||
):
|
||||
closet_timestamp = xtime
|
||||
return self._time2model[closet_timestamp], closet_timestamp
|
||||
|
||||
def debug_info(self, timestamps):
|
||||
xstrs = []
|
||||
for timestamp in timestamps:
|
||||
if timestamp in self._time2score:
|
||||
xstrs.append(
|
||||
"{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp])
|
||||
)
|
||||
return ", ".join(xstrs)
|
||||
|
||||
|
||||
def main(args):
|
||||
prepare_seed(args.rand_seed)
|
||||
@ -125,21 +164,19 @@ def main(args):
|
||||
base_model = get_model(
|
||||
dict(model_type="simple_mlp"),
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="simple_learn_norm",
|
||||
mean=0,
|
||||
std=1,
|
||||
norm_cls="identity",
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
)
|
||||
|
||||
w_container = base_model.named_parameters_buffers()
|
||||
w_container = base_model.get_w_container()
|
||||
criterion = torch.nn.MSELoss()
|
||||
print("There are {:} weights.".format(w_container.numel()))
|
||||
|
||||
adaptor = LFNAmlp(4, (50, 20), "leaky_relu")
|
||||
|
||||
pool = Population()
|
||||
pool.append(0, w_container)
|
||||
pool.append(0, w_container, -100)
|
||||
|
||||
# LFNA meta-training
|
||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||
@ -153,22 +190,35 @@ def main(args):
|
||||
+ need_time
|
||||
)
|
||||
|
||||
adaptor.zero_grad()
|
||||
|
||||
debug_timestamp = set()
|
||||
all_meta_losses = []
|
||||
for ibatch in range(args.meta_batch):
|
||||
sampled_timestamp = random.randint(0, train_time_bar)
|
||||
query_w_container, query_timestamp = pool.query(sampled_timestamp)
|
||||
# def adapt(self, model, w_container, xs, ys):
|
||||
xs, ys = [], []
|
||||
seq_datasets = []
|
||||
# xs, ys = [], []
|
||||
for it in range(sampled_timestamp, sampled_timestamp + args.max_seq):
|
||||
xs.append(env_info["{:}-x".format(it)])
|
||||
ys.append(env_info["{:}-y".format(it)])
|
||||
adaptor.adapt(base_model, criterion, query_w_container, xs, ys)
|
||||
import pdb
|
||||
xs = env_info["{:}-x".format(it)]
|
||||
ys = env_info["{:}-y".format(it)]
|
||||
seq_datasets.append(TimeData(it, xs, ys))
|
||||
temp_meta_loss, temp_containers = adaptor.adapt(
|
||||
base_model, criterion, query_w_container, seq_datasets
|
||||
)
|
||||
all_meta_losses.append(temp_meta_loss)
|
||||
for temp_time, temp_container, temp_score in temp_containers:
|
||||
pool.append(temp_time, temp_container, temp_score)
|
||||
debug_timestamp.add(temp_time)
|
||||
meta_loss = torch.stack(all_meta_losses).mean()
|
||||
meta_loss.backward()
|
||||
adaptor.step()
|
||||
|
||||
pdb.set_trace()
|
||||
print("-")
|
||||
logger.log("")
|
||||
debug_str = pool.debug_info(debug_timestamp)
|
||||
logger.log("meta-loss: {:.4f}".format(meta_loss.item()))
|
||||
|
||||
per_timestamp_time.update(time.time() - start_time)
|
||||
per_epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
logger.log("-" * 200 + "\n")
|
||||
@ -192,7 +242,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--meta_batch",
|
||||
type=int,
|
||||
default=2,
|
||||
default=5,
|
||||
help="The batch size for the meta-model",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -23,7 +23,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
|
||||
from models.xcore import get_model
|
||||
from datasets.synthetic_core import get_synthetic_env
|
||||
from datasets.synthetic_example import create_example_v1
|
||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
||||
@ -300,8 +300,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
||||
|
||||
alg_name2dir = OrderedDict()
|
||||
alg_name2dir["Optimal"] = "use-same-timestamp"
|
||||
alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
|
||||
colors = ["r", "g"]
|
||||
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
|
||||
alg_name2all_containers = OrderedDict()
|
||||
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
||||
ckp_path = Path(alg_dir) / xdir / "final-ckp.pth"
|
||||
xdata = torch.load(ckp_path)
|
||||
alg_name2all_containers[alg] = xdata["w_container_per_epoch"]
|
||||
# load the basic model
|
||||
model = get_model(
|
||||
dict(model_type="simple_mlp"),
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="identity",
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
)
|
||||
|
||||
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
|
||||
colors = ["r", "g"]
|
||||
@ -323,6 +335,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
||||
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
|
||||
|
||||
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
||||
"""
|
||||
ckp_path = (
|
||||
Path(alg_dir)
|
||||
/ xdir
|
||||
@ -330,8 +343,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
||||
)
|
||||
assert ckp_path.exists()
|
||||
ckp_data = torch.load(ckp_path)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
predicts = ckp_data["model"](ori_allx)
|
||||
# predicts = ckp_data["model"](ori_allx)
|
||||
predicts = model.forward_with_container(
|
||||
ori_allx, alg_name2all_containers[alg][idx]
|
||||
)
|
||||
predicts = predicts.cpu()
|
||||
# keep data
|
||||
metric = MSEMetric()
|
||||
|
@ -55,6 +55,10 @@ class TensorContainer:
|
||||
)
|
||||
return result
|
||||
|
||||
def requires_grad_(self, requires_grad=True):
|
||||
for tensor in self._tensors:
|
||||
tensor.requires_grad_(requires_grad)
|
||||
|
||||
@property
|
||||
def tensors(self):
|
||||
return self._tensors
|
||||
@ -162,7 +166,7 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
)
|
||||
self._abstract_child = abstract_child
|
||||
|
||||
def named_parameters_buffers(self):
|
||||
def get_w_container(self):
|
||||
container = TensorContainer()
|
||||
for name, param in self.named_parameters():
|
||||
container.append(name, param, True)
|
||||
|
Loading…
Reference in New Issue
Block a user