Prototype MAML
This commit is contained in:
		| @@ -1,8 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/basic-maml.py --env_version v1   # | ||||
| # python exps/LFNA/basic-maml.py --env_version v2   # | ||||
| # python exps/LFNA/basic-maml.py --env_version v1 --hidden_dim 16 --inner_step 5 | ||||
| # python exps/LFNA/basic-maml.py --env_version v2 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| @@ -30,73 +30,71 @@ from lfna_utils import lfna_setup, TimeData | ||||
| class MAML: | ||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||
|  | ||||
|     def __init__(self, container, criterion, meta_lr, inner_lr=0.01, inner_step=1): | ||||
|     def __init__( | ||||
|         self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1 | ||||
|     ): | ||||
|         self.criterion = criterion | ||||
|         self.container = container | ||||
|         # self.container = container | ||||
|         self.network = network | ||||
|         self.meta_optimizer = torch.optim.Adam( | ||||
|             self.container.parameters(), lr=meta_lr, amsgrad=True | ||||
|             self.network.parameters(), lr=meta_lr, amsgrad=True | ||||
|         ) | ||||
|         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(epochs * 0.25), | ||||
|                 int(epochs * 0.5), | ||||
|                 int(epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         self.inner_lr = inner_lr | ||||
|         self.inner_step = inner_step | ||||
|         self._best_info = dict(state_dict=None, score=None) | ||||
|         print("There are {:} weights.".format(w_container.numel())) | ||||
|  | ||||
|     def adapt(self, model, dataset): | ||||
|     def adapt(self, dataset): | ||||
|         # create a container for the future timestamp | ||||
|         y_hat = model.forward_with_container(dataset.x, self.container) | ||||
|         loss = self.criterion(y_hat, dataset.y) | ||||
|         grads = torch.autograd.grad(loss, self.container.parameters()) | ||||
|         container = self.network.get_w_container() | ||||
|  | ||||
|         fast_container = self.container.additive( | ||||
|             [-self.inner_lr * grad for grad in grads] | ||||
|         ) | ||||
|         import pdb | ||||
|         for k in range(0, self.inner_step): | ||||
|             y_hat = self.network.forward_with_container(dataset.x, container) | ||||
|             loss = self.criterion(y_hat, dataset.y) | ||||
|             grads = torch.autograd.grad(loss, container.parameters()) | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         w_container.requires_grad_(True) | ||||
|         containers = [w_container] | ||||
|         for idx, dataset in enumerate(seq_datasets): | ||||
|             x, y = dataset.x, dataset.y | ||||
|             y_hat = model.forward_with_container(x, containers[-1]) | ||||
|             loss = criterion(y_hat, y) | ||||
|             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||
|             with torch.no_grad(): | ||||
|                 flatten_w = containers[-1].flatten().view(-1, 1) | ||||
|                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) | ||||
|                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) | ||||
|                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||
|             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||
|             delta = self.delta_net(delta_inputs).view(-1) | ||||
|             delta = torch.clamp(delta, -0.5, 0.5) | ||||
|             unflatten_delta = containers[-1].unflatten(delta) | ||||
|             future_container = containers[-1].no_grad_clone().additive(unflatten_delta) | ||||
|             # future_container = containers[-1].additive(unflatten_delta) | ||||
|             containers.append(future_container) | ||||
|         # containers = containers[1:] | ||||
|         meta_loss = [] | ||||
|         temp_containers = [] | ||||
|         for idx, dataset in enumerate(seq_datasets): | ||||
|             if idx == 0: | ||||
|                 continue | ||||
|             current_container = containers[idx] | ||||
|             y_hat = model.forward_with_container(dataset.x, current_container) | ||||
|             loss = criterion(y_hat, dataset.y) | ||||
|             meta_loss.append(loss) | ||||
|             temp_containers.append((dataset.timestamp, current_container, -loss.item())) | ||||
|         meta_loss = sum(meta_loss) | ||||
|         w_container.requires_grad_(False) | ||||
|         # meta_loss.backward() | ||||
|         # self.meta_optimizer.step() | ||||
|         return meta_loss, temp_containers | ||||
|             container = container.additive([-self.inner_lr * grad for grad in grads]) | ||||
|         return container | ||||
|  | ||||
|     def predict(self, x, container=None): | ||||
|         if container is not None: | ||||
|             y_hat = self.network.forward_with_container(x, container) | ||||
|         else: | ||||
|             y_hat = self.network(x) | ||||
|         return y_hat | ||||
|  | ||||
|     def step(self): | ||||
|         torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) | ||||
|         torch.nn.utils.clip_grad_norm_(self.container.parameters(), 1.0) | ||||
|         self.meta_optimizer.step() | ||||
|         self.meta_lr_scheduler.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|  | ||||
|     def save_best(self, network, score): | ||||
|         if self._best_info["score"] is None or self._best_info["score"] < score: | ||||
|             state_dict = dict( | ||||
|                 criterion=criterion, | ||||
|                 network=network.state_dict(), | ||||
|                 meta_optimizer=self.meta_optimizer.state_dict(), | ||||
|                 meta_lr_scheduler=self.meta_lr_scheduler.state_dict(), | ||||
|             ) | ||||
|             self._best_info["state_dict"] = state_dict | ||||
|             self._best_info["score"] = score | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info = lfna_setup(args) | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|     model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|  | ||||
|     total_time = env_info["total"] | ||||
|     for i in range(total_time): | ||||
| @@ -104,19 +102,12 @@ def main(args): | ||||
|             nkey = "{:}-{:}".format(i, xkey) | ||||
|             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||
|     train_time_bar = total_time // 2 | ||||
|     base_model = get_model( | ||||
|         dict(model_type="simple_mlp"), | ||||
|         act_cls="leaky_relu", | ||||
|         norm_cls="identity", | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|     ) | ||||
|  | ||||
|     w_container = base_model.get_w_container() | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     print("There are {:} weights.".format(w_container.numel())) | ||||
|  | ||||
|     maml = MAML(w_container, criterion, args.meta_lr, args.inner_lr, args.inner_step) | ||||
|     maml = MAML( | ||||
|         model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step | ||||
|     ) | ||||
|  | ||||
|     # meta-training | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
| @@ -131,8 +122,7 @@ def main(args): | ||||
|         ) | ||||
|  | ||||
|         maml.zero_grad() | ||||
|  | ||||
|         all_meta_losses = [] | ||||
|         meta_losses = [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             sampled_timestamp = random.randint(0, train_time_bar) | ||||
|             past_dataset = TimeData( | ||||
| @@ -145,21 +135,23 @@ def main(args): | ||||
|                 env_info["{:}-x".format(sampled_timestamp + 1)], | ||||
|                 env_info["{:}-y".format(sampled_timestamp + 1)], | ||||
|             ) | ||||
|             maml.adapt(base_model, past_dataset) | ||||
|             import pdb | ||||
|  | ||||
|             pdb.set_trace() | ||||
|  | ||||
|         meta_loss = torch.stack(all_meta_losses).mean() | ||||
|             future_container = maml.adapt(model, past_dataset) | ||||
|             future_y_hat = maml.predict(future_dataset.x, future_container) | ||||
|             future_loss = maml.criterion(future_y_hat, future_dataset.y) | ||||
|             meta_losses.append(future_loss) | ||||
|         meta_loss = torch.stack(meta_losses).mean() | ||||
|         meta_loss.backward() | ||||
|         adaptor.step() | ||||
|         maml.step() | ||||
|  | ||||
|         debug_str = pool.debug_info(debug_timestamp) | ||||
|         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||
|  | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
| @@ -187,7 +179,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--meta_lr", | ||||
|         type=float, | ||||
|         default=0.01, | ||||
|         default=0.1, | ||||
|         help="The learning rate for the MAML optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
| @@ -202,7 +194,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         default=10, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
| @@ -223,7 +215,7 @@ if __name__ == "__main__": | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-{:}-d{:}".format( | ||||
|         args.save_dir, args.env_version, args.hidden_dim | ||||
|     args.save_dir = "{:}-s{:}-{:}-d{:}".format( | ||||
|         args.save_dir, args.inner_step, args.env_version, args.hidden_dim | ||||
|     ) | ||||
|     main(args) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user