Fix bugs in new env-v1

This commit is contained in:
D-X-Y 2021-05-24 05:14:39 +00:00
parent 3ee0d348af
commit 3a2af8e55a
4 changed files with 26 additions and 38 deletions

View File

@ -28,7 +28,7 @@ from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core, trunc_normal_
@ -244,7 +244,7 @@ def main(args):
args.time_dim,
timestamps,
seq_length=args.seq_length,
interval=train_env.timestamp_interval,
interval=train_env.time_interval,
)
meta_model = meta_model.to(args.device)
@ -253,7 +253,7 @@ def main(args):
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
# batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
# try to evaluate once
@ -387,7 +387,7 @@ def main(args):
future_time = env_info["{:}-timestamp".format(idx)].item()
time_seqs = []
for iseq in range(args.seq_length):
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
time_seqs.append(future_time - iseq * eval_env.time_interval)
time_seqs.reverse()
with torch.no_grad():
meta_model.eval()
@ -409,7 +409,7 @@ def main(args):
# creating the new meta-time-embedding
distance = meta_model.get_closest_meta_distance(future_time)
if distance < eval_env.timestamp_interval:
if distance < eval_env.time_interval:
continue
#
new_param = meta_model.create_meta_embed()

View File

@ -16,8 +16,8 @@ class LFNA_Meta(super_core.SuperModule):
def __init__(
self,
shape_container,
layer_embedding,
time_embedding,
layer_dim,
time_dim,
meta_timestamps,
mha_depth: int = 2,
dropout: float = 0.1,
@ -39,53 +39,41 @@ class LFNA_Meta(super_core.SuperModule):
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)),
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
)
self.register_parameter(
"_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)),
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
# register a time difference buffer
time_interval = [-i * self._interval for i in range(self._seq_length)]
time_interval.reverse()
self.register_buffer("_time_interval", torch.Tensor(time_interval))
self._time_embed_dim = time_embedding
self._time_embed_dim = time_dim
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._tscalar_embed = super_core.SuperDynamicPositionE(
time_embedding, scale=500
time_dim, scale=1 / interval
)
# build transformer
self._trans_att = super_core.SuperQKVAttentionV2(
qk_att_dim=time_embedding,
in_v_dim=time_embedding,
hidden_dim=time_embedding,
qk_att_dim=time_dim,
in_v_dim=time_dim,
hidden_dim=time_dim,
num_heads=4,
proj_dim=time_embedding,
proj_dim=time_dim,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
"""
self._trans_att = super_core.SuperQKVAttention(
time_embedding,
time_embedding,
time_embedding,
time_embedding,
num_heads=4,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
"""
layers = []
for ilayer in range(mha_depth):
layers.append(
super_core.SuperTransformerEncoderLayer(
time_embedding * 2,
time_dim * 2,
4,
True,
4,
@ -95,14 +83,14 @@ class LFNA_Meta(super_core.SuperModule):
use_mask=True,
)
)
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
layers.append(super_core.SuperLinear(time_dim * 2, time_dim))
self._meta_corrector = super_core.SuperSequential(*layers)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_embedding + time_embedding,
input_dim=layer_dim + time_dim,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_embedding + time_embedding) * 2] * 3,
hidden_dims=[(layer_dim + time_dim) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=dropout,
@ -193,11 +181,6 @@ class LFNA_Meta(super_core.SuperModule):
# timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
"""
timestamp_q_embed = self._tscalar_embed(timestamps)
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
"""
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_qk_att_embed = self._tscalar_embed(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
@ -212,7 +195,6 @@ class LFNA_Meta(super_core.SuperModule):
> self._thresh
)
timestamp_embeds = self._trans_att(
# timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
timestamp_qk_att_embed,
timestamp_v_embed,
mask,

View File

@ -21,6 +21,8 @@ class DynamicGenerator(abc.ABC):
class GaussianDGenerator(DynamicGenerator):
"""Generate data from Gaussian distribution."""
def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)):
super(GaussianDGenerator, self).__init__()
self._ndim = assert_list_tuple(mean_functors)
@ -41,6 +43,10 @@ class GaussianDGenerator(DynamicGenerator):
assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1]
self._trunc = trunc
@property
def ndim(self):
return self._ndim
def __call__(self, time, num):
mean_list = [functor(time) for functor in self._mean_functors]
cov_matrix = [

View File

@ -115,7 +115,7 @@ class SyntheticDEnv(data.Dataset):
name=self.__class__.__name__,
cur_num=len(self),
total=len(self._time_generator),
ndim=self._ndim,
ndim=self._data_generator.ndim,
num_per_task=self._num_per_task,
xrange_min=self.min_timestamp,
xrange_max=self.max_timestamp,