diff --git a/exps/LFNA/basic-his.py b/exps/LFNA/basic-his.py index 5ba3d68..857d1da 100644 --- a/exps/LFNA/basic-his.py +++ b/exps/LFNA/basic-his.py @@ -86,9 +86,10 @@ def main(args): input_dim=1, output_dim=1, act_cls="leaky_relu", - norm_cls="simple_norm", - mean=mean, - std=std, + norm_cls="identity", + # norm_cls="simple_norm", + # mean=mean, + # std=std, ) model = get_model(dict(model_type="simple_mlp"), **model_kwargs) # build optimizer diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index 4bcb702..a1bb87b 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -58,6 +58,8 @@ def main(args): ) ) + w_container_per_epoch = dict() + per_timestamp_time, start_time = AverageMeter(), time.time() for i, idx in enumerate(to_evaluate_indexes): @@ -73,7 +75,6 @@ def main(args): + need_time ) # train the same data - assert idx != 0 historical_x = env_info["{:}-x".format(idx)] historical_y = env_info["{:}-y".format(idx)] # build model @@ -82,9 +83,10 @@ def main(args): input_dim=1, output_dim=1, act_cls="leaky_relu", - norm_cls="simple_norm", - mean=mean, - std=std, + norm_cls="identity", + # norm_cls="simple_norm", + # mean=mean, + # std=std, ) model = get_model(dict(model_type="simple_mlp"), **model_kwargs) # build optimizer @@ -137,6 +139,7 @@ def main(args): save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( idx, env_info["total"] ) + w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() save_checkpoint( { "model_state_dict": model.state_dict(), @@ -151,6 +154,11 @@ def main(args): per_timestamp_time.update(time.time() - start_time) start_time = time.time() + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) logger.log("-" * 200 + "\n") logger.close() diff --git a/exps/LFNA/lfna-v1.py b/exps/LFNA/lfna-v1.py index c604d59..60dd62a 100644 --- a/exps/LFNA/lfna-v1.py +++ b/exps/LFNA/lfna-v1.py @@ -39,9 +39,11 @@ class LFNAmlp: self.delta_net.parameters(), lr=0.01, amsgrad=True ) - def adapt(self, model, criterion, w_container, xs, ys): + def adapt(self, model, criterion, w_container, seq_datasets): + w_container.requires_grad_(True) containers = [w_container] - for idx, (x, y) in enumerate(zip(xs, ys)): + for idx, dataset in enumerate(seq_datasets): + x, y = dataset.x, dataset.y y_hat = model.forward_with_container(x, containers[-1]) loss = criterion(y_hat, y) gradients = torch.autograd.grad(loss, containers[-1].tensors) @@ -52,21 +54,30 @@ class LFNAmlp: input_statistics = input_statistics.expand(flatten_w.numel(), -1) delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) delta = self.delta_net(delta_inputs).view(-1) - # delta = torch.clamp(delta, -0.5, 0.5) + delta = torch.clamp(delta, -0.5, 0.5) unflatten_delta = containers[-1].unflatten(delta) - future_container = containers[-1].additive(unflatten_delta) + future_container = containers[-1].no_grad_clone().additive(unflatten_delta) + # future_container = containers[-1].additive(unflatten_delta) containers.append(future_container) # containers = containers[1:] meta_loss = [] - for idx, (x, y) in enumerate(zip(xs, ys)): + temp_containers = [] + for idx, dataset in enumerate(seq_datasets): if idx == 0: continue current_container = containers[idx] - y_hat = model.forward_with_container(x, current_container) - loss = criterion(y_hat, y) + y_hat = model.forward_with_container(dataset.x, current_container) + loss = criterion(y_hat, dataset.y) meta_loss.append(loss) + temp_containers.append((dataset.timestamp, current_container, -loss.item())) meta_loss = sum(meta_loss) - meta_loss.backward() + w_container.requires_grad_(False) + # meta_loss.backward() + # self.meta_optimizer.step() + return meta_loss, temp_containers + + def step(self): + torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) self.meta_optimizer.step() def zero_grad(self): @@ -74,6 +85,25 @@ class LFNAmlp: self.delta_net.zero_grad() +class TimeData: + def __init__(self, timestamp, xs, ys): + self._timestamp = timestamp + self._xs = xs + self._ys = ys + + @property + def x(self): + return self._xs + + @property + def y(self): + return self._ys + + @property + def timestamp(self): + return self._timestamp + + class Population: """A population used to maintain models at different timestamps.""" @@ -83,20 +113,29 @@ class Population: def append(self, timestamp, model, score): if timestamp in self._time2model: - raise ValueError("This timestamp has been added.") - self._time2model[timestamp] = model + if self._time2score[timestamp] > score: + return + self._time2model[timestamp] = model.no_grad_clone() self._time2score[timestamp] = score def query(self, timestamp): closet_timestamp = None for xtime, model in self._time2model.items(): - if ( - closet_timestamp is None - or timestamp - closet_timestamp >= timestamp - xtime + if closet_timestamp is None or ( + xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime ): closet_timestamp = xtime return self._time2model[closet_timestamp], closet_timestamp + def debug_info(self, timestamps): + xstrs = [] + for timestamp in timestamps: + if timestamp in self._time2score: + xstrs.append( + "{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp]) + ) + return ", ".join(xstrs) + def main(args): prepare_seed(args.rand_seed) @@ -125,21 +164,19 @@ def main(args): base_model = get_model( dict(model_type="simple_mlp"), act_cls="leaky_relu", - norm_cls="simple_learn_norm", - mean=0, - std=1, + norm_cls="identity", input_dim=1, output_dim=1, ) - w_container = base_model.named_parameters_buffers() + w_container = base_model.get_w_container() criterion = torch.nn.MSELoss() print("There are {:} weights.".format(w_container.numel())) adaptor = LFNAmlp(4, (50, 20), "leaky_relu") pool = Population() - pool.append(0, w_container) + pool.append(0, w_container, -100) # LFNA meta-training per_epoch_time, start_time = AverageMeter(), time.time() @@ -153,22 +190,35 @@ def main(args): + need_time ) + adaptor.zero_grad() + + debug_timestamp = set() + all_meta_losses = [] for ibatch in range(args.meta_batch): sampled_timestamp = random.randint(0, train_time_bar) query_w_container, query_timestamp = pool.query(sampled_timestamp) # def adapt(self, model, w_container, xs, ys): - xs, ys = [], [] + seq_datasets = [] + # xs, ys = [], [] for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): - xs.append(env_info["{:}-x".format(it)]) - ys.append(env_info["{:}-y".format(it)]) - adaptor.adapt(base_model, criterion, query_w_container, xs, ys) - import pdb + xs = env_info["{:}-x".format(it)] + ys = env_info["{:}-y".format(it)] + seq_datasets.append(TimeData(it, xs, ys)) + temp_meta_loss, temp_containers = adaptor.adapt( + base_model, criterion, query_w_container, seq_datasets + ) + all_meta_losses.append(temp_meta_loss) + for temp_time, temp_container, temp_score in temp_containers: + pool.append(temp_time, temp_container, temp_score) + debug_timestamp.add(temp_time) + meta_loss = torch.stack(all_meta_losses).mean() + meta_loss.backward() + adaptor.step() - pdb.set_trace() - print("-") - logger.log("") + debug_str = pool.debug_info(debug_timestamp) + logger.log("meta-loss: {:.4f}".format(meta_loss.item())) - per_timestamp_time.update(time.time() - start_time) + per_epoch_time.update(time.time() - start_time) start_time = time.time() logger.log("-" * 200 + "\n") @@ -192,7 +242,7 @@ if __name__ == "__main__": parser.add_argument( "--meta_batch", type=int, - default=2, + default=5, help="The batch size for the meta-model", ) parser.add_argument( diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 357d7b1..210d639 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -23,7 +23,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) - +from models.xcore import get_model from datasets.synthetic_core import get_synthetic_env from datasets.synthetic_example import create_example_v1 from utils.temp_sync import optimize_fn, evaluate_fn @@ -300,8 +300,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): alg_name2dir = OrderedDict() alg_name2dir["Optimal"] = "use-same-timestamp" - alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" - colors = ["r", "g"] + # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" + alg_name2all_containers = OrderedDict() + for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): + ckp_path = Path(alg_dir) / xdir / "final-ckp.pth" + xdata = torch.load(ckp_path) + alg_name2all_containers[alg] = xdata["w_container_per_epoch"] + # load the basic model + model = get_model( + dict(model_type="simple_mlp"), + act_cls="leaky_relu", + norm_cls="identity", + input_dim=1, + output_dim=1, + ) alg2xs, alg2ys = defaultdict(list), defaultdict(list) colors = ["r", "g"] @@ -323,6 +335,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): + """ ckp_path = ( Path(alg_dir) / xdir @@ -330,8 +343,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): ) assert ckp_path.exists() ckp_data = torch.load(ckp_path) + """ with torch.no_grad(): - predicts = ckp_data["model"](ori_allx) + # predicts = ckp_data["model"](ori_allx) + predicts = model.forward_with_container( + ori_allx, alg_name2all_containers[alg][idx] + ) predicts = predicts.cpu() # keep data metric = MSEMetric() diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index d9fde80..8ee9ad9 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -55,6 +55,10 @@ class TensorContainer: ) return result + def requires_grad_(self, requires_grad=True): + for tensor in self._tensors: + tensor.requires_grad_(requires_grad) + @property def tensors(self): return self._tensors @@ -162,7 +166,7 @@ class SuperModule(abc.ABC, nn.Module): ) self._abstract_child = abstract_child - def named_parameters_buffers(self): + def get_w_container(self): container = TensorContainer() for name, param in self.named_parameters(): container.append(name, param, True)