From f8350d00edfc3110ab0b772a6988e6d16806c2af Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 26 May 2021 01:17:38 +0000 Subject: [PATCH] Updates --- exps/GMOA/lfna.py | 50 +++++++++++++++----------- exps/GMOA/lfna_meta_model.py | 58 +++++++++++------------------- xautodl/datasets/synthetic_core.py | 26 +++++++------- 3 files changed, 62 insertions(+), 72 deletions(-) diff --git a/exps/GMOA/lfna.py b/exps/GMOA/lfna.py index 93ae7d0..004896b 100644 --- a/exps/GMOA/lfna.py +++ b/exps/GMOA/lfna.py @@ -6,10 +6,11 @@ # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 ##################################################### -import pdb, sys, time, copy, torch, random, argparse +import sys, time, copy, torch, random, argparse from tqdm import tqdm from copy import deepcopy from pathlib import Path +from torch.nn import functional as F lib_dir = (Path(__file__).parent / ".." / "..").resolve() print("LIB-DIR: {:}".format(lib_dir)) @@ -103,7 +104,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F meta_model.eval() base_model.eval() _, [future_container], time_embeds = meta_model( - future_time.to(args.device).view(1, 1), None, True + future_time.to(args.device).view(1, 1), None, False ) if save: w_containers[idx] = future_container.no_grad_clone() @@ -159,50 +160,57 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): left_time = "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) ) - total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], [] + total_future_losses, total_present_losses, total_regu_losses = [], [], [] optimizer.zero_grad() for ibatch in range(args.meta_batch): rand_index = random.randint(0, meta_model.meta_length - 1) timestamp = meta_model.meta_timestamps[rand_index] - meta_embed = meta_model.super_meta_embed[rand_index] _, [container], time_embed = meta_model( - torch.unsqueeze(timestamp, dim=0), None, True + torch.unsqueeze(timestamp, dim=0), None, False ) _, (inputs, targets) = xenv(timestamp.item()) inputs, targets = inputs.to(device), targets.to(device) # generate models one step ahead predictions = base_model.forward_with_container(inputs, container) - total_meta_v1_losses.append(criterion(predictions, targets)) - # the matching loss - match_loss = criterion(torch.squeeze(time_embed, dim=0), meta_embed) - total_match_losses.append(match_loss) + total_future_losses.append(criterion(predictions, targets)) + # randomly sample + rand_index = random.randint(0, meta_model.meta_length - 1) + timestamp = meta_model.meta_timestamps[rand_index] + meta_embed = meta_model.super_meta_embed[rand_index] + + time_embed = meta_model(torch.unsqueeze(timestamp, dim=0), None, True) + total_regu_losses.append( + F.mse_loss( + torch.squeeze(time_embed, dim=0), meta_embed, reduction="mean" + ) + ) # generate models via memory - _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), True) + _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), False) predictions = base_model.forward_with_container(inputs, container) - total_meta_v2_losses.append(criterion(predictions, targets)) + total_present_losses.append(criterion(predictions, targets)) with torch.no_grad(): - meta_std = torch.stack(total_meta_v1_losses).std().item() - meta_v1_loss = torch.stack(total_meta_v1_losses).mean() - meta_v2_loss = torch.stack(total_meta_v2_losses).mean() - match_loss = torch.stack(total_match_losses).mean() - total_loss = meta_v1_loss + meta_v2_loss + match_loss + meta_std = torch.stack(total_future_losses).std().item() + loss_future = torch.stack(total_future_losses).mean() + loss_present = torch.stack(total_present_losses).mean() + regularization_loss = torch.stack(total_regu_losses).mean() + total_loss = loss_future + loss_present + regularization_loss total_loss.backward() optimizer.step() # success success, best_score = meta_model.save_best(-total_loss.item()) logger.log( - "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format( + "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format( time_string(), iepoch, args.epochs, total_loss.item(), meta_std, - meta_v1_loss.item(), - meta_v2_loss.item(), - match_loss.item(), + loss_future.item(), + loss_present.item(), + regularization_loss.item(), ) - + ", batch={:}".format(len(total_meta_v1_losses)) + + ", batch={:}".format(len(total_future_losses)) + ", success={:}, best={:.4f}".format(success, -best_score) + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) + ", {:}".format(left_time) diff --git a/exps/GMOA/lfna_meta_model.py b/exps/GMOA/lfna_meta_model.py index 10ebff1..dc61b47 100644 --- a/exps/GMOA/lfna_meta_model.py +++ b/exps/GMOA/lfna_meta_model.py @@ -34,7 +34,7 @@ class MetaModelV1(super_core.SuperModule): assert interval is not None self._interval = interval self._seq_length = seq_length - self._thresh = interval * 30 if thresh is None else thresh + self._thresh = interval * 50 if thresh is None else thresh self.register_parameter( "_super_layer_embed", @@ -183,7 +183,7 @@ class MetaModelV1(super_core.SuperModule): ) return timestamp_embeds - def forward_raw(self, timestamps, time_embeds, get_seq_last): + def forward_raw(self, timestamps, time_embeds, tembed_only=False): if time_embeds is None: time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) B, S = time_seq.shape @@ -193,41 +193,23 @@ class MetaModelV1(super_core.SuperModule): B, S, _ = time_embeds.shape # create joint embed num_layer, _ = self._super_layer_embed.shape - if get_seq_last: - time_embeds = time_embeds[:, -1, :] - # The shape of `joint_embed` is batch * num-layers * input-dim - joint_embeds = torch.cat( - ( - time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), - self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), - ), - dim=-1, - ) - else: - # The shape of `joint_embed` is batch * seq * num-layers * input-dim - joint_embeds = torch.cat( - ( - time_embeds.view(B, S, 1, -1).expand(-1, -1, num_layer, -1), - self._super_layer_embed.view(1, 1, num_layer, -1).expand( - B, S, -1, -1 - ), - ), - dim=-1, - ) + time_embeds = time_embeds[:, -1, :] + if tembed_only: + return time_embeds + # The shape of `joint_embed` is batch * num-layers * input-dim + joint_embeds = torch.cat( + ( + time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), + self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), + ), + dim=-1, + ) batch_weights = self._generator(joint_embeds) batch_containers = [] for weights in torch.split(batch_weights, 1): - if get_seq_last: - batch_containers.append( - self._shape_container.translate(torch.split(weights.squeeze(0), 1)) - ) - else: - seq_containers = [] - for ws in torch.split(weights.squeeze(0), 1): - seq_containers.append( - self._shape_container.translate(torch.split(ws.squeeze(0), 1)) - ) - batch_containers.append(seq_containers) + batch_containers.append( + self._shape_container.translate(torch.split(weights.squeeze(0), 1)) + ) return time_seq, batch_containers, time_embeds def forward_candidate(self, input): @@ -241,7 +223,9 @@ class MetaModelV1(super_core.SuperModule): with torch.set_grad_enabled(True): new_param = self.create_meta_embed() - optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True) + optimizer = torch.optim.Adam( + [new_param], lr=lr, weight_decay=1e-5, amsgrad=True + ) timestamp = torch.Tensor([timestamp]).to(new_param.device) self.replace_append_learnt(timestamp, new_param) self.train() @@ -255,10 +239,10 @@ class MetaModelV1(super_core.SuperModule): best_new_param = new_param.detach().clone() for iepoch in range(epochs): optimizer.zero_grad() - _, [_], time_embed = self(timestamp.view(1, 1), None, True) + _, [_], time_embed = self(timestamp.view(1, 1), None) match_loss = criterion(new_param, time_embed) - _, [container], time_embed = self(None, new_param.view(1, 1, -1), True) + _, [container], time_embed = self(None, new_param.view(1, 1, -1)) y_hat = base_model.forward_with_container(x, container) meta_loss = criterion(y_hat, y) loss = meta_loss + match_loss diff --git a/xautodl/datasets/synthetic_core.py b/xautodl/datasets/synthetic_core.py index d656597..1c7cac0 100644 --- a/xautodl/datasets/synthetic_core.py +++ b/xautodl/datasets/synthetic_core.py @@ -1,51 +1,49 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # -##################################################### import math from .synthetic_utils import TimeStamp from .synthetic_env import SyntheticDEnv from .math_core import LinearFunc from .math_core import DynamicLinearFunc from .math_core import DynamicQuadraticFunc -from .math_core import ConstantFunc, ComposedSinFunc +from .math_core import ConstantFunc, ComposedSinFunc as SinFunc from .math_core import GaussianDGenerator __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] -def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): - if version == "v0": +def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"): + max_time = math.pi * 10 + if version == "v1": mean_generator = ConstantFunc(0) std_generator = ConstantFunc(1) data_generator = GaussianDGenerator( [mean_generator], [[std_generator]], (-2, 2) ) time_generator = TimeStamp( - min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode + min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode ) oracle_map = DynamicLinearFunc( params={ - 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), - 1: ConstantFunc(0), + 0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2 + 1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8 } ) dynamic_env = SyntheticDEnv( data_generator, oracle_map, time_generator, num_per_task ) - elif version == "v1": + elif version == "v2": mean_generator = ConstantFunc(0) std_generator = ConstantFunc(1) data_generator = GaussianDGenerator( [mean_generator], [[std_generator]], (-2, 2) ) time_generator = TimeStamp( - min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode + min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode ) - oracle_map = DynamicLinearFunc( + oracle_map = DynamicQuadraticFunc( params={ - 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), - 1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), + 0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t + 1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t) } ) dynamic_env = SyntheticDEnv(