Update LFNA version 1.0

This commit is contained in:
D-X-Y 2021-05-13 21:33:34 +08:00
parent 3d3a04705f
commit cfabd05de8
11 changed files with 340 additions and 299 deletions

View File

@ -1,7 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
##################################################### #####################################################
# python exps/LFNA/lfna.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 50000 # python exps/LFNA/lfna.py --env_version v1
##################################################### #####################################################
import sys, time, copy, torch, random, argparse import sys, time, copy, torch, random, argparse
from tqdm import tqdm from tqdm import tqdm
@ -19,56 +19,82 @@ from utils import split_str2indexes
from procedures.advanced_main import basic_train_fn, basic_eval_fn from procedures.advanced_main import basic_train_fn, basic_eval_fn
from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from datasets.synthetic_core import get_synthetic_env from datasets.synthetic_core import get_synthetic_env, EnvSampler
from models.xcore import get_model from models.xcore import get_model
from xlayers import super_core, trunc_normal_ from xlayers import super_core, trunc_normal_
from lfna_utils import lfna_setup, train_model, TimeData from lfna_utils import lfna_setup, train_model, TimeData
from lfna_meta_model import LFNA_Meta
from lfna_models_v2 import HyperNet
def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger):
base_model.train()
meta_model.train()
loss_meter = AverageMeter()
for ibatch, batch_data in enumerate(loader):
timestamps, (batch_seq_inputs, batch_seq_targets) = batch_data
timestamps = timestamps.squeeze(dim=-1).to(device)
batch_seq_inputs = batch_seq_inputs.to(device)
batch_seq_targets = batch_seq_targets.to(device)
optimizer.zero_grad()
batch_seq_containers = meta_model(timestamps)
losses = []
for seq_containers, seq_inputs, seq_targets in zip(
batch_seq_containers, batch_seq_inputs, batch_seq_targets
):
for container, inputs, targets in zip(
seq_containers, seq_inputs, seq_targets
):
predictions = base_model.forward_with_container(inputs, container)
loss = criterion(predictions, targets)
losses.append(loss)
final_loss = torch.stack(losses).mean()
final_loss.backward()
optimizer.step()
loss_meter.update(final_loss.item())
return loss_meter
def main(args): def main(args):
logger, env_info, model_kwargs = lfna_setup(args) logger, env_info, model_kwargs = lfna_setup(args)
dynamic_env = env_info["dynamic_env"] dynamic_env = get_synthetic_env(mode="train", version=args.env_version)
model = get_model(**model_kwargs) base_model = get_model(**model_kwargs)
model = model.to(args.device) base_model = base_model.to(args.device)
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
logger.log("There are {:} weights.".format(model.get_w_container().numel())) shape_container = base_model.get_w_container().to_shape_container()
# meta_train_range = (dynamic_env.min_timestamp, (dynamic_env.min_timestamp + dynamic_env.max_timestamp) / 2)
# meta_train_interval = dynamic_env.timestamp_interval
shape_container = model.get_w_container().to_shape_container()
# pre-train the hypernetwork # pre-train the hypernetwork
timestamps = list( timestamps = dynamic_env.get_timestamp(None)
dynamic_env.get_timestamp(index) for index in range(len(dynamic_env) // 2) meta_model = LFNA_Meta(shape_container, args.layer_dim, args.time_dim, timestamps)
meta_model = meta_model.to(args.device)
logger.log("The base-model has {:} weights.".format(base_model.numel()))
logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
batch_sampler = EnvSampler(dynamic_env, args.meta_batch, args.sampler_enlarge)
dynamic_env.reset_max_seq_length(args.seq_length)
"""
env_loader = torch.utils.data.DataLoader(
dynamic_env,
batch_size=args.meta_batch,
shuffle=True,
num_workers=args.workers,
pin_memory=True,
)
"""
env_loader = torch.utils.data.DataLoader(
dynamic_env,
batch_sampler=batch_sampler,
num_workers=args.workers,
pin_memory=True,
) )
hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim, timestamps) optimizer = torch.optim.Adam(
hypernet = hypernet.to(args.device) meta_model.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True
)
import pdb
pdb.set_trace()
# task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim))
total_bar = 16
task_embeds = []
for i in range(total_bar):
tensor = torch.Tensor(1, args.task_dim).to(args.device)
task_embeds.append(torch.nn.Parameter(tensor))
for task_embed in task_embeds:
trunc_normal_(task_embed, std=0.02)
model.train()
hypernet.train()
parameters = list(hypernet.parameters()) + task_embeds
# optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True)
optimizer = torch.optim.Adam(parameters, lr=args.init_lr, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, optimizer,
milestones=[ milestones=[
@ -77,71 +103,59 @@ def main(args):
], ],
gamma=0.1, gamma=0.1,
) )
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
logger.log("The optimizer is\n{:}".format(optimizer))
logger.log("Per epoch iterations = {:}".format(len(env_loader)))
# total_bar = env_info["total"] - 1
# LFNA meta-training # LFNA meta-training
loss_meter = AverageMeter()
per_epoch_time, start_time = AverageMeter(), time.time() per_epoch_time, start_time = AverageMeter(), time.time()
last_success_epoch = 0
for iepoch in range(args.epochs): for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format( head_str = "[{:}] [{:04d}/{:04d}] ".format(
time_string(), iepoch, args.epochs
) + "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
) )
head_str = (
"[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) loss_meter = epoch_train(
+ need_time env_loader,
meta_model,
base_model,
optimizer,
criterion,
args.device,
logger,
) )
losses = []
# for ibatch in range(args.meta_batch):
for cur_time in range(total_bar):
# cur_time = random.randint(0, total_bar)
cur_task_embed = task_embeds[cur_time]
cur_container = hypernet(cur_task_embed)
cur_x = env_info["{:}-x".format(cur_time)].to(args.device)
cur_y = env_info["{:}-y".format(cur_time)].to(args.device)
cur_dataset = TimeData(cur_time, cur_x, cur_y)
preds = model.forward_with_container(cur_dataset.x, cur_container)
optimizer.zero_grad()
loss = criterion(preds, cur_dataset.y)
losses.append(loss)
final_loss = torch.stack(losses).mean()
final_loss.backward()
optimizer.step()
lr_scheduler.step() lr_scheduler.step()
logger.log(
loss_meter.update(final_loss.item()) head_str
if iepoch % 100 == 0: + " meta-loss: {meter.avg:.4f} ({meter.count:.0f})".format(meter=loss_meter)
logger.log( + " :: lr={:.5f}".format(min(lr_scheduler.get_last_lr()))
head_str )
+ " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( success, best_score = meta_model.save_best(-loss_meter.avg)
loss_meter.avg, if success:
loss_meter.val, logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
min(lr_scheduler.get_last_lr()), last_success_epoch = iepoch
len(losses),
)
)
save_checkpoint( save_checkpoint(
{ {
"hypernet": hypernet.state_dict(), "meta_model": meta_model.state_dict(),
"task_embed": task_embed, "optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(), "lr_scheduler": lr_scheduler.state_dict(),
"iepoch": iepoch, "iepoch": iepoch,
"args": args,
}, },
logger.path("model"), logger.path("model"),
logger, logger,
) )
loss_meter.reset() if iepoch - last_success_epoch >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time) per_epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
print(model)
print(hypernet)
w_container_per_epoch = dict() w_container_per_epoch = dict()
for idx in range(0, total_bar): for idx in range(0, total_bar):
future_time = env_info["{:}-timestamp".format(idx)] future_time = env_info["{:}-timestamp".format(idx)]
@ -183,20 +197,26 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--hidden_dim", "--hidden_dim",
type=int, type=int,
required=True, default=16,
help="The hidden dimension.", help="The hidden dimension.",
) )
parser.add_argument( parser.add_argument(
"--layer_dim", "--layer_dim",
type=int, type=int,
required=True, default=16,
help="The hidden dimension.", help="The layer chunk dimension.",
)
parser.add_argument(
"--time_dim",
type=int,
default=16,
help="The timestamp dimension.",
) )
##### #####
parser.add_argument( parser.add_argument(
"--init_lr", "--init_lr",
type=float, type=float,
default=0.1, default=0.01,
help="The initial learning rate for the optimizer (default is Adam)", help="The initial learning rate for the optimizer (default is Adam)",
) )
parser.add_argument( parser.add_argument(
@ -206,10 +226,23 @@ if __name__ == "__main__":
help="The batch size for the meta-model", help="The batch size for the meta-model",
) )
parser.add_argument( parser.add_argument(
"--epochs", "--sampler_enlarge",
type=int, type=int,
default=2000, default=5,
help="The total number of epochs.", help="Enlarge the #iterations for an epoch",
)
parser.add_argument("--epochs", type=int, default=1000, help="The total #epochs.")
parser.add_argument(
"--early_stop_thresh",
type=int,
default=50,
help="The maximum epochs for early stop.",
)
parser.add_argument(
"--seq_length", type=int, default=5, help="The sequence length."
)
parser.add_argument(
"--workers", type=int, default=4, help="The number of workers in parallel."
) )
parser.add_argument( parser.add_argument(
"--device", "--device",
@ -223,8 +256,7 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0: if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None" assert args.save_dir is not None, "The save dir argument can not be None"
args.task_dim = args.layer_dim args.save_dir = "{:}-{:}-d{:}_{:}_{:}".format(
args.save_dir = "{:}-{:}-d{:}".format( args.save_dir, args.env_version, args.hidden_dim, args.layer_dim, args.time_dim
args.save_dir, args.env_version, args.hidden_dim
) )
main(args) main(args)

View File

@ -0,0 +1,128 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
import torch
import torch.nn.functional as F
from xlayers import super_core
from xlayers import trunc_normal_
from models.xcore import get_model
class LFNA_Meta(super_core.SuperModule):
"""Learning to Forecast Neural Adaptation (Meta Model Design)."""
def __init__(
self,
shape_container,
layer_embeding,
time_embedding,
meta_timestamps,
mha_depth: int = 2,
dropout: float = 0.1,
):
super(LFNA_Meta, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self._raw_meta_timestamps = meta_timestamps
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)),
)
self.register_parameter(
"_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)),
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
# build transformer
layers = []
for ilayer in range(mha_depth):
layers.append(
super_core.SuperTransformerEncoderLayer(
time_embedding,
4,
True,
4,
dropout,
norm_affine=False,
order=super_core.LayerOrder.PostNorm,
)
)
self.meta_corrector = super_core.SuperSequential(*layers)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_embeding + time_embedding,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_embeding + time_embedding) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=dropout,
)
self._generator = get_model(**model_kwargs)
# print("generator: {:}".format(self._generator))
# unknown token
self.register_parameter(
"_unknown_token",
torch.nn.Parameter(torch.Tensor(1, time_embedding)),
)
# initialization
trunc_normal_(
[self._super_layer_embed, self._super_meta_embed, self._unknown_token],
std=0.02,
)
def forward_raw(self, timestamps):
# timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape
timestamps = timestamps.unsqueeze(dim=-1)
meta_timestamps = self._meta_timestamps.view(1, 1, -1)
time_diffs = timestamps - meta_timestamps
time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1)
# select corresponding meta-knowledge
meta_match = torch.index_select(
self._super_meta_embed, dim=0, index=time_match_i.view(-1)
)
meta_match = meta_match.view(batch, seq, -1)
# create the probability
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
time_probs[:, -1, :] = 0
unknown_token = self._unknown_token.view(1, 1, -1)
raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token
meta_embed = self.meta_corrector(raw_meta_embed)
# create joint embed
num_layer, _ = self._super_layer_embed.shape
meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1)
layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand(
batch, seq, -1, -1
)
joint_embed = torch.cat((meta_embed, layer_embed), dim=-1)
batch_weights = self._generator(joint_embed)
batch_containers = []
for seq_weights in torch.split(batch_weights, 1):
seq_containers = []
for weights in torch.split(seq_weights.squeeze(0), 1):
weights = torch.split(weights.squeeze(0), 1)
seq_containers.append(self._shape_container.translate(weights))
batch_containers.append(seq_containers)
return batch_containers
def forward_candidate(self, input):
raise NotImplementedError
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
list(self._super_layer_embed.shape),
list(self._super_meta_embed.shape),
list(self._meta_timestamps.shape),
)

View File

@ -1,72 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
import torch
import torch.nn.functional as F
from xlayers import super_core
from xlayers import trunc_normal_
from models.xcore import get_model
class HyperNet(super_core.SuperModule):
"""The hyper-network."""
def __init__(
self,
shape_container,
layer_embeding,
task_embedding,
meta_timestamps,
return_container: bool = True,
):
super(HyperNet, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)),
)
trunc_normal_(self._super_layer_embed, std=0.02)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_embeding + task_embedding,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_embeding + task_embedding) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=0.2,
)
import pdb
pdb.set_trace()
self._generator = get_model(**model_kwargs)
self._return_container = return_container
print("generator: {:}".format(self._generator))
def forward_raw(self, task_embed):
# task_embed = F.normalize(task_embed, dim=-1, p=2)
# layer_embed = F.normalize(self._super_layer_embed, dim=-1, p=2)
layer_embed = self._super_layer_embed
task_embed = task_embed.view(1, -1).expand(self._num_layers, -1)
joint_embed = torch.cat((task_embed, layer_embed), dim=-1)
weights = self._generator(joint_embed)
if self._return_container:
weights = torch.split(weights, 1)
return self._shape_container.translate(weights)
else:
return weights
def forward_candidate(self, input):
raise NotImplementedError
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape))

View File

@ -225,8 +225,8 @@ def visualize_env(save_dir, version):
def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
save_dir = Path(str(save_dir)) save_dir = Path(str(save_dir))
for substr in ("pdf", "png"): for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr sub_save_dir = save_dir / substr
sub_save_dir.mkdir(parents=True, exist_ok=True) sub_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 30, 3200, 2000 dpi, width, height = 30, 3200, 2000
figsize = width / float(dpi), height / float(dpi) figsize = width / float(dpi), height / float(dpi)

View File

@ -2,6 +2,7 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
##################################################### #####################################################
from .synthetic_utils import TimeStamp from .synthetic_utils import TimeStamp
from .synthetic_env import EnvSampler
from .synthetic_env import SyntheticDEnv from .synthetic_env import SyntheticDEnv
from .math_core import LinearFunc from .math_core import LinearFunc
from .math_core import DynamicLinearFunc from .math_core import DynamicLinearFunc

View File

@ -2,7 +2,7 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
##################################################### #####################################################
import math import math
import abc import random
import numpy as np import numpy as np
from typing import List, Optional, Dict from typing import List, Optional, Dict
import torch import torch
@ -11,6 +11,28 @@ import torch.utils.data as data
from .synthetic_utils import TimeStamp from .synthetic_utils import TimeStamp
def is_list_tuple(x):
return isinstance(x, (tuple, list))
def zip_sequence(sequence):
def _combine(*alist):
if is_list_tuple(alist[0]):
return [_combine(*xlist) for xlist in zip(*alist)]
else:
return torch.cat(alist, dim=0)
def unsqueeze(a):
if is_list_tuple(a):
return [unsqueeze(x) for x in a]
else:
return a.unsqueeze(dim=0)
with torch.no_grad():
sequence = [unsqueeze(a) for a in sequence]
return _combine(*sequence)
class SyntheticDEnv(data.Dataset): class SyntheticDEnv(data.Dataset):
"""The synethtic dynamic environment.""" """The synethtic dynamic environment."""
@ -33,7 +55,7 @@ class SyntheticDEnv(data.Dataset):
self._num_per_task = num_per_task self._num_per_task = num_per_task
if timestamp_config is None: if timestamp_config is None:
timestamp_config = dict(mode=mode) timestamp_config = dict(mode=mode)
else: elif "mode" not in timestamp_config:
timestamp_config["mode"] = mode timestamp_config["mode"] = mode
self._timestamp_generator = TimeStamp(**timestamp_config) self._timestamp_generator = TimeStamp(**timestamp_config)
@ -42,6 +64,7 @@ class SyntheticDEnv(data.Dataset):
self._cov_functors = cov_functors self._cov_functors = cov_functors
self._oracle_map = None self._oracle_map = None
self._seq_length = None
@property @property
def min_timestamp(self): def min_timestamp(self):
@ -55,9 +78,18 @@ class SyntheticDEnv(data.Dataset):
def timestamp_interval(self): def timestamp_interval(self):
return self._timestamp_generator.interval return self._timestamp_generator.interval
def reset_max_seq_length(self, seq_length):
self._seq_length = seq_length
def get_timestamp(self, index): def get_timestamp(self, index):
index, timestamp = self._timestamp_generator[index] if index is None:
return timestamp timestamps = []
for index in range(len(self._timestamp_generator)):
timestamps.append(self._timestamp_generator[index][1])
return tuple(timestamps)
else:
index, timestamp = self._timestamp_generator[index]
return timestamp
def set_oracle_map(self, functor): def set_oracle_map(self, functor):
self._oracle_map = functor self._oracle_map = functor
@ -75,7 +107,14 @@ class SyntheticDEnv(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index, timestamp = self._timestamp_generator[index] index, timestamp = self._timestamp_generator[index]
return self.__call__(timestamp) if self._seq_length is None:
return self.__call__(timestamp)
else:
timestamps = [
timestamp + i * self.timestamp_interval for i in range(self._seq_length)
]
xdata = [self.__call__(timestamp) for timestamp in timestamps]
return zip_sequence(xdata)
def __call__(self, timestamp): def __call__(self, timestamp):
mean_list = [functor(timestamp) for functor in self._mean_functors] mean_list = [functor(timestamp) for functor in self._mean_functors]
@ -88,10 +127,13 @@ class SyntheticDEnv(data.Dataset):
mean_list, cov_matrix, size=self._num_per_task mean_list, cov_matrix, size=self._num_per_task
) )
if self._oracle_map is None: if self._oracle_map is None:
return timestamp, torch.Tensor(dataset) return torch.Tensor([timestamp]), torch.Tensor(dataset)
else: else:
targets = self._oracle_map.noise_call(dataset, timestamp) targets = self._oracle_map.noise_call(dataset, timestamp)
return timestamp, (torch.Tensor(dataset), torch.Tensor(targets)) return torch.Tensor([timestamp]), (
torch.Tensor(dataset),
torch.Tensor(targets),
)
def __len__(self): def __len__(self):
return len(self._timestamp_generator) return len(self._timestamp_generator)
@ -104,3 +146,20 @@ class SyntheticDEnv(data.Dataset):
ndim=self._ndim, ndim=self._ndim,
num_per_task=self._num_per_task, num_per_task=self._num_per_task,
) )
class EnvSampler:
def __init__(self, env, batch, enlarge):
indexes = list(range(len(env)))
self._indexes = indexes * enlarge
self._batch = batch
self._iterations = len(self._indexes) // self._batch
def __iter__(self):
random.shuffle(self._indexes)
for it in range(self._iterations):
indexes = self._indexes[it * self._batch : (it + 1) * self._batch]
yield indexes
def __len__(self):
return self._iterations

View File

@ -30,6 +30,7 @@ class UnifiedSplit:
self._indexes = all_indexes[num_of_train + num_of_valid :] self._indexes = all_indexes[num_of_train + num_of_valid :]
else: else:
raise ValueError("Unkonwn mode of {:}".format(mode)) raise ValueError("Unkonwn mode of {:}".format(mode))
self._all_indexes = all_indexes
self._mode = mode self._mode = mode
@property @property

View File

@ -1,120 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# DISABLED / NOT-FINISHED
#####################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Callable
import spaces
from .super_container import SuperSequential
from .super_linear import SuperLinear
class SuperActor(SuperModule):
"""A Actor in RL."""
def _distribution(self, obs):
raise NotImplementedError
def _log_prob_from_distribution(self, pi, act):
raise NotImplementedError
def forward_candidate(self, **kwargs):
return self.forward_raw(**kwargs)
def forward_raw(self, obs, act=None):
# Produce action distributions for given observations, and
# optionally compute the log likelihood of given actions under
# those distributions.
pi = self._distribution(obs)
logp_a = None
if act is not None:
logp_a = self._log_prob_from_distribution(pi, act)
return pi, logp_a
class SuperLfnaMetaMLP(SuperModule):
def __init__(self, obs_dim, hidden_sizes, act_cls):
super(SuperLfnaMetaMLP).__init__()
self.delta_net = SuperSequential(
SuperLinear(obs_dim, hidden_sizes[0]),
act_cls(),
SuperLinear(hidden_sizes[0], hidden_sizes[1]),
act_cls(),
SuperLinear(hidden_sizes[1], 1),
)
class SuperLfnaMetaMLP(SuperModule):
def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls):
super(SuperLfnaMetaMLP).__init__()
log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
self.mu_net = SuperSequential(
SuperLinear(obs_dim, hidden_sizes[0]),
act_cls(),
SuperLinear(hidden_sizes[0], hidden_sizes[1]),
act_cls(),
SuperLinear(hidden_sizes[1], act_dim),
)
def _distribution(self, obs):
mu = self.mu_net(obs)
std = torch.exp(self.log_std)
return Normal(mu, std)
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act).sum(axis=-1)
def forward_candidate(self, **kwargs):
return self.forward_raw(**kwargs)
def forward_raw(self, obs, act=None):
# Produce action distributions for given observations, and
# optionally compute the log likelihood of given actions under
# those distributions.
pi = self._distribution(obs)
logp_a = None
if act is not None:
logp_a = self._log_prob_from_distribution(pi, act)
return pi, logp_a
class SuperMLPGaussianActor(SuperModule):
def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls):
super(SuperMLPGaussianActor).__init__()
log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
self.mu_net = SuperSequential(
SuperLinear(obs_dim, hidden_sizes[0]),
act_cls(),
SuperLinear(hidden_sizes[0], hidden_sizes[1]),
act_cls(),
SuperLinear(hidden_sizes[1], act_dim),
)
def _distribution(self, obs):
mu = self.mu_net(obs)
std = torch.exp(self.log_std)
return Normal(mu, std)
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act).sum(axis=-1)
def forward_candidate(self, **kwargs):
return self.forward_raw(**kwargs)
def forward_raw(self, obs, act=None):
# Produce action distributions for given observations, and
# optionally compute the log likelihood of given actions under
# those distributions.
pi = self._distribution(obs)
logp_a = None
if act is not None:
logp_a = self._log_prob_from_distribution(pi, act)
return pi, logp_a

View File

@ -42,6 +42,7 @@ class SuperTransformerEncoderLayer(SuperModule):
qkv_bias: BoolSpaceType = False, qkv_bias: BoolSpaceType = False,
mlp_hidden_multiplier: IntSpaceType = 4, mlp_hidden_multiplier: IntSpaceType = 4,
drop: Optional[float] = None, drop: Optional[float] = None,
norm_affine: bool = True,
act_layer: Callable[[], nn.Module] = nn.GELU, act_layer: Callable[[], nn.Module] = nn.GELU,
order: LayerOrder = LayerOrder.PreNorm, order: LayerOrder = LayerOrder.PreNorm,
): ):
@ -62,19 +63,19 @@ class SuperTransformerEncoderLayer(SuperModule):
drop=drop, drop=drop,
) )
if order is LayerOrder.PreNorm: if order is LayerOrder.PreNorm:
self.norm1 = SuperLayerNorm1D(d_model) self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mha = mha self.mha = mha
self.drop1 = nn.Dropout(drop or 0.0) self.drop1 = nn.Dropout(drop or 0.0)
self.norm2 = SuperLayerNorm1D(d_model) self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mlp = mlp self.mlp = mlp
self.drop2 = nn.Dropout(drop or 0.0) self.drop2 = nn.Dropout(drop or 0.0)
elif order is LayerOrder.PostNorm: elif order is LayerOrder.PostNorm:
self.mha = mha self.mha = mha
self.drop1 = nn.Dropout(drop or 0.0) self.drop1 = nn.Dropout(drop or 0.0)
self.norm1 = SuperLayerNorm1D(d_model) self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
self.mlp = mlp self.mlp = mlp
self.drop2 = nn.Dropout(drop or 0.0) self.drop2 = nn.Dropout(drop or 0.0)
self.norm2 = SuperLayerNorm1D(d_model) self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
else: else:
raise ValueError("Unknown order: {:}".format(order)) raise ValueError("Unknown order: {:}".format(order))
self._order = order self._order = order

View File

@ -60,4 +60,7 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
>>> w = torch.empty(3, 5) >>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w) >>> nn.init.trunc_normal_(w)
""" """
return _no_grad_trunc_normal_(tensor, mean, std, a, b) if isinstance(tensor, list):
return [_no_grad_trunc_normal_(x, mean, std, a, b) for x in tensor]
else:
return _no_grad_trunc_normal_(tensor, mean, std, a, b)

View File

@ -23,8 +23,16 @@ class TestSynethicEnv(unittest.TestCase):
def test_simple(self): def test_simple(self):
mean_generator = ComposedSinFunc(constant=0.1) mean_generator = ComposedSinFunc(constant=0.1)
std_generator = ConstantFunc(constant=0.5) std_generator = ConstantFunc(constant=0.5)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
print(dataset) print(dataset)
for timestamp, tau in dataset: for timestamp, tau in dataset:
assert tau.shape == (5000, 1) self.assertEqual(tau.shape, (5000, 1))
def test_length(self):
mean_generator = ComposedSinFunc(constant=0.1)
std_generator = ConstantFunc(constant=0.5)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
self.assertEqual(len(dataset), 100)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], mode="train")
self.assertEqual(len(dataset), 60)