Updates
This commit is contained in:
		| @@ -6,10 +6,11 @@ | |||||||
| # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 | # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 | ||||||
| # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||||
| ##################################################### | ##################################################### | ||||||
| import pdb, sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from torch.nn import functional as F | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
| print("LIB-DIR: {:}".format(lib_dir)) | print("LIB-DIR: {:}".format(lib_dir)) | ||||||
| @@ -103,7 +104,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | |||||||
|             meta_model.eval() |             meta_model.eval() | ||||||
|             base_model.eval() |             base_model.eval() | ||||||
|             _, [future_container], time_embeds = meta_model( |             _, [future_container], time_embeds = meta_model( | ||||||
|                 future_time.to(args.device).view(1, 1), None, True |                 future_time.to(args.device).view(1, 1), None, False | ||||||
|             ) |             ) | ||||||
|             if save: |             if save: | ||||||
|                 w_containers[idx] = future_container.no_grad_clone() |                 w_containers[idx] = future_container.no_grad_clone() | ||||||
| @@ -159,50 +160,57 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         left_time = "Time Left: {:}".format( |         left_time = "Time Left: {:}".format( | ||||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
|         ) |         ) | ||||||
|         total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], [] |         total_future_losses, total_present_losses, total_regu_losses = [], [], [] | ||||||
|         optimizer.zero_grad() |         optimizer.zero_grad() | ||||||
|         for ibatch in range(args.meta_batch): |         for ibatch in range(args.meta_batch): | ||||||
|             rand_index = random.randint(0, meta_model.meta_length - 1) |             rand_index = random.randint(0, meta_model.meta_length - 1) | ||||||
|             timestamp = meta_model.meta_timestamps[rand_index] |             timestamp = meta_model.meta_timestamps[rand_index] | ||||||
|             meta_embed = meta_model.super_meta_embed[rand_index] |  | ||||||
|  |  | ||||||
|             _, [container], time_embed = meta_model( |             _, [container], time_embed = meta_model( | ||||||
|                 torch.unsqueeze(timestamp, dim=0), None, True |                 torch.unsqueeze(timestamp, dim=0), None, False | ||||||
|             ) |             ) | ||||||
|             _, (inputs, targets) = xenv(timestamp.item()) |             _, (inputs, targets) = xenv(timestamp.item()) | ||||||
|             inputs, targets = inputs.to(device), targets.to(device) |             inputs, targets = inputs.to(device), targets.to(device) | ||||||
|             # generate models one step ahead |             # generate models one step ahead | ||||||
|             predictions = base_model.forward_with_container(inputs, container) |             predictions = base_model.forward_with_container(inputs, container) | ||||||
|             total_meta_v1_losses.append(criterion(predictions, targets)) |             total_future_losses.append(criterion(predictions, targets)) | ||||||
|             # the matching loss |             # randomly sample | ||||||
|             match_loss = criterion(torch.squeeze(time_embed, dim=0), meta_embed) |             rand_index = random.randint(0, meta_model.meta_length - 1) | ||||||
|             total_match_losses.append(match_loss) |             timestamp = meta_model.meta_timestamps[rand_index] | ||||||
|  |             meta_embed = meta_model.super_meta_embed[rand_index] | ||||||
|  |  | ||||||
|  |             time_embed = meta_model(torch.unsqueeze(timestamp, dim=0), None, True) | ||||||
|  |             total_regu_losses.append( | ||||||
|  |                 F.mse_loss( | ||||||
|  |                     torch.squeeze(time_embed, dim=0), meta_embed, reduction="mean" | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|             # generate models via memory |             # generate models via memory | ||||||
|             _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), True) |             _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), False) | ||||||
|             predictions = base_model.forward_with_container(inputs, container) |             predictions = base_model.forward_with_container(inputs, container) | ||||||
|             total_meta_v2_losses.append(criterion(predictions, targets)) |             total_present_losses.append(criterion(predictions, targets)) | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_std = torch.stack(total_meta_v1_losses).std().item() |             meta_std = torch.stack(total_future_losses).std().item() | ||||||
|         meta_v1_loss = torch.stack(total_meta_v1_losses).mean() |         loss_future = torch.stack(total_future_losses).mean() | ||||||
|         meta_v2_loss = torch.stack(total_meta_v2_losses).mean() |         loss_present = torch.stack(total_present_losses).mean() | ||||||
|         match_loss = torch.stack(total_match_losses).mean() |         regularization_loss = torch.stack(total_regu_losses).mean() | ||||||
|         total_loss = meta_v1_loss + meta_v2_loss + match_loss |         total_loss = loss_future + loss_present + regularization_loss | ||||||
|         total_loss.backward() |         total_loss.backward() | ||||||
|         optimizer.step() |         optimizer.step() | ||||||
|         # success |         # success | ||||||
|         success, best_score = meta_model.save_best(-total_loss.item()) |         success, best_score = meta_model.save_best(-total_loss.item()) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format( |             "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format( | ||||||
|                 time_string(), |                 time_string(), | ||||||
|                 iepoch, |                 iepoch, | ||||||
|                 args.epochs, |                 args.epochs, | ||||||
|                 total_loss.item(), |                 total_loss.item(), | ||||||
|                 meta_std, |                 meta_std, | ||||||
|                 meta_v1_loss.item(), |                 loss_future.item(), | ||||||
|                 meta_v2_loss.item(), |                 loss_present.item(), | ||||||
|                 match_loss.item(), |                 regularization_loss.item(), | ||||||
|             ) |             ) | ||||||
|             + ", batch={:}".format(len(total_meta_v1_losses)) |             + ", batch={:}".format(len(total_future_losses)) | ||||||
|             + ", success={:}, best={:.4f}".format(success, -best_score) |             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||||
|             + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) |             + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) | ||||||
|             + ", {:}".format(left_time) |             + ", {:}".format(left_time) | ||||||
|   | |||||||
| @@ -34,7 +34,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         assert interval is not None |         assert interval is not None | ||||||
|         self._interval = interval |         self._interval = interval | ||||||
|         self._seq_length = seq_length |         self._seq_length = seq_length | ||||||
|         self._thresh = interval * 30 if thresh is None else thresh |         self._thresh = interval * 50 if thresh is None else thresh | ||||||
|  |  | ||||||
|         self.register_parameter( |         self.register_parameter( | ||||||
|             "_super_layer_embed", |             "_super_layer_embed", | ||||||
| @@ -183,7 +183,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         ) |         ) | ||||||
|         return timestamp_embeds |         return timestamp_embeds | ||||||
|  |  | ||||||
|     def forward_raw(self, timestamps, time_embeds, get_seq_last): |     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||||
|         if time_embeds is None: |         if time_embeds is None: | ||||||
|             time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) |             time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) | ||||||
|             B, S = time_seq.shape |             B, S = time_seq.shape | ||||||
| @@ -193,41 +193,23 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             B, S, _ = time_embeds.shape |             B, S, _ = time_embeds.shape | ||||||
|         # create joint embed |         # create joint embed | ||||||
|         num_layer, _ = self._super_layer_embed.shape |         num_layer, _ = self._super_layer_embed.shape | ||||||
|         if get_seq_last: |         time_embeds = time_embeds[:, -1, :] | ||||||
|             time_embeds = time_embeds[:, -1, :] |         if tembed_only: | ||||||
|             # The shape of `joint_embed` is batch * num-layers * input-dim |             return time_embeds | ||||||
|             joint_embeds = torch.cat( |         # The shape of `joint_embed` is batch * num-layers * input-dim | ||||||
|                 ( |         joint_embeds = torch.cat( | ||||||
|                     time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), |             ( | ||||||
|                     self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), |                 time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), | ||||||
|                 ), |                 self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), | ||||||
|                 dim=-1, |             ), | ||||||
|             ) |             dim=-1, | ||||||
|         else: |         ) | ||||||
|             # The shape of `joint_embed` is batch * seq * num-layers * input-dim |  | ||||||
|             joint_embeds = torch.cat( |  | ||||||
|                 ( |  | ||||||
|                     time_embeds.view(B, S, 1, -1).expand(-1, -1, num_layer, -1), |  | ||||||
|                     self._super_layer_embed.view(1, 1, num_layer, -1).expand( |  | ||||||
|                         B, S, -1, -1 |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 dim=-1, |  | ||||||
|             ) |  | ||||||
|         batch_weights = self._generator(joint_embeds) |         batch_weights = self._generator(joint_embeds) | ||||||
|         batch_containers = [] |         batch_containers = [] | ||||||
|         for weights in torch.split(batch_weights, 1): |         for weights in torch.split(batch_weights, 1): | ||||||
|             if get_seq_last: |             batch_containers.append( | ||||||
|                 batch_containers.append( |                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||||
|                     self._shape_container.translate(torch.split(weights.squeeze(0), 1)) |             ) | ||||||
|                 ) |  | ||||||
|             else: |  | ||||||
|                 seq_containers = [] |  | ||||||
|                 for ws in torch.split(weights.squeeze(0), 1): |  | ||||||
|                     seq_containers.append( |  | ||||||
|                         self._shape_container.translate(torch.split(ws.squeeze(0), 1)) |  | ||||||
|                     ) |  | ||||||
|                 batch_containers.append(seq_containers) |  | ||||||
|         return time_seq, batch_containers, time_embeds |         return time_seq, batch_containers, time_embeds | ||||||
|  |  | ||||||
|     def forward_candidate(self, input): |     def forward_candidate(self, input): | ||||||
| @@ -241,7 +223,9 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         with torch.set_grad_enabled(True): |         with torch.set_grad_enabled(True): | ||||||
|             new_param = self.create_meta_embed() |             new_param = self.create_meta_embed() | ||||||
|  |  | ||||||
|             optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True) |             optimizer = torch.optim.Adam( | ||||||
|  |                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True | ||||||
|  |             ) | ||||||
|             timestamp = torch.Tensor([timestamp]).to(new_param.device) |             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||||
|             self.replace_append_learnt(timestamp, new_param) |             self.replace_append_learnt(timestamp, new_param) | ||||||
|             self.train() |             self.train() | ||||||
| @@ -255,10 +239,10 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|                 best_new_param = new_param.detach().clone() |                 best_new_param = new_param.detach().clone() | ||||||
|             for iepoch in range(epochs): |             for iepoch in range(epochs): | ||||||
|                 optimizer.zero_grad() |                 optimizer.zero_grad() | ||||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None, True) |                 _, [_], time_embed = self(timestamp.view(1, 1), None) | ||||||
|                 match_loss = criterion(new_param, time_embed) |                 match_loss = criterion(new_param, time_embed) | ||||||
|  |  | ||||||
|                 _, [container], time_embed = self(None, new_param.view(1, 1, -1), True) |                 _, [container], time_embed = self(None, new_param.view(1, 1, -1)) | ||||||
|                 y_hat = base_model.forward_with_container(x, container) |                 y_hat = base_model.forward_with_container(x, container) | ||||||
|                 meta_loss = criterion(y_hat, y) |                 meta_loss = criterion(y_hat, y) | ||||||
|                 loss = meta_loss + match_loss |                 loss = meta_loss + match_loss | ||||||
|   | |||||||
| @@ -1,51 +1,49 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # |  | ||||||
| ##################################################### |  | ||||||
| import math | import math | ||||||
| from .synthetic_utils import TimeStamp | from .synthetic_utils import TimeStamp | ||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
| from .math_core import LinearFunc | from .math_core import LinearFunc | ||||||
| from .math_core import DynamicLinearFunc | from .math_core import DynamicLinearFunc | ||||||
| from .math_core import DynamicQuadraticFunc | from .math_core import DynamicQuadraticFunc | ||||||
| from .math_core import ConstantFunc, ComposedSinFunc | from .math_core import ConstantFunc, ComposedSinFunc as SinFunc | ||||||
| from .math_core import GaussianDGenerator | from .math_core import GaussianDGenerator | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): | def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"): | ||||||
|     if version == "v0": |     max_time = math.pi * 10 | ||||||
|  |     if version == "v1": | ||||||
|         mean_generator = ConstantFunc(0) |         mean_generator = ConstantFunc(0) | ||||||
|         std_generator = ConstantFunc(1) |         std_generator = ConstantFunc(1) | ||||||
|         data_generator = GaussianDGenerator( |         data_generator = GaussianDGenerator( | ||||||
|             [mean_generator], [[std_generator]], (-2, 2) |             [mean_generator], [[std_generator]], (-2, 2) | ||||||
|         ) |         ) | ||||||
|         time_generator = TimeStamp( |         time_generator = TimeStamp( | ||||||
|             min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode |             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||||
|         ) |         ) | ||||||
|         oracle_map = DynamicLinearFunc( |         oracle_map = DynamicLinearFunc( | ||||||
|             params={ |             params={ | ||||||
|                 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), |                 0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),  # 2 sin(t) + 2.2 | ||||||
|                 1: ConstantFunc(0), |                 1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),  # 1.5 sin(0.6t) + 1.8 | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|             data_generator, oracle_map, time_generator, num_per_task |             data_generator, oracle_map, time_generator, num_per_task | ||||||
|         ) |         ) | ||||||
|     elif version == "v1": |     elif version == "v2": | ||||||
|         mean_generator = ConstantFunc(0) |         mean_generator = ConstantFunc(0) | ||||||
|         std_generator = ConstantFunc(1) |         std_generator = ConstantFunc(1) | ||||||
|         data_generator = GaussianDGenerator( |         data_generator = GaussianDGenerator( | ||||||
|             [mean_generator], [[std_generator]], (-2, 2) |             [mean_generator], [[std_generator]], (-2, 2) | ||||||
|         ) |         ) | ||||||
|         time_generator = TimeStamp( |         time_generator = TimeStamp( | ||||||
|             min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode |             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||||
|         ) |         ) | ||||||
|         oracle_map = DynamicLinearFunc( |         oracle_map = DynamicQuadraticFunc( | ||||||
|             params={ |             params={ | ||||||
|                 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), |                 0: LinearFunc(params={0: 0.1, 1: 0}),  # 0.1 * t | ||||||
|                 1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), |                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user