This commit is contained in:
D-X-Y 2021-05-26 01:17:38 +00:00
parent 6c1fd745d7
commit f8350d00ed
3 changed files with 62 additions and 72 deletions

View File

@ -6,10 +6,11 @@
# python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
# python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
#####################################################
import pdb, sys, time, copy, torch, random, argparse
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
from torch.nn import functional as F
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
@ -103,7 +104,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
meta_model.eval()
base_model.eval()
_, [future_container], time_embeds = meta_model(
future_time.to(args.device).view(1, 1), None, True
future_time.to(args.device).view(1, 1), None, False
)
if save:
w_containers[idx] = future_container.no_grad_clone()
@ -159,50 +160,57 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], []
total_future_losses, total_present_losses, total_regu_losses = [], [], []
optimizer.zero_grad()
for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - 1)
timestamp = meta_model.meta_timestamps[rand_index]
meta_embed = meta_model.super_meta_embed[rand_index]
_, [container], time_embed = meta_model(
torch.unsqueeze(timestamp, dim=0), None, True
torch.unsqueeze(timestamp, dim=0), None, False
)
_, (inputs, targets) = xenv(timestamp.item())
inputs, targets = inputs.to(device), targets.to(device)
# generate models one step ahead
predictions = base_model.forward_with_container(inputs, container)
total_meta_v1_losses.append(criterion(predictions, targets))
# the matching loss
match_loss = criterion(torch.squeeze(time_embed, dim=0), meta_embed)
total_match_losses.append(match_loss)
total_future_losses.append(criterion(predictions, targets))
# randomly sample
rand_index = random.randint(0, meta_model.meta_length - 1)
timestamp = meta_model.meta_timestamps[rand_index]
meta_embed = meta_model.super_meta_embed[rand_index]
time_embed = meta_model(torch.unsqueeze(timestamp, dim=0), None, True)
total_regu_losses.append(
F.mse_loss(
torch.squeeze(time_embed, dim=0), meta_embed, reduction="mean"
)
)
# generate models via memory
_, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), True)
_, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), False)
predictions = base_model.forward_with_container(inputs, container)
total_meta_v2_losses.append(criterion(predictions, targets))
total_present_losses.append(criterion(predictions, targets))
with torch.no_grad():
meta_std = torch.stack(total_meta_v1_losses).std().item()
meta_v1_loss = torch.stack(total_meta_v1_losses).mean()
meta_v2_loss = torch.stack(total_meta_v2_losses).mean()
match_loss = torch.stack(total_match_losses).mean()
total_loss = meta_v1_loss + meta_v2_loss + match_loss
meta_std = torch.stack(total_future_losses).std().item()
loss_future = torch.stack(total_future_losses).mean()
loss_present = torch.stack(total_present_losses).mean()
regularization_loss = torch.stack(total_regu_losses).mean()
total_loss = loss_future + loss_present + regularization_loss
total_loss.backward()
optimizer.step()
# success
success, best_score = meta_model.save_best(-total_loss.item())
logger.log(
"{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format(
"{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format(
time_string(),
iepoch,
args.epochs,
total_loss.item(),
meta_std,
meta_v1_loss.item(),
meta_v2_loss.item(),
match_loss.item(),
loss_future.item(),
loss_present.item(),
regularization_loss.item(),
)
+ ", batch={:}".format(len(total_meta_v1_losses))
+ ", batch={:}".format(len(total_future_losses))
+ ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh)
+ ", {:}".format(left_time)

View File

@ -34,7 +34,7 @@ class MetaModelV1(super_core.SuperModule):
assert interval is not None
self._interval = interval
self._seq_length = seq_length
self._thresh = interval * 30 if thresh is None else thresh
self._thresh = interval * 50 if thresh is None else thresh
self.register_parameter(
"_super_layer_embed",
@ -183,7 +183,7 @@ class MetaModelV1(super_core.SuperModule):
)
return timestamp_embeds
def forward_raw(self, timestamps, time_embeds, get_seq_last):
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
if time_embeds is None:
time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1)
B, S = time_seq.shape
@ -193,41 +193,23 @@ class MetaModelV1(super_core.SuperModule):
B, S, _ = time_embeds.shape
# create joint embed
num_layer, _ = self._super_layer_embed.shape
if get_seq_last:
time_embeds = time_embeds[:, -1, :]
# The shape of `joint_embed` is batch * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, 1, -1).expand(-1, num_layer, -1),
self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1),
),
dim=-1,
)
else:
# The shape of `joint_embed` is batch * seq * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, S, 1, -1).expand(-1, -1, num_layer, -1),
self._super_layer_embed.view(1, 1, num_layer, -1).expand(
B, S, -1, -1
),
),
dim=-1,
)
time_embeds = time_embeds[:, -1, :]
if tembed_only:
return time_embeds
# The shape of `joint_embed` is batch * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, 1, -1).expand(-1, num_layer, -1),
self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1),
),
dim=-1,
)
batch_weights = self._generator(joint_embeds)
batch_containers = []
for weights in torch.split(batch_weights, 1):
if get_seq_last:
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
else:
seq_containers = []
for ws in torch.split(weights.squeeze(0), 1):
seq_containers.append(
self._shape_container.translate(torch.split(ws.squeeze(0), 1))
)
batch_containers.append(seq_containers)
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return time_seq, batch_containers, time_embeds
def forward_candidate(self, input):
@ -241,7 +223,9 @@ class MetaModelV1(super_core.SuperModule):
with torch.set_grad_enabled(True):
new_param = self.create_meta_embed()
optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True)
optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
)
timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param)
self.train()
@ -255,10 +239,10 @@ class MetaModelV1(super_core.SuperModule):
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
_, [_], time_embed = self(timestamp.view(1, 1), None, True)
_, [_], time_embed = self(timestamp.view(1, 1), None)
match_loss = criterion(new_param, time_embed)
_, [container], time_embed = self(None, new_param.view(1, 1, -1), True)
_, [container], time_embed = self(None, new_param.view(1, 1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss

View File

@ -1,51 +1,49 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
#####################################################
import math
from .synthetic_utils import TimeStamp
from .synthetic_env import SyntheticDEnv
from .math_core import LinearFunc
from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc
from .math_core import ConstantFunc, ComposedSinFunc
from .math_core import ConstantFunc, ComposedSinFunc as SinFunc
from .math_core import GaussianDGenerator
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"):
if version == "v0":
def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"):
max_time = math.pi * 10
if version == "v1":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
data_generator = GaussianDGenerator(
[mean_generator], [[std_generator]], (-2, 2)
)
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicLinearFunc(
params={
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
1: ConstantFunc(0),
0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2
1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8
}
)
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
elif version == "v1":
elif version == "v2":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
data_generator = GaussianDGenerator(
[mean_generator], [[std_generator]], (-2, 2)
)
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicLinearFunc(
oracle_map = DynamicQuadraticFunc(
params={
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),
0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
}
)
dynamic_env = SyntheticDEnv(