Update ablation for GeMOSA
This commit is contained in:
		| @@ -4,6 +4,7 @@ | |||||||
| # python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | # python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | ||||||
| # python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | # python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda | ||||||
| # python exps/GeMOSA/basic-same.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | # python exps/GeMOSA/basic-same.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | ||||||
|  | # python exps/GeMOSA/basic-same.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -28,7 +29,12 @@ from xautodl.log_utils import AverageMeter, convert_secs2time | |||||||
| from xautodl.utils import split_str2indexes | from xautodl.utils import split_str2indexes | ||||||
|  |  | ||||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | from xautodl.procedures.metric_utils import ( | ||||||
|  |     SaveMetric, | ||||||
|  |     MSEMetric, | ||||||
|  |     Top1AccMetric, | ||||||
|  |     ComposeMetric, | ||||||
|  | ) | ||||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | from xautodl.datasets.synthetic_core import get_synthetic_env | ||||||
| from xautodl.models.xcore import get_model | from xautodl.models.xcore import get_model | ||||||
|  |  | ||||||
| @@ -57,6 +63,17 @@ def main(args): | |||||||
|     logger.log("The total enviornment: {:}".format(env)) |     logger.log("The total enviornment: {:}".format(env)) | ||||||
|     w_containers = dict() |     w_containers = dict() | ||||||
|  |  | ||||||
|  |     if env.meta_info["task"] == "regression": | ||||||
|  |         criterion = torch.nn.MSELoss() | ||||||
|  |         metric_cls = MSEMetric | ||||||
|  |     elif env.meta_info["task"] == "classification": | ||||||
|  |         criterion = torch.nn.CrossEntropyLoss() | ||||||
|  |         metric_cls = Top1AccMetric | ||||||
|  |     else: | ||||||
|  |         raise ValueError( | ||||||
|  |             "This task ({:}) is not supported.".format(all_env.meta_info["task"]) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): |     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||||
|  |  | ||||||
| @@ -79,7 +96,6 @@ def main(args): | |||||||
|             print(model) |             print(model) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|         criterion = torch.nn.MSELoss() |  | ||||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|             optimizer, |             optimizer, | ||||||
|             milestones=[ |             milestones=[ | ||||||
| @@ -89,7 +105,7 @@ def main(args): | |||||||
|             ], |             ], | ||||||
|             gamma=0.3, |             gamma=0.3, | ||||||
|         ) |         ) | ||||||
|         train_metric = MSEMetric() |         train_metric = metric_cls(True) | ||||||
|         best_loss, best_param = None, None |         best_loss, best_param = None, None | ||||||
|         for _iepoch in range(args.epochs): |         for _iepoch in range(args.epochs): | ||||||
|             preds = model(historical_x) |             preds = model(historical_x) | ||||||
| @@ -108,19 +124,19 @@ def main(args): | |||||||
|             train_metric(preds, historical_y) |             train_metric(preds, historical_y) | ||||||
|         train_results = train_metric.get_info() |         train_results = train_metric.get_info() | ||||||
|  |  | ||||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) |         xmetric = ComposeMetric(metric_cls(True), SaveMetric()) | ||||||
|         eval_dataset = torch.utils.data.TensorDataset( |         eval_dataset = torch.utils.data.TensorDataset( | ||||||
|             future_x.to(args.device), future_y.to(args.device) |             future_x.to(args.device), future_y.to(args.device) | ||||||
|         ) |         ) | ||||||
|         eval_loader = torch.utils.data.DataLoader( |         eval_loader = torch.utils.data.DataLoader( | ||||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 |             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||||
|         ) |         ) | ||||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) |         results = basic_eval_fn(eval_loader, model, xmetric, logger) | ||||||
|         log_str = ( |         log_str = ( | ||||||
|             "[{:}]".format(time_string()) |             "[{:}]".format(time_string()) | ||||||
|             + " [{:04d}/{:04d}]".format(idx, len(env)) |             + " [{:04d}/{:04d}]".format(idx, len(env)) | ||||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( |             + " train-score: {:.5f}, eval-score: {:.5f}".format( | ||||||
|                 train_results["mse"], results["mse"] |                 train_results["score"], results["score"] | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         logger.log(log_str) |         logger.log(log_str) | ||||||
|   | |||||||
| @@ -1,12 +1,16 @@ | |||||||
| ##################################################### | ########################################################## | ||||||
| # Learning to Generate Model One Step Ahead         # | # Learning to Efficiently Generate Models One Step Ahead # | ||||||
| ##################################################### | ########################################################## | ||||||
|  | # <----> run on CPU | ||||||
| # python exps/GeMOSA/main.py --env_version v1 --workers 0 | # python exps/GeMOSA/main.py --env_version v1 --workers 0 | ||||||
|  | # <----> run on a GPU | ||||||
| # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||||
| ##################################################### | # <----> ablation commands | ||||||
|  | # python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda | ||||||
|  | ########################################################## | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| @@ -36,6 +40,7 @@ from xautodl.models.xcore import get_model | |||||||
| from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric | from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric | ||||||
|  |  | ||||||
| from meta_model import MetaModelV1 | from meta_model import MetaModelV1 | ||||||
|  | from meta_model_ablation import MetaModel_TraditionalAtt | ||||||
|  |  | ||||||
|  |  | ||||||
| def online_evaluate( | def online_evaluate( | ||||||
| @@ -230,7 +235,13 @@ def main(args): | |||||||
|  |  | ||||||
|     # pre-train the hypernetwork |     # pre-train the hypernetwork | ||||||
|     timestamps = trainval_env.get_timestamp(None) |     timestamps = trainval_env.get_timestamp(None) | ||||||
|     meta_model = MetaModelV1( |     if args.ablation is None: | ||||||
|  |         MetaModel_cls = MetaModelV1 | ||||||
|  |     elif args.ablation == "old": | ||||||
|  |         MetaModel_cls = MetaModel_TraditionalAtt | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Unknown ablation : {:}".format(args.ablation)) | ||||||
|  |     meta_model = MetaModel_cls( | ||||||
|         shape_container, |         shape_container, | ||||||
|         args.layer_dim, |         args.layer_dim, | ||||||
|         args.time_dim, |         args.time_dim, | ||||||
| @@ -373,6 +384,9 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", type=int, default=4, help="The number of workers in parallel." |         "--workers", type=int, default=4, help="The number of workers in parallel." | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--ablation", type=str, default=None, help="The ablation indicator." | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--device", |         "--device", | ||||||
|         type=str, |         type=str, | ||||||
| @@ -385,7 +399,7 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|     args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( |     args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format( | ||||||
|         args.save_dir, |         args.save_dir, | ||||||
|         args.meta_batch, |         args.meta_batch, | ||||||
|         args.hidden_dim, |         args.hidden_dim, | ||||||
| @@ -395,6 +409,7 @@ if __name__ == "__main__": | |||||||
|         args.lr, |         args.lr, | ||||||
|         args.weight_decay, |         args.weight_decay, | ||||||
|         args.epochs, |         args.epochs, | ||||||
|  |         args.ablation, | ||||||
|         args.env_version, |         args.env_version, | ||||||
|     ) |     ) | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
| @@ -1,6 +1,3 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # |  | ||||||
| ##################################################### |  | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|   | |||||||
							
								
								
									
										260
									
								
								exps/GeMOSA/meta_model_ablation.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										260
									
								
								exps/GeMOSA/meta_model_ablation.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,260 @@ | |||||||
|  | # | ||||||
|  | # This is used for the ablation studies: | ||||||
|  | # The meta-model in this file uses the traditional attention in | ||||||
|  | # transformer. | ||||||
|  | # | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  | from xautodl.xlayers import super_core | ||||||
|  | from xautodl.xlayers import trunc_normal_ | ||||||
|  | from xautodl.models.xcore import get_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MetaModel_TraditionalAtt(super_core.SuperModule): | ||||||
|  |     """Learning to Generate Models One Step Ahead (Meta Model Design).""" | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         shape_container, | ||||||
|  |         layer_dim, | ||||||
|  |         time_dim, | ||||||
|  |         meta_timestamps, | ||||||
|  |         dropout: float = 0.1, | ||||||
|  |         seq_length: int = None, | ||||||
|  |         interval: float = None, | ||||||
|  |         thresh: float = None, | ||||||
|  |     ): | ||||||
|  |         super(MetaModel_TraditionalAtt, self).__init__() | ||||||
|  |         self._shape_container = shape_container | ||||||
|  |         self._num_layers = len(shape_container) | ||||||
|  |         self._numel_per_layer = [] | ||||||
|  |         for ilayer in range(self._num_layers): | ||||||
|  |             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||||
|  |         self._raw_meta_timestamps = meta_timestamps | ||||||
|  |         assert interval is not None | ||||||
|  |         self._interval = interval | ||||||
|  |         self._thresh = interval * seq_length if thresh is None else thresh | ||||||
|  |  | ||||||
|  |         self.register_parameter( | ||||||
|  |             "_super_layer_embed", | ||||||
|  |             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)), | ||||||
|  |         ) | ||||||
|  |         self.register_parameter( | ||||||
|  |             "_super_meta_embed", | ||||||
|  |             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||||
|  |         ) | ||||||
|  |         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||||
|  |         self._time_embed_dim = time_dim | ||||||
|  |         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||||
|  |         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||||
|  |  | ||||||
|  |         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||||
|  |             time_dim, scale=1 / interval | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # build transformer | ||||||
|  |         self._trans_att = super_core.SuperQKVAttention( | ||||||
|  |             in_q_dim=time_dim, | ||||||
|  |             in_k_dim=time_dim, | ||||||
|  |             in_v_dim=time_dim, | ||||||
|  |             num_heads=4, | ||||||
|  |             proj_dim=time_dim, | ||||||
|  |             qkv_bias=True, | ||||||
|  |             attn_drop=None, | ||||||
|  |             proj_drop=dropout, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         model_kwargs = dict( | ||||||
|  |             config=dict(model_type="dual_norm_mlp"), | ||||||
|  |             input_dim=layer_dim + time_dim, | ||||||
|  |             output_dim=max(self._numel_per_layer), | ||||||
|  |             hidden_dims=[(layer_dim + time_dim) * 2] * 3, | ||||||
|  |             act_cls="gelu", | ||||||
|  |             norm_cls="layer_norm_1d", | ||||||
|  |             dropout=dropout, | ||||||
|  |         ) | ||||||
|  |         self._generator = get_model(**model_kwargs) | ||||||
|  |  | ||||||
|  |         # initialization | ||||||
|  |         trunc_normal_( | ||||||
|  |             [self._super_layer_embed, self._super_meta_embed], | ||||||
|  |             std=0.02, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def get_parameters(self, time_embed, attention, generator): | ||||||
|  |         parameters = [] | ||||||
|  |         if time_embed: | ||||||
|  |             parameters.append(self._super_meta_embed) | ||||||
|  |         if attention: | ||||||
|  |             parameters.extend(list(self._trans_att.parameters())) | ||||||
|  |         if generator: | ||||||
|  |             parameters.append(self._super_layer_embed) | ||||||
|  |             parameters.extend(list(self._generator.parameters())) | ||||||
|  |         return parameters | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def meta_timestamps(self): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             meta_timestamps = [self._meta_timestamps] | ||||||
|  |             for key in ("fixed", "learnt"): | ||||||
|  |                 if self._append_meta_timestamps[key] is not None: | ||||||
|  |                     meta_timestamps.append(self._append_meta_timestamps[key]) | ||||||
|  |         return torch.cat(meta_timestamps) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def super_meta_embed(self): | ||||||
|  |         meta_embed = [self._super_meta_embed] | ||||||
|  |         for key in ("fixed", "learnt"): | ||||||
|  |             if self._append_meta_embed[key] is not None: | ||||||
|  |                 meta_embed.append(self._append_meta_embed[key]) | ||||||
|  |         return torch.cat(meta_embed) | ||||||
|  |  | ||||||
|  |     def create_meta_embed(self): | ||||||
|  |         param = torch.Tensor(1, self._time_embed_dim) | ||||||
|  |         trunc_normal_(param, std=0.02) | ||||||
|  |         param = param.to(self._super_meta_embed.device) | ||||||
|  |         param = torch.nn.Parameter(param, True) | ||||||
|  |         return param | ||||||
|  |  | ||||||
|  |     def get_closest_meta_distance(self, timestamp): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             distances = torch.abs(self.meta_timestamps - timestamp) | ||||||
|  |             return torch.min(distances).item() | ||||||
|  |  | ||||||
|  |     def replace_append_learnt(self, timestamp, meta_embed): | ||||||
|  |         self._append_meta_timestamps["learnt"] = timestamp | ||||||
|  |         self._append_meta_embed["learnt"] = meta_embed | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def meta_length(self): | ||||||
|  |         return self.meta_timestamps.numel() | ||||||
|  |  | ||||||
|  |     def clear_fixed(self): | ||||||
|  |         self._append_meta_timestamps["fixed"] = None | ||||||
|  |         self._append_meta_embed["fixed"] = None | ||||||
|  |  | ||||||
|  |     def clear_learnt(self): | ||||||
|  |         self.replace_append_learnt(None, None) | ||||||
|  |  | ||||||
|  |     def append_fixed(self, timestamp, meta_embed): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             device = self._super_meta_embed.device | ||||||
|  |             timestamp = timestamp.detach().clone().to(device) | ||||||
|  |             meta_embed = meta_embed.detach().clone().to(device) | ||||||
|  |             if self._append_meta_timestamps["fixed"] is None: | ||||||
|  |                 self._append_meta_timestamps["fixed"] = timestamp | ||||||
|  |             else: | ||||||
|  |                 self._append_meta_timestamps["fixed"] = torch.cat( | ||||||
|  |                     (self._append_meta_timestamps["fixed"], timestamp), dim=0 | ||||||
|  |                 ) | ||||||
|  |             if self._append_meta_embed["fixed"] is None: | ||||||
|  |                 self._append_meta_embed["fixed"] = meta_embed | ||||||
|  |             else: | ||||||
|  |                 self._append_meta_embed["fixed"] = torch.cat( | ||||||
|  |                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def gen_time_embed(self, timestamps): | ||||||
|  |         # timestamps is a batch of timestamps | ||||||
|  |         [B] = timestamps.shape | ||||||
|  |         # batch, seq = timestamps.shape | ||||||
|  |         timestamps = timestamps.view(-1, 1) | ||||||
|  |         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||||
|  |         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||||
|  |         timestamp_q_embed = self._tscalar_embed(timestamps) | ||||||
|  |         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) | ||||||
|  |  | ||||||
|  |         # create the mask | ||||||
|  |         mask = ( | ||||||
|  |             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||||
|  |         ) | ( | ||||||
|  |             torch.abs( | ||||||
|  |                 torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1) | ||||||
|  |             ) | ||||||
|  |             > self._thresh | ||||||
|  |         ) | ||||||
|  |         timestamp_embeds = self._trans_att( | ||||||
|  |             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask | ||||||
|  |         ) | ||||||
|  |         return timestamp_embeds[:, -1, :] | ||||||
|  |  | ||||||
|  |     def gen_model(self, time_embeds): | ||||||
|  |         B, _ = time_embeds.shape | ||||||
|  |         # create joint embed | ||||||
|  |         num_layer, _ = self._super_layer_embed.shape | ||||||
|  |         # The shape of `joint_embed` is batch * num-layers * input-dim | ||||||
|  |         joint_embeds = torch.cat( | ||||||
|  |             ( | ||||||
|  |                 time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), | ||||||
|  |                 self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), | ||||||
|  |             ), | ||||||
|  |             dim=-1, | ||||||
|  |         ) | ||||||
|  |         batch_weights = self._generator(joint_embeds) | ||||||
|  |         batch_containers = [] | ||||||
|  |         for weights in torch.split(batch_weights, 1): | ||||||
|  |             batch_containers.append( | ||||||
|  |                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||||
|  |             ) | ||||||
|  |         return batch_containers | ||||||
|  |  | ||||||
|  |     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def easy_adapt(self, timestamp, time_embed): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device) | ||||||
|  |             self.replace_append_learnt(None, None) | ||||||
|  |             self.append_fixed(timestamp, time_embed) | ||||||
|  |  | ||||||
|  |     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||||
|  |         distance = self.get_closest_meta_distance(timestamp) | ||||||
|  |         if distance + self._interval * 1e-2 <= self._interval: | ||||||
|  |             return False, None | ||||||
|  |         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||||
|  |         with torch.set_grad_enabled(True): | ||||||
|  |             new_param = self.create_meta_embed() | ||||||
|  |  | ||||||
|  |             optimizer = torch.optim.Adam( | ||||||
|  |                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True | ||||||
|  |             ) | ||||||
|  |             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||||
|  |             self.replace_append_learnt(timestamp, new_param) | ||||||
|  |             self.train() | ||||||
|  |             base_model.train() | ||||||
|  |             if init_info is not None: | ||||||
|  |                 best_loss = init_info["loss"] | ||||||
|  |                 new_param.data.copy_(init_info["param"].data) | ||||||
|  |             else: | ||||||
|  |                 best_loss = 1e9 | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 best_new_param = new_param.detach().clone() | ||||||
|  |             for iepoch in range(epochs): | ||||||
|  |                 optimizer.zero_grad() | ||||||
|  |                 time_embed = self.gen_time_embed(timestamp.view(1)) | ||||||
|  |                 match_loss = criterion(new_param, time_embed) | ||||||
|  |  | ||||||
|  |                 [container] = self.gen_model(new_param.view(1, -1)) | ||||||
|  |                 y_hat = base_model.forward_with_container(x, container) | ||||||
|  |                 meta_loss = criterion(y_hat, y) | ||||||
|  |                 loss = meta_loss + match_loss | ||||||
|  |                 loss.backward() | ||||||
|  |                 optimizer.step() | ||||||
|  |                 if meta_loss.item() < best_loss: | ||||||
|  |                     with torch.no_grad(): | ||||||
|  |                         best_loss = meta_loss.item() | ||||||
|  |                         best_new_param = new_param.detach().clone() | ||||||
|  |         self.easy_adapt(timestamp, best_new_param) | ||||||
|  |         return True, best_loss | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||||
|  |             list(self._super_layer_embed.shape), | ||||||
|  |             list(self._super_meta_embed.shape), | ||||||
|  |             list(self._meta_timestamps.shape), | ||||||
|  |         ) | ||||||
		Reference in New Issue
	
	Block a user