Update LFNA version 1.0
This commit is contained in:
		| @@ -86,9 +86,10 @@ def main(args): | ||||
|             input_dim=1, | ||||
|             output_dim=1, | ||||
|             act_cls="leaky_relu", | ||||
|             norm_cls="simple_norm", | ||||
|             mean=mean, | ||||
|             std=std, | ||||
|             norm_cls="identity", | ||||
|             # norm_cls="simple_norm", | ||||
|             # mean=mean, | ||||
|             # std=std, | ||||
|         ) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
|   | ||||
| @@ -58,6 +58,8 @@ def main(args): | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     w_container_per_epoch = dict() | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for i, idx in enumerate(to_evaluate_indexes): | ||||
|  | ||||
| @@ -73,7 +75,6 @@ def main(args): | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         assert idx != 0 | ||||
|         historical_x = env_info["{:}-x".format(idx)] | ||||
|         historical_y = env_info["{:}-y".format(idx)] | ||||
|         # build model | ||||
| @@ -82,9 +83,10 @@ def main(args): | ||||
|             input_dim=1, | ||||
|             output_dim=1, | ||||
|             act_cls="leaky_relu", | ||||
|             norm_cls="simple_norm", | ||||
|             mean=mean, | ||||
|             std=std, | ||||
|             norm_cls="identity", | ||||
|             # norm_cls="simple_norm", | ||||
|             # mean=mean, | ||||
|             # std=std, | ||||
|         ) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
| @@ -137,6 +139,7 @@ def main(args): | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
| @@ -151,6 +154,11 @@ def main(args): | ||||
|  | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|   | ||||
| @@ -39,9 +39,11 @@ class LFNAmlp: | ||||
|             self.delta_net.parameters(), lr=0.01, amsgrad=True | ||||
|         ) | ||||
|  | ||||
|     def adapt(self, model, criterion, w_container, xs, ys): | ||||
|     def adapt(self, model, criterion, w_container, seq_datasets): | ||||
|         w_container.requires_grad_(True) | ||||
|         containers = [w_container] | ||||
|         for idx, (x, y) in enumerate(zip(xs, ys)): | ||||
|         for idx, dataset in enumerate(seq_datasets): | ||||
|             x, y = dataset.x, dataset.y | ||||
|             y_hat = model.forward_with_container(x, containers[-1]) | ||||
|             loss = criterion(y_hat, y) | ||||
|             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||
| @@ -52,21 +54,30 @@ class LFNAmlp: | ||||
|                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||
|             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||
|             delta = self.delta_net(delta_inputs).view(-1) | ||||
|             # delta = torch.clamp(delta, -0.5, 0.5) | ||||
|             delta = torch.clamp(delta, -0.5, 0.5) | ||||
|             unflatten_delta = containers[-1].unflatten(delta) | ||||
|             future_container = containers[-1].additive(unflatten_delta) | ||||
|             future_container = containers[-1].no_grad_clone().additive(unflatten_delta) | ||||
|             # future_container = containers[-1].additive(unflatten_delta) | ||||
|             containers.append(future_container) | ||||
|         # containers = containers[1:] | ||||
|         meta_loss = [] | ||||
|         for idx, (x, y) in enumerate(zip(xs, ys)): | ||||
|         temp_containers = [] | ||||
|         for idx, dataset in enumerate(seq_datasets): | ||||
|             if idx == 0: | ||||
|                 continue | ||||
|             current_container = containers[idx] | ||||
|             y_hat = model.forward_with_container(x, current_container) | ||||
|             loss = criterion(y_hat, y) | ||||
|             y_hat = model.forward_with_container(dataset.x, current_container) | ||||
|             loss = criterion(y_hat, dataset.y) | ||||
|             meta_loss.append(loss) | ||||
|             temp_containers.append((dataset.timestamp, current_container, -loss.item())) | ||||
|         meta_loss = sum(meta_loss) | ||||
|         meta_loss.backward() | ||||
|         w_container.requires_grad_(False) | ||||
|         # meta_loss.backward() | ||||
|         # self.meta_optimizer.step() | ||||
|         return meta_loss, temp_containers | ||||
|  | ||||
|     def step(self): | ||||
|         torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) | ||||
|         self.meta_optimizer.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
| @@ -74,6 +85,25 @@ class LFNAmlp: | ||||
|         self.delta_net.zero_grad() | ||||
|  | ||||
|  | ||||
| class TimeData: | ||||
|     def __init__(self, timestamp, xs, ys): | ||||
|         self._timestamp = timestamp | ||||
|         self._xs = xs | ||||
|         self._ys = ys | ||||
|  | ||||
|     @property | ||||
|     def x(self): | ||||
|         return self._xs | ||||
|  | ||||
|     @property | ||||
|     def y(self): | ||||
|         return self._ys | ||||
|  | ||||
|     @property | ||||
|     def timestamp(self): | ||||
|         return self._timestamp | ||||
|  | ||||
|  | ||||
| class Population: | ||||
|     """A population used to maintain models at different timestamps.""" | ||||
|  | ||||
| @@ -83,20 +113,29 @@ class Population: | ||||
|  | ||||
|     def append(self, timestamp, model, score): | ||||
|         if timestamp in self._time2model: | ||||
|             raise ValueError("This timestamp has been added.") | ||||
|         self._time2model[timestamp] = model | ||||
|             if self._time2score[timestamp] > score: | ||||
|                 return | ||||
|         self._time2model[timestamp] = model.no_grad_clone() | ||||
|         self._time2score[timestamp] = score | ||||
|  | ||||
|     def query(self, timestamp): | ||||
|         closet_timestamp = None | ||||
|         for xtime, model in self._time2model.items(): | ||||
|             if ( | ||||
|                 closet_timestamp is None | ||||
|                 or timestamp - closet_timestamp >= timestamp - xtime | ||||
|             if closet_timestamp is None or ( | ||||
|                 xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime | ||||
|             ): | ||||
|                 closet_timestamp = xtime | ||||
|         return self._time2model[closet_timestamp], closet_timestamp | ||||
|  | ||||
|     def debug_info(self, timestamps): | ||||
|         xstrs = [] | ||||
|         for timestamp in timestamps: | ||||
|             if timestamp in self._time2score: | ||||
|                 xstrs.append( | ||||
|                     "{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp]) | ||||
|                 ) | ||||
|         return ", ".join(xstrs) | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
| @@ -125,21 +164,19 @@ def main(args): | ||||
|     base_model = get_model( | ||||
|         dict(model_type="simple_mlp"), | ||||
|         act_cls="leaky_relu", | ||||
|         norm_cls="simple_learn_norm", | ||||
|         mean=0, | ||||
|         std=1, | ||||
|         norm_cls="identity", | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|     ) | ||||
|  | ||||
|     w_container = base_model.named_parameters_buffers() | ||||
|     w_container = base_model.get_w_container() | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     print("There are {:} weights.".format(w_container.numel())) | ||||
|  | ||||
|     adaptor = LFNAmlp(4, (50, 20), "leaky_relu") | ||||
|  | ||||
|     pool = Population() | ||||
|     pool.append(0, w_container) | ||||
|     pool.append(0, w_container, -100) | ||||
|  | ||||
|     # LFNA meta-training | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
| @@ -153,22 +190,35 @@ def main(args): | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         adaptor.zero_grad() | ||||
|  | ||||
|         debug_timestamp = set() | ||||
|         all_meta_losses = [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             sampled_timestamp = random.randint(0, train_time_bar) | ||||
|             query_w_container, query_timestamp = pool.query(sampled_timestamp) | ||||
|             # def adapt(self, model, w_container, xs, ys): | ||||
|             xs, ys = [], [] | ||||
|             seq_datasets = [] | ||||
|             # xs, ys = [], [] | ||||
|             for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): | ||||
|                 xs.append(env_info["{:}-x".format(it)]) | ||||
|                 ys.append(env_info["{:}-y".format(it)]) | ||||
|             adaptor.adapt(base_model, criterion, query_w_container, xs, ys) | ||||
|             import pdb | ||||
|                 xs = env_info["{:}-x".format(it)] | ||||
|                 ys = env_info["{:}-y".format(it)] | ||||
|                 seq_datasets.append(TimeData(it, xs, ys)) | ||||
|             temp_meta_loss, temp_containers = adaptor.adapt( | ||||
|                 base_model, criterion, query_w_container, seq_datasets | ||||
|             ) | ||||
|             all_meta_losses.append(temp_meta_loss) | ||||
|             for temp_time, temp_container, temp_score in temp_containers: | ||||
|                 pool.append(temp_time, temp_container, temp_score) | ||||
|                 debug_timestamp.add(temp_time) | ||||
|         meta_loss = torch.stack(all_meta_losses).mean() | ||||
|         meta_loss.backward() | ||||
|         adaptor.step() | ||||
|  | ||||
|             pdb.set_trace() | ||||
|         print("-") | ||||
|         logger.log("") | ||||
|         debug_str = pool.debug_info(debug_timestamp) | ||||
|         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||
|  | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
| @@ -192,7 +242,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         default=5, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
| @@ -23,7 +23,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
|  | ||||
| from models.xcore import get_model | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from datasets.synthetic_example import create_example_v1 | ||||
| from utils.temp_sync import optimize_fn, evaluate_fn | ||||
| @@ -300,8 +300,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|  | ||||
|     alg_name2dir = OrderedDict() | ||||
|     alg_name2dir["Optimal"] = "use-same-timestamp" | ||||
|     alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||
|     colors = ["r", "g"] | ||||
|     # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||
|     alg_name2all_containers = OrderedDict() | ||||
|     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|         ckp_path = Path(alg_dir) / xdir / "final-ckp.pth" | ||||
|         xdata = torch.load(ckp_path) | ||||
|         alg_name2all_containers[alg] = xdata["w_container_per_epoch"] | ||||
|     # load the basic model | ||||
|     model = get_model( | ||||
|         dict(model_type="simple_mlp"), | ||||
|         act_cls="leaky_relu", | ||||
|         norm_cls="identity", | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|     ) | ||||
|  | ||||
|     alg2xs, alg2ys = defaultdict(list), defaultdict(list) | ||||
|     colors = ["r", "g"] | ||||
| @@ -323,6 +335,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data") | ||||
|  | ||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||
|             """ | ||||
|             ckp_path = ( | ||||
|                 Path(alg_dir) | ||||
|                 / xdir | ||||
| @@ -330,8 +343,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | ||||
|             ) | ||||
|             assert ckp_path.exists() | ||||
|             ckp_data = torch.load(ckp_path) | ||||
|             """ | ||||
|             with torch.no_grad(): | ||||
|                 predicts = ckp_data["model"](ori_allx) | ||||
|                 # predicts = ckp_data["model"](ori_allx) | ||||
|                 predicts = model.forward_with_container( | ||||
|                     ori_allx, alg_name2all_containers[alg][idx] | ||||
|                 ) | ||||
|                 predicts = predicts.cpu() | ||||
|                 # keep data | ||||
|                 metric = MSEMetric() | ||||
|   | ||||
| @@ -55,6 +55,10 @@ class TensorContainer: | ||||
|                 ) | ||||
|         return result | ||||
|  | ||||
|     def requires_grad_(self, requires_grad=True): | ||||
|         for tensor in self._tensors: | ||||
|           tensor.requires_grad_(requires_grad) | ||||
|  | ||||
|     @property | ||||
|     def tensors(self): | ||||
|         return self._tensors | ||||
| @@ -162,7 +166,7 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|             ) | ||||
|         self._abstract_child = abstract_child | ||||
|  | ||||
|     def named_parameters_buffers(self): | ||||
|     def get_w_container(self): | ||||
|         container = TensorContainer() | ||||
|         for name, param in self.named_parameters(): | ||||
|             container.append(name, param, True) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user