Prototype MAML
This commit is contained in:
		| @@ -1,8 +1,8 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/basic-maml.py --env_version v1   # | # python exps/LFNA/basic-maml.py --env_version v1 --hidden_dim 16 --inner_step 5 | ||||||
| # python exps/LFNA/basic-maml.py --env_version v2   # | # python exps/LFNA/basic-maml.py --env_version v2 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -30,73 +30,71 @@ from lfna_utils import lfna_setup, TimeData | |||||||
| class MAML: | class MAML: | ||||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" |     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||||
|  |  | ||||||
|     def __init__(self, container, criterion, meta_lr, inner_lr=0.01, inner_step=1): |     def __init__( | ||||||
|  |         self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1 | ||||||
|  |     ): | ||||||
|         self.criterion = criterion |         self.criterion = criterion | ||||||
|         self.container = container |         # self.container = container | ||||||
|  |         self.network = network | ||||||
|         self.meta_optimizer = torch.optim.Adam( |         self.meta_optimizer = torch.optim.Adam( | ||||||
|             self.container.parameters(), lr=meta_lr, amsgrad=True |             self.network.parameters(), lr=meta_lr, amsgrad=True | ||||||
|  |         ) | ||||||
|  |         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|  |             optimizer, | ||||||
|  |             milestones=[ | ||||||
|  |                 int(epochs * 0.25), | ||||||
|  |                 int(epochs * 0.5), | ||||||
|  |                 int(epochs * 0.75), | ||||||
|  |             ], | ||||||
|  |             gamma=0.3, | ||||||
|         ) |         ) | ||||||
|         self.inner_lr = inner_lr |         self.inner_lr = inner_lr | ||||||
|         self.inner_step = inner_step |         self.inner_step = inner_step | ||||||
|  |         self._best_info = dict(state_dict=None, score=None) | ||||||
|  |         print("There are {:} weights.".format(w_container.numel())) | ||||||
|  |  | ||||||
|     def adapt(self, model, dataset): |     def adapt(self, dataset): | ||||||
|         # create a container for the future timestamp |         # create a container for the future timestamp | ||||||
|         y_hat = model.forward_with_container(dataset.x, self.container) |         container = self.network.get_w_container() | ||||||
|  |  | ||||||
|  |         for k in range(0, self.inner_step): | ||||||
|  |             y_hat = self.network.forward_with_container(dataset.x, container) | ||||||
|             loss = self.criterion(y_hat, dataset.y) |             loss = self.criterion(y_hat, dataset.y) | ||||||
|         grads = torch.autograd.grad(loss, self.container.parameters()) |             grads = torch.autograd.grad(loss, container.parameters()) | ||||||
|  |  | ||||||
|         fast_container = self.container.additive( |             container = container.additive([-self.inner_lr * grad for grad in grads]) | ||||||
|             [-self.inner_lr * grad for grad in grads] |         return container | ||||||
|         ) |  | ||||||
|         import pdb |  | ||||||
|  |  | ||||||
|         pdb.set_trace() |     def predict(self, x, container=None): | ||||||
|         w_container.requires_grad_(True) |         if container is not None: | ||||||
|         containers = [w_container] |             y_hat = self.network.forward_with_container(x, container) | ||||||
|         for idx, dataset in enumerate(seq_datasets): |         else: | ||||||
|             x, y = dataset.x, dataset.y |             y_hat = self.network(x) | ||||||
|             y_hat = model.forward_with_container(x, containers[-1]) |         return y_hat | ||||||
|             loss = criterion(y_hat, y) |  | ||||||
|             gradients = torch.autograd.grad(loss, containers[-1].tensors) |  | ||||||
|             with torch.no_grad(): |  | ||||||
|                 flatten_w = containers[-1].flatten().view(-1, 1) |  | ||||||
|                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) |  | ||||||
|                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) |  | ||||||
|                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) |  | ||||||
|             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) |  | ||||||
|             delta = self.delta_net(delta_inputs).view(-1) |  | ||||||
|             delta = torch.clamp(delta, -0.5, 0.5) |  | ||||||
|             unflatten_delta = containers[-1].unflatten(delta) |  | ||||||
|             future_container = containers[-1].no_grad_clone().additive(unflatten_delta) |  | ||||||
|             # future_container = containers[-1].additive(unflatten_delta) |  | ||||||
|             containers.append(future_container) |  | ||||||
|         # containers = containers[1:] |  | ||||||
|         meta_loss = [] |  | ||||||
|         temp_containers = [] |  | ||||||
|         for idx, dataset in enumerate(seq_datasets): |  | ||||||
|             if idx == 0: |  | ||||||
|                 continue |  | ||||||
|             current_container = containers[idx] |  | ||||||
|             y_hat = model.forward_with_container(dataset.x, current_container) |  | ||||||
|             loss = criterion(y_hat, dataset.y) |  | ||||||
|             meta_loss.append(loss) |  | ||||||
|             temp_containers.append((dataset.timestamp, current_container, -loss.item())) |  | ||||||
|         meta_loss = sum(meta_loss) |  | ||||||
|         w_container.requires_grad_(False) |  | ||||||
|         # meta_loss.backward() |  | ||||||
|         # self.meta_optimizer.step() |  | ||||||
|         return meta_loss, temp_containers |  | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) |         torch.nn.utils.clip_grad_norm_(self.container.parameters(), 1.0) | ||||||
|         self.meta_optimizer.step() |         self.meta_optimizer.step() | ||||||
|  |         self.meta_lr_scheduler.step() | ||||||
|  |  | ||||||
|     def zero_grad(self): |     def zero_grad(self): | ||||||
|         self.meta_optimizer.zero_grad() |         self.meta_optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |     def save_best(self, network, score): | ||||||
|  |         if self._best_info["score"] is None or self._best_info["score"] < score: | ||||||
|  |             state_dict = dict( | ||||||
|  |                 criterion=criterion, | ||||||
|  |                 network=network.state_dict(), | ||||||
|  |                 meta_optimizer=self.meta_optimizer.state_dict(), | ||||||
|  |                 meta_lr_scheduler=self.meta_lr_scheduler.state_dict(), | ||||||
|  |             ) | ||||||
|  |             self._best_info["state_dict"] = state_dict | ||||||
|  |             self._best_info["score"] = score | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     logger, env_info = lfna_setup(args) |     logger, env_info, model_kwargs = lfna_setup(args) | ||||||
|  |     model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|  |  | ||||||
|     total_time = env_info["total"] |     total_time = env_info["total"] | ||||||
|     for i in range(total_time): |     for i in range(total_time): | ||||||
| @@ -104,19 +102,12 @@ def main(args): | |||||||
|             nkey = "{:}-{:}".format(i, xkey) |             nkey = "{:}-{:}".format(i, xkey) | ||||||
|             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) |             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||||
|     train_time_bar = total_time // 2 |     train_time_bar = total_time // 2 | ||||||
|     base_model = get_model( |  | ||||||
|         dict(model_type="simple_mlp"), |  | ||||||
|         act_cls="leaky_relu", |  | ||||||
|         norm_cls="identity", |  | ||||||
|         input_dim=1, |  | ||||||
|         output_dim=1, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     w_container = base_model.get_w_container() |  | ||||||
|     criterion = torch.nn.MSELoss() |     criterion = torch.nn.MSELoss() | ||||||
|     print("There are {:} weights.".format(w_container.numel())) |  | ||||||
|  |  | ||||||
|     maml = MAML(w_container, criterion, args.meta_lr, args.inner_lr, args.inner_step) |     maml = MAML( | ||||||
|  |         model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     # meta-training |     # meta-training | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
| @@ -131,8 +122,7 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         maml.zero_grad() |         maml.zero_grad() | ||||||
|  |         meta_losses = [] | ||||||
|         all_meta_losses = [] |  | ||||||
|         for ibatch in range(args.meta_batch): |         for ibatch in range(args.meta_batch): | ||||||
|             sampled_timestamp = random.randint(0, train_time_bar) |             sampled_timestamp = random.randint(0, train_time_bar) | ||||||
|             past_dataset = TimeData( |             past_dataset = TimeData( | ||||||
| @@ -145,21 +135,23 @@ def main(args): | |||||||
|                 env_info["{:}-x".format(sampled_timestamp + 1)], |                 env_info["{:}-x".format(sampled_timestamp + 1)], | ||||||
|                 env_info["{:}-y".format(sampled_timestamp + 1)], |                 env_info["{:}-y".format(sampled_timestamp + 1)], | ||||||
|             ) |             ) | ||||||
|             maml.adapt(base_model, past_dataset) |             future_container = maml.adapt(model, past_dataset) | ||||||
|             import pdb |             future_y_hat = maml.predict(future_dataset.x, future_container) | ||||||
|  |             future_loss = maml.criterion(future_y_hat, future_dataset.y) | ||||||
|             pdb.set_trace() |             meta_losses.append(future_loss) | ||||||
|  |         meta_loss = torch.stack(meta_losses).mean() | ||||||
|         meta_loss = torch.stack(all_meta_losses).mean() |  | ||||||
|         meta_loss.backward() |         meta_loss.backward() | ||||||
|         adaptor.step() |         maml.step() | ||||||
|  |  | ||||||
|         debug_str = pool.debug_info(debug_timestamp) |  | ||||||
|         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) |         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||||
|  |  | ||||||
|         per_epoch_time.update(time.time() - start_time) |         per_epoch_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     import pdb | ||||||
|  |  | ||||||
|  |     pdb.set_trace() | ||||||
|  |  | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
|     logger.close() |     logger.close() | ||||||
|  |  | ||||||
| @@ -187,7 +179,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_lr", |         "--meta_lr", | ||||||
|         type=float, |         type=float, | ||||||
|         default=0.01, |         default=0.1, | ||||||
|         help="The learning rate for the MAML optimizer (default is Adam)", |         help="The learning rate for the MAML optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -202,7 +194,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", |         "--meta_batch", | ||||||
|         type=int, |         type=int, | ||||||
|         default=5, |         default=10, | ||||||
|         help="The batch size for the meta-model", |         help="The batch size for the meta-model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -223,7 +215,7 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|     args.save_dir = "{:}-{:}-d{:}".format( |     args.save_dir = "{:}-s{:}-{:}-d{:}".format( | ||||||
|         args.save_dir, args.env_version, args.hidden_dim |         args.save_dir, args.inner_step, args.env_version, args.hidden_dim | ||||||
|     ) |     ) | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user