This commit is contained in:
D-X-Y 2021-05-26 01:17:38 +00:00
parent 6c1fd745d7
commit f8350d00ed
3 changed files with 62 additions and 72 deletions

View File

@ -6,10 +6,11 @@
# python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
# python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 # python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
##################################################### #####################################################
import pdb, sys, time, copy, torch, random, argparse import sys, time, copy, torch, random, argparse
from tqdm import tqdm from tqdm import tqdm
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from torch.nn import functional as F
lib_dir = (Path(__file__).parent / ".." / "..").resolve() lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir)) print("LIB-DIR: {:}".format(lib_dir))
@ -103,7 +104,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
meta_model.eval() meta_model.eval()
base_model.eval() base_model.eval()
_, [future_container], time_embeds = meta_model( _, [future_container], time_embeds = meta_model(
future_time.to(args.device).view(1, 1), None, True future_time.to(args.device).view(1, 1), None, False
) )
if save: if save:
w_containers[idx] = future_container.no_grad_clone() w_containers[idx] = future_container.no_grad_clone()
@ -159,50 +160,57 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
left_time = "Time Left: {:}".format( left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
) )
total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], [] total_future_losses, total_present_losses, total_regu_losses = [], [], []
optimizer.zero_grad() optimizer.zero_grad()
for ibatch in range(args.meta_batch): for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - 1) rand_index = random.randint(0, meta_model.meta_length - 1)
timestamp = meta_model.meta_timestamps[rand_index] timestamp = meta_model.meta_timestamps[rand_index]
meta_embed = meta_model.super_meta_embed[rand_index]
_, [container], time_embed = meta_model( _, [container], time_embed = meta_model(
torch.unsqueeze(timestamp, dim=0), None, True torch.unsqueeze(timestamp, dim=0), None, False
) )
_, (inputs, targets) = xenv(timestamp.item()) _, (inputs, targets) = xenv(timestamp.item())
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.to(device), targets.to(device)
# generate models one step ahead # generate models one step ahead
predictions = base_model.forward_with_container(inputs, container) predictions = base_model.forward_with_container(inputs, container)
total_meta_v1_losses.append(criterion(predictions, targets)) total_future_losses.append(criterion(predictions, targets))
# the matching loss # randomly sample
match_loss = criterion(torch.squeeze(time_embed, dim=0), meta_embed) rand_index = random.randint(0, meta_model.meta_length - 1)
total_match_losses.append(match_loss) timestamp = meta_model.meta_timestamps[rand_index]
meta_embed = meta_model.super_meta_embed[rand_index]
time_embed = meta_model(torch.unsqueeze(timestamp, dim=0), None, True)
total_regu_losses.append(
F.mse_loss(
torch.squeeze(time_embed, dim=0), meta_embed, reduction="mean"
)
)
# generate models via memory # generate models via memory
_, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), True) _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), False)
predictions = base_model.forward_with_container(inputs, container) predictions = base_model.forward_with_container(inputs, container)
total_meta_v2_losses.append(criterion(predictions, targets)) total_present_losses.append(criterion(predictions, targets))
with torch.no_grad(): with torch.no_grad():
meta_std = torch.stack(total_meta_v1_losses).std().item() meta_std = torch.stack(total_future_losses).std().item()
meta_v1_loss = torch.stack(total_meta_v1_losses).mean() loss_future = torch.stack(total_future_losses).mean()
meta_v2_loss = torch.stack(total_meta_v2_losses).mean() loss_present = torch.stack(total_present_losses).mean()
match_loss = torch.stack(total_match_losses).mean() regularization_loss = torch.stack(total_regu_losses).mean()
total_loss = meta_v1_loss + meta_v2_loss + match_loss total_loss = loss_future + loss_present + regularization_loss
total_loss.backward() total_loss.backward()
optimizer.step() optimizer.step()
# success # success
success, best_score = meta_model.save_best(-total_loss.item()) success, best_score = meta_model.save_best(-total_loss.item())
logger.log( logger.log(
"{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format( "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format(
time_string(), time_string(),
iepoch, iepoch,
args.epochs, args.epochs,
total_loss.item(), total_loss.item(),
meta_std, meta_std,
meta_v1_loss.item(), loss_future.item(),
meta_v2_loss.item(), loss_present.item(),
match_loss.item(), regularization_loss.item(),
) )
+ ", batch={:}".format(len(total_meta_v1_losses)) + ", batch={:}".format(len(total_future_losses))
+ ", success={:}, best={:.4f}".format(success, -best_score) + ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh)
+ ", {:}".format(left_time) + ", {:}".format(left_time)

View File

@ -34,7 +34,7 @@ class MetaModelV1(super_core.SuperModule):
assert interval is not None assert interval is not None
self._interval = interval self._interval = interval
self._seq_length = seq_length self._seq_length = seq_length
self._thresh = interval * 30 if thresh is None else thresh self._thresh = interval * 50 if thresh is None else thresh
self.register_parameter( self.register_parameter(
"_super_layer_embed", "_super_layer_embed",
@ -183,7 +183,7 @@ class MetaModelV1(super_core.SuperModule):
) )
return timestamp_embeds return timestamp_embeds
def forward_raw(self, timestamps, time_embeds, get_seq_last): def forward_raw(self, timestamps, time_embeds, tembed_only=False):
if time_embeds is None: if time_embeds is None:
time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1)
B, S = time_seq.shape B, S = time_seq.shape
@ -193,41 +193,23 @@ class MetaModelV1(super_core.SuperModule):
B, S, _ = time_embeds.shape B, S, _ = time_embeds.shape
# create joint embed # create joint embed
num_layer, _ = self._super_layer_embed.shape num_layer, _ = self._super_layer_embed.shape
if get_seq_last: time_embeds = time_embeds[:, -1, :]
time_embeds = time_embeds[:, -1, :] if tembed_only:
# The shape of `joint_embed` is batch * num-layers * input-dim return time_embeds
joint_embeds = torch.cat( # The shape of `joint_embed` is batch * num-layers * input-dim
( joint_embeds = torch.cat(
time_embeds.view(B, 1, -1).expand(-1, num_layer, -1), (
self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1), time_embeds.view(B, 1, -1).expand(-1, num_layer, -1),
), self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1),
dim=-1, ),
) dim=-1,
else: )
# The shape of `joint_embed` is batch * seq * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, S, 1, -1).expand(-1, -1, num_layer, -1),
self._super_layer_embed.view(1, 1, num_layer, -1).expand(
B, S, -1, -1
),
),
dim=-1,
)
batch_weights = self._generator(joint_embeds) batch_weights = self._generator(joint_embeds)
batch_containers = [] batch_containers = []
for weights in torch.split(batch_weights, 1): for weights in torch.split(batch_weights, 1):
if get_seq_last: batch_containers.append(
batch_containers.append( self._shape_container.translate(torch.split(weights.squeeze(0), 1))
self._shape_container.translate(torch.split(weights.squeeze(0), 1)) )
)
else:
seq_containers = []
for ws in torch.split(weights.squeeze(0), 1):
seq_containers.append(
self._shape_container.translate(torch.split(ws.squeeze(0), 1))
)
batch_containers.append(seq_containers)
return time_seq, batch_containers, time_embeds return time_seq, batch_containers, time_embeds
def forward_candidate(self, input): def forward_candidate(self, input):
@ -241,7 +223,9 @@ class MetaModelV1(super_core.SuperModule):
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
new_param = self.create_meta_embed() new_param = self.create_meta_embed()
optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True) optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
)
timestamp = torch.Tensor([timestamp]).to(new_param.device) timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param) self.replace_append_learnt(timestamp, new_param)
self.train() self.train()
@ -255,10 +239,10 @@ class MetaModelV1(super_core.SuperModule):
best_new_param = new_param.detach().clone() best_new_param = new_param.detach().clone()
for iepoch in range(epochs): for iepoch in range(epochs):
optimizer.zero_grad() optimizer.zero_grad()
_, [_], time_embed = self(timestamp.view(1, 1), None, True) _, [_], time_embed = self(timestamp.view(1, 1), None)
match_loss = criterion(new_param, time_embed) match_loss = criterion(new_param, time_embed)
_, [container], time_embed = self(None, new_param.view(1, 1, -1), True) _, [container], time_embed = self(None, new_param.view(1, 1, -1))
y_hat = base_model.forward_with_container(x, container) y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y) meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss loss = meta_loss + match_loss

View File

@ -1,51 +1,49 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
#####################################################
import math import math
from .synthetic_utils import TimeStamp from .synthetic_utils import TimeStamp
from .synthetic_env import SyntheticDEnv from .synthetic_env import SyntheticDEnv
from .math_core import LinearFunc from .math_core import LinearFunc
from .math_core import DynamicLinearFunc from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc from .math_core import DynamicQuadraticFunc
from .math_core import ConstantFunc, ComposedSinFunc from .math_core import ConstantFunc, ComposedSinFunc as SinFunc
from .math_core import GaussianDGenerator from .math_core import GaussianDGenerator
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"):
if version == "v0": max_time = math.pi * 10
if version == "v1":
mean_generator = ConstantFunc(0) mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1) std_generator = ConstantFunc(1)
data_generator = GaussianDGenerator( data_generator = GaussianDGenerator(
[mean_generator], [[std_generator]], (-2, 2) [mean_generator], [[std_generator]], (-2, 2)
) )
time_generator = TimeStamp( time_generator = TimeStamp(
min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
) )
oracle_map = DynamicLinearFunc( oracle_map = DynamicLinearFunc(
params={ params={
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), 0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2
1: ConstantFunc(0), 1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8
} }
) )
dynamic_env = SyntheticDEnv( dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task data_generator, oracle_map, time_generator, num_per_task
) )
elif version == "v1": elif version == "v2":
mean_generator = ConstantFunc(0) mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1) std_generator = ConstantFunc(1)
data_generator = GaussianDGenerator( data_generator = GaussianDGenerator(
[mean_generator], [[std_generator]], (-2, 2) [mean_generator], [[std_generator]], (-2, 2)
) )
time_generator = TimeStamp( time_generator = TimeStamp(
min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
) )
oracle_map = DynamicLinearFunc( oracle_map = DynamicQuadraticFunc(
params={ params={
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), 0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t
1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), 1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
} }
) )
dynamic_env = SyntheticDEnv( dynamic_env = SyntheticDEnv(