Update LFNA version 1.0
This commit is contained in:
		| @@ -86,9 +86,10 @@ def main(args): | |||||||
|             input_dim=1, |             input_dim=1, | ||||||
|             output_dim=1, |             output_dim=1, | ||||||
|             act_cls="leaky_relu", |             act_cls="leaky_relu", | ||||||
|             norm_cls="simple_norm", |             norm_cls="identity", | ||||||
|             mean=mean, |             # norm_cls="simple_norm", | ||||||
|             std=std, |             # mean=mean, | ||||||
|  |             # std=std, | ||||||
|         ) |         ) | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|   | |||||||
| @@ -58,6 +58,8 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     w_container_per_epoch = dict() | ||||||
|  |  | ||||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|     for i, idx in enumerate(to_evaluate_indexes): |     for i, idx in enumerate(to_evaluate_indexes): | ||||||
|  |  | ||||||
| @@ -73,7 +75,6 @@ def main(args): | |||||||
|             + need_time |             + need_time | ||||||
|         ) |         ) | ||||||
|         # train the same data |         # train the same data | ||||||
|         assert idx != 0 |  | ||||||
|         historical_x = env_info["{:}-x".format(idx)] |         historical_x = env_info["{:}-x".format(idx)] | ||||||
|         historical_y = env_info["{:}-y".format(idx)] |         historical_y = env_info["{:}-y".format(idx)] | ||||||
|         # build model |         # build model | ||||||
| @@ -82,9 +83,10 @@ def main(args): | |||||||
|             input_dim=1, |             input_dim=1, | ||||||
|             output_dim=1, |             output_dim=1, | ||||||
|             act_cls="leaky_relu", |             act_cls="leaky_relu", | ||||||
|             norm_cls="simple_norm", |             norm_cls="identity", | ||||||
|             mean=mean, |             # norm_cls="simple_norm", | ||||||
|             std=std, |             # mean=mean, | ||||||
|  |             # std=std, | ||||||
|         ) |         ) | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|         # build optimizer |         # build optimizer | ||||||
| @@ -137,6 +139,7 @@ def main(args): | |||||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||||
|             idx, env_info["total"] |             idx, env_info["total"] | ||||||
|         ) |         ) | ||||||
|  |         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||||
|         save_checkpoint( |         save_checkpoint( | ||||||
|             { |             { | ||||||
|                 "model_state_dict": model.state_dict(), |                 "model_state_dict": model.state_dict(), | ||||||
| @@ -151,6 +154,11 @@ def main(args): | |||||||
|  |  | ||||||
|         per_timestamp_time.update(time.time() - start_time) |         per_timestamp_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |     save_checkpoint( | ||||||
|  |         {"w_container_per_epoch": w_container_per_epoch}, | ||||||
|  |         logger.path(None) / "final-ckp.pth", | ||||||
|  |         logger, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
|     logger.close() |     logger.close() | ||||||
|   | |||||||
| @@ -39,9 +39,11 @@ class LFNAmlp: | |||||||
|             self.delta_net.parameters(), lr=0.01, amsgrad=True |             self.delta_net.parameters(), lr=0.01, amsgrad=True | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def adapt(self, model, criterion, w_container, xs, ys): |     def adapt(self, model, criterion, w_container, seq_datasets): | ||||||
|  |         w_container.requires_grad_(True) | ||||||
|         containers = [w_container] |         containers = [w_container] | ||||||
|         for idx, (x, y) in enumerate(zip(xs, ys)): |         for idx, dataset in enumerate(seq_datasets): | ||||||
|  |             x, y = dataset.x, dataset.y | ||||||
|             y_hat = model.forward_with_container(x, containers[-1]) |             y_hat = model.forward_with_container(x, containers[-1]) | ||||||
|             loss = criterion(y_hat, y) |             loss = criterion(y_hat, y) | ||||||
|             gradients = torch.autograd.grad(loss, containers[-1].tensors) |             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||||
| @@ -52,21 +54,30 @@ class LFNAmlp: | |||||||
|                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) |                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||||
|             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) |             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||||
|             delta = self.delta_net(delta_inputs).view(-1) |             delta = self.delta_net(delta_inputs).view(-1) | ||||||
|             # delta = torch.clamp(delta, -0.5, 0.5) |             delta = torch.clamp(delta, -0.5, 0.5) | ||||||
|             unflatten_delta = containers[-1].unflatten(delta) |             unflatten_delta = containers[-1].unflatten(delta) | ||||||
|             future_container = containers[-1].additive(unflatten_delta) |             future_container = containers[-1].no_grad_clone().additive(unflatten_delta) | ||||||
|  |             # future_container = containers[-1].additive(unflatten_delta) | ||||||
|             containers.append(future_container) |             containers.append(future_container) | ||||||
|         # containers = containers[1:] |         # containers = containers[1:] | ||||||
|         meta_loss = [] |         meta_loss = [] | ||||||
|         for idx, (x, y) in enumerate(zip(xs, ys)): |         temp_containers = [] | ||||||
|  |         for idx, dataset in enumerate(seq_datasets): | ||||||
|             if idx == 0: |             if idx == 0: | ||||||
|                 continue |                 continue | ||||||
|             current_container = containers[idx] |             current_container = containers[idx] | ||||||
|             y_hat = model.forward_with_container(x, current_container) |             y_hat = model.forward_with_container(dataset.x, current_container) | ||||||
|             loss = criterion(y_hat, y) |             loss = criterion(y_hat, dataset.y) | ||||||
|             meta_loss.append(loss) |             meta_loss.append(loss) | ||||||
|  |             temp_containers.append((dataset.timestamp, current_container, -loss.item())) | ||||||
|         meta_loss = sum(meta_loss) |         meta_loss = sum(meta_loss) | ||||||
|         meta_loss.backward() |         w_container.requires_grad_(False) | ||||||
|  |         # meta_loss.backward() | ||||||
|  |         # self.meta_optimizer.step() | ||||||
|  |         return meta_loss, temp_containers | ||||||
|  |  | ||||||
|  |     def step(self): | ||||||
|  |         torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) | ||||||
|         self.meta_optimizer.step() |         self.meta_optimizer.step() | ||||||
|  |  | ||||||
|     def zero_grad(self): |     def zero_grad(self): | ||||||
| @@ -74,6 +85,25 @@ class LFNAmlp: | |||||||
|         self.delta_net.zero_grad() |         self.delta_net.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TimeData: | ||||||
|  |     def __init__(self, timestamp, xs, ys): | ||||||
|  |         self._timestamp = timestamp | ||||||
|  |         self._xs = xs | ||||||
|  |         self._ys = ys | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def x(self): | ||||||
|  |         return self._xs | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def y(self): | ||||||
|  |         return self._ys | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def timestamp(self): | ||||||
|  |         return self._timestamp | ||||||
|  |  | ||||||
|  |  | ||||||
| class Population: | class Population: | ||||||
|     """A population used to maintain models at different timestamps.""" |     """A population used to maintain models at different timestamps.""" | ||||||
|  |  | ||||||
| @@ -83,20 +113,29 @@ class Population: | |||||||
|  |  | ||||||
|     def append(self, timestamp, model, score): |     def append(self, timestamp, model, score): | ||||||
|         if timestamp in self._time2model: |         if timestamp in self._time2model: | ||||||
|             raise ValueError("This timestamp has been added.") |             if self._time2score[timestamp] > score: | ||||||
|         self._time2model[timestamp] = model |                 return | ||||||
|  |         self._time2model[timestamp] = model.no_grad_clone() | ||||||
|         self._time2score[timestamp] = score |         self._time2score[timestamp] = score | ||||||
|  |  | ||||||
|     def query(self, timestamp): |     def query(self, timestamp): | ||||||
|         closet_timestamp = None |         closet_timestamp = None | ||||||
|         for xtime, model in self._time2model.items(): |         for xtime, model in self._time2model.items(): | ||||||
|             if ( |             if closet_timestamp is None or ( | ||||||
|                 closet_timestamp is None |                 xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime | ||||||
|                 or timestamp - closet_timestamp >= timestamp - xtime |  | ||||||
|             ): |             ): | ||||||
|                 closet_timestamp = xtime |                 closet_timestamp = xtime | ||||||
|         return self._time2model[closet_timestamp], closet_timestamp |         return self._time2model[closet_timestamp], closet_timestamp | ||||||
|  |  | ||||||
|  |     def debug_info(self, timestamps): | ||||||
|  |         xstrs = [] | ||||||
|  |         for timestamp in timestamps: | ||||||
|  |             if timestamp in self._time2score: | ||||||
|  |                 xstrs.append( | ||||||
|  |                     "{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp]) | ||||||
|  |                 ) | ||||||
|  |         return ", ".join(xstrs) | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     prepare_seed(args.rand_seed) |     prepare_seed(args.rand_seed) | ||||||
| @@ -125,21 +164,19 @@ def main(args): | |||||||
|     base_model = get_model( |     base_model = get_model( | ||||||
|         dict(model_type="simple_mlp"), |         dict(model_type="simple_mlp"), | ||||||
|         act_cls="leaky_relu", |         act_cls="leaky_relu", | ||||||
|         norm_cls="simple_learn_norm", |         norm_cls="identity", | ||||||
|         mean=0, |  | ||||||
|         std=1, |  | ||||||
|         input_dim=1, |         input_dim=1, | ||||||
|         output_dim=1, |         output_dim=1, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     w_container = base_model.named_parameters_buffers() |     w_container = base_model.get_w_container() | ||||||
|     criterion = torch.nn.MSELoss() |     criterion = torch.nn.MSELoss() | ||||||
|     print("There are {:} weights.".format(w_container.numel())) |     print("There are {:} weights.".format(w_container.numel())) | ||||||
|  |  | ||||||
|     adaptor = LFNAmlp(4, (50, 20), "leaky_relu") |     adaptor = LFNAmlp(4, (50, 20), "leaky_relu") | ||||||
|  |  | ||||||
|     pool = Population() |     pool = Population() | ||||||
|     pool.append(0, w_container) |     pool.append(0, w_container, -100) | ||||||
|  |  | ||||||
|     # LFNA meta-training |     # LFNA meta-training | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
| @@ -153,22 +190,35 @@ def main(args): | |||||||
|             + need_time |             + need_time | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |         adaptor.zero_grad() | ||||||
|  |  | ||||||
|  |         debug_timestamp = set() | ||||||
|  |         all_meta_losses = [] | ||||||
|         for ibatch in range(args.meta_batch): |         for ibatch in range(args.meta_batch): | ||||||
|             sampled_timestamp = random.randint(0, train_time_bar) |             sampled_timestamp = random.randint(0, train_time_bar) | ||||||
|             query_w_container, query_timestamp = pool.query(sampled_timestamp) |             query_w_container, query_timestamp = pool.query(sampled_timestamp) | ||||||
|             # def adapt(self, model, w_container, xs, ys): |             # def adapt(self, model, w_container, xs, ys): | ||||||
|             xs, ys = [], [] |             seq_datasets = [] | ||||||
|  |             # xs, ys = [], [] | ||||||
|             for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): |             for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): | ||||||
|                 xs.append(env_info["{:}-x".format(it)]) |                 xs = env_info["{:}-x".format(it)] | ||||||
|                 ys.append(env_info["{:}-y".format(it)]) |                 ys = env_info["{:}-y".format(it)] | ||||||
|             adaptor.adapt(base_model, criterion, query_w_container, xs, ys) |                 seq_datasets.append(TimeData(it, xs, ys)) | ||||||
|             import pdb |             temp_meta_loss, temp_containers = adaptor.adapt( | ||||||
|  |                 base_model, criterion, query_w_container, seq_datasets | ||||||
|  |             ) | ||||||
|  |             all_meta_losses.append(temp_meta_loss) | ||||||
|  |             for temp_time, temp_container, temp_score in temp_containers: | ||||||
|  |                 pool.append(temp_time, temp_container, temp_score) | ||||||
|  |                 debug_timestamp.add(temp_time) | ||||||
|  |         meta_loss = torch.stack(all_meta_losses).mean() | ||||||
|  |         meta_loss.backward() | ||||||
|  |         adaptor.step() | ||||||
|  |  | ||||||
|             pdb.set_trace() |         debug_str = pool.debug_info(debug_timestamp) | ||||||
|         print("-") |         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||||
|         logger.log("") |  | ||||||
|  |  | ||||||
|         per_timestamp_time.update(time.time() - start_time) |         per_epoch_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
| @@ -192,7 +242,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", |         "--meta_batch", | ||||||
|         type=int, |         type=int, | ||||||
|         default=2, |         default=5, | ||||||
|         help="The batch size for the meta-model", |         help="The batch size for the meta-model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from models.xcore import get_model | ||||||
| from datasets.synthetic_core import get_synthetic_env | from datasets.synthetic_core import get_synthetic_env | ||||||
| from datasets.synthetic_example import create_example_v1 | from datasets.synthetic_example import create_example_v1 | ||||||
| from utils.temp_sync import optimize_fn, evaluate_fn | from utils.temp_sync import optimize_fn, evaluate_fn | ||||||
| @@ -300,8 +300,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|  |  | ||||||
|     alg_name2dir = OrderedDict() |     alg_name2dir = OrderedDict() | ||||||
|     alg_name2dir["Optimal"] = "use-same-timestamp" |     alg_name2dir["Optimal"] = "use-same-timestamp" | ||||||
|     alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" |     # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||||
|     colors = ["r", "g"] |     alg_name2all_containers = OrderedDict() | ||||||
|  |     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||||
|  |         ckp_path = Path(alg_dir) / xdir / "final-ckp.pth" | ||||||
|  |         xdata = torch.load(ckp_path) | ||||||
|  |         alg_name2all_containers[alg] = xdata["w_container_per_epoch"] | ||||||
|  |     # load the basic model | ||||||
|  |     model = get_model( | ||||||
|  |         dict(model_type="simple_mlp"), | ||||||
|  |         act_cls="leaky_relu", | ||||||
|  |         norm_cls="identity", | ||||||
|  |         input_dim=1, | ||||||
|  |         output_dim=1, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     alg2xs, alg2ys = defaultdict(list), defaultdict(list) |     alg2xs, alg2ys = defaultdict(list), defaultdict(list) | ||||||
|     colors = ["r", "g"] |     colors = ["r", "g"] | ||||||
| @@ -323,6 +335,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") |         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") | ||||||
|  |  | ||||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): |         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||||
|  |             """ | ||||||
|             ckp_path = ( |             ckp_path = ( | ||||||
|                 Path(alg_dir) |                 Path(alg_dir) | ||||||
|                 / xdir |                 / xdir | ||||||
| @@ -330,8 +343,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|             ) |             ) | ||||||
|             assert ckp_path.exists() |             assert ckp_path.exists() | ||||||
|             ckp_data = torch.load(ckp_path) |             ckp_data = torch.load(ckp_path) | ||||||
|  |             """ | ||||||
|             with torch.no_grad(): |             with torch.no_grad(): | ||||||
|                 predicts = ckp_data["model"](ori_allx) |                 # predicts = ckp_data["model"](ori_allx) | ||||||
|  |                 predicts = model.forward_with_container( | ||||||
|  |                     ori_allx, alg_name2all_containers[alg][idx] | ||||||
|  |                 ) | ||||||
|                 predicts = predicts.cpu() |                 predicts = predicts.cpu() | ||||||
|                 # keep data |                 # keep data | ||||||
|                 metric = MSEMetric() |                 metric = MSEMetric() | ||||||
|   | |||||||
| @@ -55,6 +55,10 @@ class TensorContainer: | |||||||
|                 ) |                 ) | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
|  |     def requires_grad_(self, requires_grad=True): | ||||||
|  |         for tensor in self._tensors: | ||||||
|  |           tensor.requires_grad_(requires_grad) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def tensors(self): |     def tensors(self): | ||||||
|         return self._tensors |         return self._tensors | ||||||
| @@ -162,7 +166,7 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|             ) |             ) | ||||||
|         self._abstract_child = abstract_child |         self._abstract_child = abstract_child | ||||||
|  |  | ||||||
|     def named_parameters_buffers(self): |     def get_w_container(self): | ||||||
|         container = TensorContainer() |         container = TensorContainer() | ||||||
|         for name, param in self.named_parameters(): |         for name, param in self.named_parameters(): | ||||||
|             container.append(name, param, True) |             container.append(name, param, True) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user