Updates
This commit is contained in:
		| @@ -6,10 +6,11 @@ | ||||
| # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 | ||||
| # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||
| ##################################################### | ||||
| import pdb, sys, time, copy, torch, random, argparse | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| from torch.nn import functional as F | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| @@ -103,7 +104,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             _, [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(1, 1), None, True | ||||
|                 future_time.to(args.device).view(1, 1), None, False | ||||
|             ) | ||||
|             if save: | ||||
|                 w_containers[idx] = future_container.no_grad_clone() | ||||
| @@ -159,50 +160,57 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         left_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], [] | ||||
|         total_future_losses, total_present_losses, total_regu_losses = [], [], [] | ||||
|         optimizer.zero_grad() | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             rand_index = random.randint(0, meta_model.meta_length - 1) | ||||
|             timestamp = meta_model.meta_timestamps[rand_index] | ||||
|             meta_embed = meta_model.super_meta_embed[rand_index] | ||||
|  | ||||
|             _, [container], time_embed = meta_model( | ||||
|                 torch.unsqueeze(timestamp, dim=0), None, True | ||||
|                 torch.unsqueeze(timestamp, dim=0), None, False | ||||
|             ) | ||||
|             _, (inputs, targets) = xenv(timestamp.item()) | ||||
|             inputs, targets = inputs.to(device), targets.to(device) | ||||
|             # generate models one step ahead | ||||
|             predictions = base_model.forward_with_container(inputs, container) | ||||
|             total_meta_v1_losses.append(criterion(predictions, targets)) | ||||
|             # the matching loss | ||||
|             match_loss = criterion(torch.squeeze(time_embed, dim=0), meta_embed) | ||||
|             total_match_losses.append(match_loss) | ||||
|             total_future_losses.append(criterion(predictions, targets)) | ||||
|             # randomly sample | ||||
|             rand_index = random.randint(0, meta_model.meta_length - 1) | ||||
|             timestamp = meta_model.meta_timestamps[rand_index] | ||||
|             meta_embed = meta_model.super_meta_embed[rand_index] | ||||
|  | ||||
|             time_embed = meta_model(torch.unsqueeze(timestamp, dim=0), None, True) | ||||
|             total_regu_losses.append( | ||||
|                 F.mse_loss( | ||||
|                     torch.squeeze(time_embed, dim=0), meta_embed, reduction="mean" | ||||
|                 ) | ||||
|             ) | ||||
|             # generate models via memory | ||||
|             _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), True) | ||||
|             _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), False) | ||||
|             predictions = base_model.forward_with_container(inputs, container) | ||||
|             total_meta_v2_losses.append(criterion(predictions, targets)) | ||||
|             total_present_losses.append(criterion(predictions, targets)) | ||||
|         with torch.no_grad(): | ||||
|             meta_std = torch.stack(total_meta_v1_losses).std().item() | ||||
|         meta_v1_loss = torch.stack(total_meta_v1_losses).mean() | ||||
|         meta_v2_loss = torch.stack(total_meta_v2_losses).mean() | ||||
|         match_loss = torch.stack(total_match_losses).mean() | ||||
|         total_loss = meta_v1_loss + meta_v2_loss + match_loss | ||||
|             meta_std = torch.stack(total_future_losses).std().item() | ||||
|         loss_future = torch.stack(total_future_losses).mean() | ||||
|         loss_present = torch.stack(total_present_losses).mean() | ||||
|         regularization_loss = torch.stack(total_regu_losses).mean() | ||||
|         total_loss = loss_future + loss_present + regularization_loss | ||||
|         total_loss.backward() | ||||
|         optimizer.step() | ||||
|         # success | ||||
|         success, best_score = meta_model.save_best(-total_loss.item()) | ||||
|         logger.log( | ||||
|             "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format( | ||||
|             "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format( | ||||
|                 time_string(), | ||||
|                 iepoch, | ||||
|                 args.epochs, | ||||
|                 total_loss.item(), | ||||
|                 meta_std, | ||||
|                 meta_v1_loss.item(), | ||||
|                 meta_v2_loss.item(), | ||||
|                 match_loss.item(), | ||||
|                 loss_future.item(), | ||||
|                 loss_present.item(), | ||||
|                 regularization_loss.item(), | ||||
|             ) | ||||
|             + ", batch={:}".format(len(total_meta_v1_losses)) | ||||
|             + ", batch={:}".format(len(total_future_losses)) | ||||
|             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) | ||||
|             + ", {:}".format(left_time) | ||||
|   | ||||
| @@ -34,7 +34,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         assert interval is not None | ||||
|         self._interval = interval | ||||
|         self._seq_length = seq_length | ||||
|         self._thresh = interval * 30 if thresh is None else thresh | ||||
|         self._thresh = interval * 50 if thresh is None else thresh | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
| @@ -183,7 +183,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         ) | ||||
|         return timestamp_embeds | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, get_seq_last): | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         if time_embeds is None: | ||||
|             time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) | ||||
|             B, S = time_seq.shape | ||||
| @@ -193,8 +193,9 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             B, S, _ = time_embeds.shape | ||||
|         # create joint embed | ||||
|         num_layer, _ = self._super_layer_embed.shape | ||||
|         if get_seq_last: | ||||
|         time_embeds = time_embeds[:, -1, :] | ||||
|         if tembed_only: | ||||
|             return time_embeds | ||||
|         # The shape of `joint_embed` is batch * num-layers * input-dim | ||||
|         joint_embeds = torch.cat( | ||||
|             ( | ||||
| @@ -203,31 +204,12 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             ), | ||||
|             dim=-1, | ||||
|         ) | ||||
|         else: | ||||
|             # The shape of `joint_embed` is batch * seq * num-layers * input-dim | ||||
|             joint_embeds = torch.cat( | ||||
|                 ( | ||||
|                     time_embeds.view(B, S, 1, -1).expand(-1, -1, num_layer, -1), | ||||
|                     self._super_layer_embed.view(1, 1, num_layer, -1).expand( | ||||
|                         B, S, -1, -1 | ||||
|                     ), | ||||
|                 ), | ||||
|                 dim=-1, | ||||
|             ) | ||||
|         batch_weights = self._generator(joint_embeds) | ||||
|         batch_containers = [] | ||||
|         for weights in torch.split(batch_weights, 1): | ||||
|             if get_seq_last: | ||||
|             batch_containers.append( | ||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||
|             ) | ||||
|             else: | ||||
|                 seq_containers = [] | ||||
|                 for ws in torch.split(weights.squeeze(0), 1): | ||||
|                     seq_containers.append( | ||||
|                         self._shape_container.translate(torch.split(ws.squeeze(0), 1)) | ||||
|                     ) | ||||
|                 batch_containers.append(seq_containers) | ||||
|         return time_seq, batch_containers, time_embeds | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
| @@ -241,7 +223,9 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         with torch.set_grad_enabled(True): | ||||
|             new_param = self.create_meta_embed() | ||||
|  | ||||
|             optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True) | ||||
|             optimizer = torch.optim.Adam( | ||||
|                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True | ||||
|             ) | ||||
|             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||
|             self.replace_append_learnt(timestamp, new_param) | ||||
|             self.train() | ||||
| @@ -255,10 +239,10 @@ class MetaModelV1(super_core.SuperModule): | ||||
|                 best_new_param = new_param.detach().clone() | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None, True) | ||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None) | ||||
|                 match_loss = criterion(new_param, time_embed) | ||||
|  | ||||
|                 _, [container], time_embed = self(None, new_param.view(1, 1, -1), True) | ||||
|                 _, [container], time_embed = self(None, new_param.view(1, 1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|                 meta_loss = criterion(y_hat, y) | ||||
|                 loss = meta_loss + match_loss | ||||
|   | ||||
| @@ -1,51 +1,49 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| import math | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_core import LinearFunc | ||||
| from .math_core import DynamicLinearFunc | ||||
| from .math_core import DynamicQuadraticFunc | ||||
| from .math_core import ConstantFunc, ComposedSinFunc | ||||
| from .math_core import ConstantFunc, ComposedSinFunc as SinFunc | ||||
| from .math_core import GaussianDGenerator | ||||
|  | ||||
|  | ||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||
|  | ||||
|  | ||||
| def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): | ||||
|     if version == "v0": | ||||
| def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"): | ||||
|     max_time = math.pi * 10 | ||||
|     if version == "v1": | ||||
|         mean_generator = ConstantFunc(0) | ||||
|         std_generator = ConstantFunc(1) | ||||
|         data_generator = GaussianDGenerator( | ||||
|             [mean_generator], [[std_generator]], (-2, 2) | ||||
|         ) | ||||
|         time_generator = TimeStamp( | ||||
|             min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode | ||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||
|         ) | ||||
|         oracle_map = DynamicLinearFunc( | ||||
|             params={ | ||||
|                 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), | ||||
|                 1: ConstantFunc(0), | ||||
|                 0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),  # 2 sin(t) + 2.2 | ||||
|                 1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),  # 1.5 sin(0.6t) + 1.8 | ||||
|             } | ||||
|         ) | ||||
|         dynamic_env = SyntheticDEnv( | ||||
|             data_generator, oracle_map, time_generator, num_per_task | ||||
|         ) | ||||
|     elif version == "v1": | ||||
|     elif version == "v2": | ||||
|         mean_generator = ConstantFunc(0) | ||||
|         std_generator = ConstantFunc(1) | ||||
|         data_generator = GaussianDGenerator( | ||||
|             [mean_generator], [[std_generator]], (-2, 2) | ||||
|         ) | ||||
|         time_generator = TimeStamp( | ||||
|             min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode | ||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||
|         ) | ||||
|         oracle_map = DynamicLinearFunc( | ||||
|         oracle_map = DynamicQuadraticFunc( | ||||
|             params={ | ||||
|                 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), | ||||
|                 1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), | ||||
|                 0: LinearFunc(params={0: 0.1, 1: 0}),  # 0.1 * t | ||||
|                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) | ||||
|             } | ||||
|         ) | ||||
|         dynamic_env = SyntheticDEnv( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user