From 0dbbc286c9f56136291590136fffd513af881c36 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 10 May 2021 14:14:06 +0800 Subject: [PATCH] Update DEBUG INFO --- exps/LFNA/lfna-debug.py | 256 ++++++++++++++++++++++++++++++ exps/LFNA/lfna-fix-init.py | 239 ++++++++++++++++++++++++++++ exps/LFNA/lfna-v0.py | 272 -------------------------------- exps/LFNA/lfna_utils.py | 21 ++- exps/LFNA/vis-synthetic.py | 4 +- lib/datasets/synthetic_env.py | 7 + lib/datasets/synthetic_utils.py | 4 + lib/xlayers/super_module.py | 7 + 8 files changed, 536 insertions(+), 274 deletions(-) create mode 100644 exps/LFNA/lfna-debug.py create mode 100644 exps/LFNA/lfna-fix-init.py delete mode 100644 exps/LFNA/lfna-v0.py diff --git a/exps/LFNA/lfna-debug.py b/exps/LFNA/lfna-debug.py new file mode 100644 index 0000000..b5a3963 --- /dev/null +++ b/exps/LFNA/lfna-debug.py @@ -0,0 +1,256 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/LFNA/lfna-debug.py --env_version v1 --hidden_dim 16 +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from log_utils import time_string +from log_utils import AverageMeter, convert_secs2time + +from utils import split_str2indexes + +from procedures.advanced_main import basic_train_fn, basic_eval_fn +from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from datasets.synthetic_core import get_synthetic_env +from models.xcore import get_model +from xlayers import super_core + + +from lfna_utils import lfna_setup, train_model, TimeData + + +class LFNAmlp: + """A LFNA meta-model that uses the MLP as delta-net.""" + + def __init__(self, obs_dim, hidden_sizes, act_name, criterion): + self.delta_net = super_core.SuperSequential( + super_core.SuperLinear(obs_dim, hidden_sizes[0]), + super_core.super_name2activation[act_name](), + super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), + super_core.super_name2activation[act_name](), + super_core.SuperLinear(hidden_sizes[1], 1), + ) + self.meta_optimizer = torch.optim.Adam( + self.delta_net.parameters(), lr=0.01, amsgrad=True + ) + self.criterion = criterion + + def adapt(self, model, seq_flatten_w): + delta_inputs = torch.stack(seq_flatten_w, dim=-1) + delta = self.delta_net(delta_inputs) + container = model.get_w_container() + unflatten_delta = container.unflatten(delta) + future_container = container.create_container(unflatten_delta) + return future_container + + def step(self): + torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) + self.meta_optimizer.step() + + def zero_grad(self): + self.meta_optimizer.zero_grad() + self.delta_net.zero_grad() + + def state_dict(self): + return dict( + delta_net=self.delta_net.state_dict(), + meta_optimizer=self.meta_optimizer.state_dict(), + ) + + +def main(args): + logger, env_info, model_kwargs = lfna_setup(args) + dynamic_env = env_info["dynamic_env"] + model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + + total_time = env_info["total"] + for i in range(total_time): + for xkey in ("timestamp", "x", "y"): + nkey = "{:}-{:}".format(i, xkey) + assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) + train_time_bar = total_time // 2 + network = get_model(dict(model_type="simple_mlp"), **model_kwargs) + + criterion = torch.nn.MSELoss() + logger.log("There are {:} weights.".format(network.get_w_container().numel())) + + adaptor = LFNAmlp(args.meta_seq, (200, 200), "leaky_relu", criterion) + + # pre-train the model + init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) + init_loss = train_model(network, init_dataset, args.init_lr, args.epochs) + logger.log("The pre-training loss is {:.4f}".format(init_loss)) + + all_past_containers = [] + ground_truth_path = ( + logger.path(None) / ".." / "use-same-timestamp-v1-d16" / "final-ckp.pth" + ) + ground_truth_data = torch.load(ground_truth_path) + all_gt_containers = ground_truth_data["w_container_per_epoch"] + all_gt_flattens = dict() + for idx, container in all_gt_containers.items(): + all_gt_flattens[idx] = container.no_grad_clone().flatten() + + # LFNA meta-training + meta_loss_meter = AverageMeter() + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(args.epochs): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + logger.log( + "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + + need_time + ) + + adaptor.zero_grad() + + meta_losses = [] + for ibatch in range(args.meta_batch): + future_timestamp = random.randint(args.meta_seq, train_time_bar) + future_dataset = TimeData( + future_timestamp, + env_info["{:}-x".format(future_timestamp)], + env_info["{:}-y".format(future_timestamp)], + ) + seq_datasets = [] + for iseq in range(args.meta_seq): + cur_time = future_timestamp - iseq - 1 + cur_x = env_info["{:}-x".format(cur_time)] + cur_y = env_info["{:}-y".format(cur_time)] + seq_datasets.append(TimeData(cur_time, cur_x, cur_y)) + seq_datasets.reverse() + seq_flatten_w = [ + all_gt_flattens[dataset.timestamp] for dataset in seq_datasets + ] + future_container = adaptor.adapt(network, seq_flatten_w) + """ + future_y_hat = network.forward_with_container( + future_dataset.x, future_container + ) + future_loss = adaptor.criterion(future_y_hat, future_dataset.y) + """ + future_loss = adaptor.criterion( + future_container.flatten(), all_gt_flattens[future_timestamp] + ) + # import pdb; pdb.set_trace() + meta_losses.append(future_loss) + meta_loss = torch.stack(meta_losses).mean() + meta_loss.backward() + adaptor.step() + + meta_loss_meter.update(meta_loss.item()) + + logger.log( + "meta-loss: {:.4f} ({:.4f}) ".format( + meta_loss_meter.avg, meta_loss_meter.val + ) + ) + if iepoch % 200 == 0: + save_checkpoint( + {"adaptor": adaptor.state_dict(), "iepoch": iepoch}, + logger.path("model"), + logger, + ) + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + w_container_per_epoch = dict() + # import pdb; pdb.set_trace() + for idx in range(1, env_info["total"]): + future_time = env_info["{:}-timestamp".format(idx)] + future_x = env_info["{:}-x".format(idx)] + future_y = env_info["{:}-y".format(idx)] + seq_datasets = [] + for iseq in range(1, args.meta_seq + 1): + cur_time = future_timestamp - iseq - 1 + if cur_time < 0: + cur_time = 0 + cur_x = env_info["{:}-x".format(cur_time)] + cur_y = env_info["{:}-y".format(cur_time)] + seq_datasets.append(TimeData(cur_time, cur_x, cur_y)) + seq_datasets.reverse() + seq_flatten_w = [all_gt_flattens[dataset.timestamp] for dataset in seq_datasets] + future_container = adaptor.adapt(network, seq_flatten_w) + w_container_per_epoch[idx] = future_container.no_grad_clone() + with torch.no_grad(): + future_y_hat = network.forward_with_container( + future_x, w_container_per_epoch[idx] + ) + future_loss = adaptor.criterion(future_y_hat, future_y) + logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) + + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the data in the past.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/lfna-debug", + help="The checkpoint directory.", + ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + required=True, + help="The hidden dimension.", + ) + ##### + parser.add_argument( + "--init_lr", + type=float, + default=0.1, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--meta_batch", + type=int, + default=32, + help="The batch size for the meta-model", + ) + parser.add_argument( + "--meta_seq", + type=int, + default=10, + help="The length of the sequence for meta-model.", + ) + parser.add_argument( + "--epochs", + type=int, + default=2000, + help="The total number of epochs.", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + args.save_dir = "{:}-{:}-d{:}".format( + args.save_dir, args.env_version, args.hidden_dim + ) + main(args) diff --git a/exps/LFNA/lfna-fix-init.py b/exps/LFNA/lfna-fix-init.py new file mode 100644 index 0000000..c3e8e7b --- /dev/null +++ b/exps/LFNA/lfna-fix-init.py @@ -0,0 +1,239 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/LFNA/lfna-fix-init.py --env_version v1 --hidden_dim 16 +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from log_utils import time_string +from log_utils import AverageMeter, convert_secs2time + +from utils import split_str2indexes + +from procedures.advanced_main import basic_train_fn, basic_eval_fn +from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from datasets.synthetic_core import get_synthetic_env +from models.xcore import get_model +from xlayers import super_core + + +from lfna_utils import lfna_setup, train_model, TimeData + + +class LFNAmlp: + """A LFNA meta-model that uses the MLP as delta-net.""" + + def __init__(self, obs_dim, hidden_sizes, act_name, criterion): + self.delta_net = super_core.SuperSequential( + super_core.SuperLinear(obs_dim, hidden_sizes[0]), + super_core.super_name2activation[act_name](), + super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), + super_core.super_name2activation[act_name](), + super_core.SuperLinear(hidden_sizes[1], 1), + ) + self.meta_optimizer = torch.optim.Adam( + self.delta_net.parameters(), lr=0.001, amsgrad=True + ) + self.criterion = criterion + + def adapt(self, model, seq_datasets): + delta_inputs = [] + container = model.get_w_container() + for iseq, dataset in enumerate(seq_datasets): + y_hat = model.forward_with_container(dataset.x, container) + loss = self.criterion(y_hat, dataset.y) + gradients = torch.autograd.grad(loss, container.parameters()) + with torch.no_grad(): + flatten_g = container.flatten(gradients) + delta_inputs.append(flatten_g) + flatten_w = container.no_grad_clone().flatten() + delta_inputs.append(flatten_w) + delta_inputs = torch.stack(delta_inputs, dim=-1) + delta = self.delta_net(delta_inputs) + + delta = torch.clamp(delta, -0.8, 0.8) + unflatten_delta = container.unflatten(delta) + future_container = container.no_grad_clone().additive(unflatten_delta) + return future_container + + def step(self): + torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) + self.meta_optimizer.step() + + def zero_grad(self): + self.meta_optimizer.zero_grad() + self.delta_net.zero_grad() + + def state_dict(self): + return dict( + delta_net=self.delta_net.state_dict(), + meta_optimizer=self.meta_optimizer.state_dict(), + ) + + +def main(args): + logger, env_info, model_kwargs = lfna_setup(args) + dynamic_env = env_info["dynamic_env"] + model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + + total_time = env_info["total"] + for i in range(total_time): + for xkey in ("timestamp", "x", "y"): + nkey = "{:}-{:}".format(i, xkey) + assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) + train_time_bar = total_time // 2 + network = get_model(dict(model_type="simple_mlp"), **model_kwargs) + + criterion = torch.nn.MSELoss() + logger.log("There are {:} weights.".format(network.get_w_container().numel())) + + adaptor = LFNAmlp(1 + args.meta_seq, (20, 20), "leaky_relu", criterion) + + # pre-train the model + init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) + init_loss = train_model(network, init_dataset, args.init_lr, args.epochs) + logger.log("The pre-training loss is {:.4f}".format(init_loss)) + + # LFNA meta-training + meta_loss_meter = AverageMeter() + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(args.epochs): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + logger.log( + "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + + need_time + ) + + adaptor.zero_grad() + + batch_indexes, meta_losses = [], [] + for ibatch in range(args.meta_batch): + sampled_timestamp = random.random() * train_time_bar + batch_indexes.append("{:.3f}".format(sampled_timestamp)) + seq_datasets = [] + for iseq in range(args.meta_seq + 1): + cur_time = sampled_timestamp + iseq * dynamic_env.timestamp_interval + cur_time, (x, y) = dynamic_env(cur_time) + seq_datasets.append(TimeData(cur_time, x, y)) + history_datasets, future_dataset = seq_datasets[:-1], seq_datasets[-1] + future_container = adaptor.adapt(network, history_datasets) + future_y_hat = network.forward_with_container( + future_dataset.x, future_container + ) + future_loss = adaptor.criterion(future_y_hat, future_dataset.y) + meta_losses.append(future_loss) + meta_loss = torch.stack(meta_losses).mean() + meta_loss.backward() + adaptor.step() + + meta_loss_meter.update(meta_loss.item()) + + logger.log( + "meta-loss: {:.4f} ({:.4f}) batch: {:}".format( + meta_loss_meter.avg, meta_loss_meter.val, ",".join(batch_indexes[:5]) + ) + ) + if iepoch % 200 == 0: + save_checkpoint( + {"adaptor": adaptor.state_dict(), "iepoch": iepoch}, + logger.path("model"), + logger, + ) + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + w_container_per_epoch = dict() + for idx in range(1, env_info["total"]): + future_time = env_info["{:}-timestamp".format(idx)] + future_x = env_info["{:}-x".format(idx)] + future_y = env_info["{:}-y".format(idx)] + seq_datasets = [] + for iseq in range(1, args.meta_seq + 1): + cur_time = future_time - iseq * dynamic_env.timestamp_interval + cur_time, (x, y) = dynamic_env(cur_time) + seq_datasets.append(TimeData(cur_time, x, y)) + seq_datasets.reverse() + future_container = adaptor.adapt(network, seq_datasets) + w_container_per_epoch[idx] = future_container.no_grad_clone() + with torch.no_grad(): + future_y_hat = network.forward_with_container( + future_x, w_container_per_epoch[idx] + ) + future_loss = adaptor.criterion(future_y_hat, future_y) + logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) + + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the data in the past.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/lfna-fix-init", + help="The checkpoint directory.", + ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + required=True, + help="The hidden dimension.", + ) + ##### + parser.add_argument( + "--init_lr", + type=float, + default=0.1, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--meta_batch", + type=int, + default=32, + help="The batch size for the meta-model", + ) + parser.add_argument( + "--meta_seq", + type=int, + default=10, + help="The length of the sequence for meta-model.", + ) + parser.add_argument( + "--epochs", + type=int, + default=1000, + help="The total number of epochs.", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + args.save_dir = "{:}-{:}-d{:}".format( + args.save_dir, args.env_version, args.hidden_dim + ) + main(args) diff --git a/exps/LFNA/lfna-v0.py b/exps/LFNA/lfna-v0.py deleted file mode 100644 index e3f937b..0000000 --- a/exps/LFNA/lfna-v0.py +++ /dev/null @@ -1,272 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -# python exps/LFNA/lfna-v0.py -##################################################### -import sys, time, copy, torch, random, argparse -from tqdm import tqdm -from copy import deepcopy -from pathlib import Path - -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from log_utils import time_string -from log_utils import AverageMeter, convert_secs2time - -from utils import split_str2indexes - -from procedures.advanced_main import basic_train_fn, basic_eval_fn -from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env -from models.xcore import get_model -from xlayers import super_core - - -class LFNAmlp: - """A LFNA meta-model that uses the MLP as delta-net.""" - - def __init__(self, obs_dim, hidden_sizes, act_name): - self.delta_net = super_core.SuperSequential( - super_core.SuperLinear(obs_dim, hidden_sizes[0]), - super_core.super_name2activation[act_name](), - super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), - super_core.super_name2activation[act_name](), - super_core.SuperLinear(hidden_sizes[1], 1), - ) - self.meta_optimizer = torch.optim.Adam( - self.delta_net.parameters(), lr=0.01, amsgrad=True - ) - - def adapt(self, model, criterion, w_container, seq_datasets): - w_container.requires_grad_(True) - containers = [w_container] - for idx, dataset in enumerate(seq_datasets): - x, y = dataset.x, dataset.y - y_hat = model.forward_with_container(x, containers[-1]) - loss = criterion(y_hat, y) - gradients = torch.autograd.grad(loss, containers[-1].tensors) - with torch.no_grad(): - flatten_w = containers[-1].flatten().view(-1, 1) - flatten_g = containers[-1].flatten(gradients).view(-1, 1) - input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) - input_statistics = input_statistics.expand(flatten_w.numel(), -1) - delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) - delta = self.delta_net(delta_inputs).view(-1) - delta = torch.clamp(delta, -0.5, 0.5) - unflatten_delta = containers[-1].unflatten(delta) - future_container = containers[-1].no_grad_clone().additive(unflatten_delta) - # future_container = containers[-1].additive(unflatten_delta) - containers.append(future_container) - # containers = containers[1:] - meta_loss = [] - temp_containers = [] - for idx, dataset in enumerate(seq_datasets): - if idx == 0: - continue - current_container = containers[idx] - y_hat = model.forward_with_container(dataset.x, current_container) - loss = criterion(y_hat, dataset.y) - meta_loss.append(loss) - temp_containers.append((dataset.timestamp, current_container, -loss.item())) - meta_loss = sum(meta_loss) - w_container.requires_grad_(False) - # meta_loss.backward() - # self.meta_optimizer.step() - return meta_loss, temp_containers - - def step(self): - torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) - self.meta_optimizer.step() - - def zero_grad(self): - self.meta_optimizer.zero_grad() - self.delta_net.zero_grad() - - -class TimeData: - def __init__(self, timestamp, xs, ys): - self._timestamp = timestamp - self._xs = xs - self._ys = ys - - @property - def x(self): - return self._xs - - @property - def y(self): - return self._ys - - @property - def timestamp(self): - return self._timestamp - - -class Population: - """A population used to maintain models at different timestamps.""" - - def __init__(self): - self._time2model = dict() - self._time2score = dict() # higher is better - - def append(self, timestamp, model, score): - if timestamp in self._time2model: - if self._time2score[timestamp] > score: - return - self._time2model[timestamp] = model.no_grad_clone() - self._time2score[timestamp] = score - - def query(self, timestamp): - closet_timestamp = None - for xtime, model in self._time2model.items(): - if closet_timestamp is None or ( - xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime - ): - closet_timestamp = xtime - return self._time2model[closet_timestamp], closet_timestamp - - def debug_info(self, timestamps): - xstrs = [] - for timestamp in timestamps: - if timestamp in self._time2score: - xstrs.append( - "{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp]) - ) - return ", ".join(xstrs) - - -def main(args): - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() - if cache_path.exists(): - env_info = torch.load(cache_path) - else: - env_info = dict() - dynamic_env = get_synthetic_env() - env_info["total"] = len(dynamic_env) - for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): - env_info["{:}-timestamp".format(idx)] = timestamp - env_info["{:}-x".format(idx)] = _allx - env_info["{:}-y".format(idx)] = _ally - env_info["dynamic_env"] = dynamic_env - torch.save(env_info, cache_path) - - total_time = env_info["total"] - for i in range(total_time): - for xkey in ("timestamp", "x", "y"): - nkey = "{:}-{:}".format(i, xkey) - assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) - train_time_bar = total_time // 2 - base_model = get_model( - dict(model_type="simple_mlp"), - act_cls="leaky_relu", - norm_cls="identity", - input_dim=1, - output_dim=1, - ) - - w_container = base_model.get_w_container() - criterion = torch.nn.MSELoss() - print("There are {:} weights.".format(w_container.numel())) - - adaptor = LFNAmlp(4, (50, 20), "leaky_relu") - - pool = Population() - pool.append(0, w_container, -100) - - # LFNA meta-training - per_epoch_time, start_time = AverageMeter(), time.time() - for iepoch in range(args.epochs): - - need_time = "Time Left: {:}".format( - convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) - ) - logger.log( - "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) - + need_time - ) - - adaptor.zero_grad() - - debug_timestamp = set() - all_meta_losses = [] - for ibatch in range(args.meta_batch): - sampled_timestamp = random.randint(0, train_time_bar) - query_w_container, query_timestamp = pool.query(sampled_timestamp) - # def adapt(self, model, w_container, xs, ys): - seq_datasets = [] - # xs, ys = [], [] - for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): - xs = env_info["{:}-x".format(it)] - ys = env_info["{:}-y".format(it)] - seq_datasets.append(TimeData(it, xs, ys)) - temp_meta_loss, temp_containers = adaptor.adapt( - base_model, criterion, query_w_container, seq_datasets - ) - all_meta_losses.append(temp_meta_loss) - for temp_time, temp_container, temp_score in temp_containers: - pool.append(temp_time, temp_container, temp_score) - debug_timestamp.add(temp_time) - meta_loss = torch.stack(all_meta_losses).mean() - meta_loss.backward() - adaptor.step() - - debug_str = pool.debug_info(debug_timestamp) - logger.log("meta-loss: {:.4f}".format(meta_loss.item())) - - per_epoch_time.update(time.time() - start_time) - start_time = time.time() - - logger.log("-" * 200 + "\n") - logger.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Use the data in the past.") - parser.add_argument( - "--save_dir", - type=str, - default="./outputs/lfna-synthetic/lfna-v1", - help="The checkpoint directory.", - ) - parser.add_argument( - "--init_lr", - type=float, - default=0.1, - help="The initial learning rate for the optimizer (default is Adam)", - ) - parser.add_argument( - "--meta_batch", - type=int, - default=5, - help="The batch size for the meta-model", - ) - parser.add_argument( - "--epochs", - type=int, - default=1000, - help="The total number of epochs.", - ) - parser.add_argument( - "--max_seq", - type=int, - default=5, - help="The maximum length of the sequence.", - ) - parser.add_argument( - "--workers", - type=int, - default=4, - help="The number of data loading workers (default: 4)", - ) - # Random Seed - parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, "The save dir argument can not be None" - main(args) diff --git a/exps/LFNA/lfna_utils.py b/exps/LFNA/lfna_utils.py index a46854c..067b47b 100644 --- a/exps/LFNA/lfna_utils.py +++ b/exps/LFNA/lfna_utils.py @@ -1,6 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### +import copy import torch from tqdm import tqdm from procedures import prepare_seed, prepare_logger @@ -37,6 +38,24 @@ def lfna_setup(args): return logger, env_info, model_kwargs +def train_model(model, dataset, lr, epochs): + criterion = torch.nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True) + best_loss, best_param = None, None + for _iepoch in range(epochs): + preds = model(dataset.x) + optimizer.zero_grad() + loss = criterion(preds, dataset.y) + loss.backward() + optimizer.step() + # save best + if best_loss is None or best_loss > loss.item(): + best_loss = loss.item() + best_param = copy.deepcopy(model.state_dict()) + model.load_state_dict(best_param) + return best_loss + + class TimeData: def __init__(self, timestamp, xs, ys): self._timestamp = timestamp @@ -56,6 +75,6 @@ class TimeData: return self._timestamp def __repr__(self): - return "{name}(timestamp={:}, with {num} samples)".format( + return "{name}(timestamp={timestamp}, with {num} samples)".format( name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs) ) diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 432bb65..395d760 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -237,6 +237,8 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): alg_name2dir["Optimal"] = "use-same-timestamp" alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" alg_name2dir["MAML"] = "use-maml-s1" + alg_name2dir["LFNA (fix init)"] = "lfna-fix-init" + alg_name2dir["LFNA (debug)"] = "lfna-debug" alg_name2all_containers = OrderedDict() if version == "v1": poststr = "v1-d16" @@ -256,7 +258,7 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): ) alg2xs, alg2ys = defaultdict(list), defaultdict(list) - colors = ["r", "g", "b"] + colors = ["r", "g", "b", "m", "y"] dynamic_env = env_info["dynamic_env"] min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp diff --git a/lib/datasets/synthetic_env.py b/lib/datasets/synthetic_env.py index 6d24036..6c2c6ed 100644 --- a/lib/datasets/synthetic_env.py +++ b/lib/datasets/synthetic_env.py @@ -51,6 +51,10 @@ class SyntheticDEnv(data.Dataset): def max_timestamp(self): return self._timestamp_generator.max_timestamp + @property + def timestamp_interval(self): + return self._timestamp_generator.interval + def set_oracle_map(self, functor): self._oracle_map = functor @@ -67,6 +71,9 @@ class SyntheticDEnv(data.Dataset): def __getitem__(self, index): assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) index, timestamp = self._timestamp_generator[index] + return self.__call__(timestamp) + + def __call__(self, timestamp): mean_list = [functor(timestamp) for functor in self._mean_functors] cov_matrix = [ [abs(cov_gen(timestamp)) for cov_gen in cov_functor] diff --git a/lib/datasets/synthetic_utils.py b/lib/datasets/synthetic_utils.py index e187c9b..7c95d4b 100644 --- a/lib/datasets/synthetic_utils.py +++ b/lib/datasets/synthetic_utils.py @@ -60,6 +60,10 @@ class TimeStamp(UnifiedSplit, data.Dataset): @property def max_timestamp(self): return self._max_timestamp + + @property + def interval(self): + return self._interval def __iter__(self): self._iter_num = 0 diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index 091be02..5a85c51 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -46,6 +46,13 @@ class TensorContainer: result.append(name, new_tensor, self._param_or_buffers[index]) return result + def create_container(self, tensors): + result = TensorContainer() + for index, name in enumerate(self._names): + new_tensor = tensors[index] + result.append(name, new_tensor, self._param_or_buffers[index]) + return result + def no_grad_clone(self): result = TensorContainer() with torch.no_grad():