Fix bugs in new env-v1

This commit is contained in:
D-X-Y 2021-05-24 05:14:39 +00:00
parent 3ee0d348af
commit 3a2af8e55a
4 changed files with 26 additions and 38 deletions

View File

@ -28,7 +28,7 @@ from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core, trunc_normal_ from xautodl.xlayers import super_core, trunc_normal_
@ -244,7 +244,7 @@ def main(args):
args.time_dim, args.time_dim,
timestamps, timestamps,
seq_length=args.seq_length, seq_length=args.seq_length,
interval=train_env.timestamp_interval, interval=train_env.time_interval,
) )
meta_model = meta_model.to(args.device) meta_model = meta_model.to(args.device)
@ -253,7 +253,7 @@ def main(args):
logger.log("The base-model is\n{:}".format(base_model)) logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model)) logger.log("The meta-model is\n{:}".format(meta_model))
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) # batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
# try to evaluate once # try to evaluate once
@ -387,7 +387,7 @@ def main(args):
future_time = env_info["{:}-timestamp".format(idx)].item() future_time = env_info["{:}-timestamp".format(idx)].item()
time_seqs = [] time_seqs = []
for iseq in range(args.seq_length): for iseq in range(args.seq_length):
time_seqs.append(future_time - iseq * eval_env.timestamp_interval) time_seqs.append(future_time - iseq * eval_env.time_interval)
time_seqs.reverse() time_seqs.reverse()
with torch.no_grad(): with torch.no_grad():
meta_model.eval() meta_model.eval()
@ -409,7 +409,7 @@ def main(args):
# creating the new meta-time-embedding # creating the new meta-time-embedding
distance = meta_model.get_closest_meta_distance(future_time) distance = meta_model.get_closest_meta_distance(future_time)
if distance < eval_env.timestamp_interval: if distance < eval_env.time_interval:
continue continue
# #
new_param = meta_model.create_meta_embed() new_param = meta_model.create_meta_embed()

View File

@ -16,8 +16,8 @@ class LFNA_Meta(super_core.SuperModule):
def __init__( def __init__(
self, self,
shape_container, shape_container,
layer_embedding, layer_dim,
time_embedding, time_dim,
meta_timestamps, meta_timestamps,
mha_depth: int = 2, mha_depth: int = 2,
dropout: float = 0.1, dropout: float = 0.1,
@ -39,53 +39,41 @@ class LFNA_Meta(super_core.SuperModule):
self.register_parameter( self.register_parameter(
"_super_layer_embed", "_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)), torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
) )
self.register_parameter( self.register_parameter(
"_super_meta_embed", "_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)), torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
) )
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
# register a time difference buffer # register a time difference buffer
time_interval = [-i * self._interval for i in range(self._seq_length)] time_interval = [-i * self._interval for i in range(self._seq_length)]
time_interval.reverse() time_interval.reverse()
self.register_buffer("_time_interval", torch.Tensor(time_interval)) self.register_buffer("_time_interval", torch.Tensor(time_interval))
self._time_embed_dim = time_embedding self._time_embed_dim = time_dim
self._append_meta_embed = dict(fixed=None, learnt=None) self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None) self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._tscalar_embed = super_core.SuperDynamicPositionE( self._tscalar_embed = super_core.SuperDynamicPositionE(
time_embedding, scale=500 time_dim, scale=1 / interval
) )
# build transformer # build transformer
self._trans_att = super_core.SuperQKVAttentionV2( self._trans_att = super_core.SuperQKVAttentionV2(
qk_att_dim=time_embedding, qk_att_dim=time_dim,
in_v_dim=time_embedding, in_v_dim=time_dim,
hidden_dim=time_embedding, hidden_dim=time_dim,
num_heads=4, num_heads=4,
proj_dim=time_embedding, proj_dim=time_dim,
qkv_bias=True, qkv_bias=True,
attn_drop=None, attn_drop=None,
proj_drop=dropout, proj_drop=dropout,
) )
"""
self._trans_att = super_core.SuperQKVAttention(
time_embedding,
time_embedding,
time_embedding,
time_embedding,
num_heads=4,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
"""
layers = [] layers = []
for ilayer in range(mha_depth): for ilayer in range(mha_depth):
layers.append( layers.append(
super_core.SuperTransformerEncoderLayer( super_core.SuperTransformerEncoderLayer(
time_embedding * 2, time_dim * 2,
4, 4,
True, True,
4, 4,
@ -95,14 +83,14 @@ class LFNA_Meta(super_core.SuperModule):
use_mask=True, use_mask=True,
) )
) )
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) layers.append(super_core.SuperLinear(time_dim * 2, time_dim))
self._meta_corrector = super_core.SuperSequential(*layers) self._meta_corrector = super_core.SuperSequential(*layers)
model_kwargs = dict( model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"), config=dict(model_type="dual_norm_mlp"),
input_dim=layer_embedding + time_embedding, input_dim=layer_dim + time_dim,
output_dim=max(self._numel_per_layer), output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_embedding + time_embedding) * 2] * 3, hidden_dims=[(layer_dim + time_dim) * 2] * 3,
act_cls="gelu", act_cls="gelu",
norm_cls="layer_norm_1d", norm_cls="layer_norm_1d",
dropout=dropout, dropout=dropout,
@ -193,11 +181,6 @@ class LFNA_Meta(super_core.SuperModule):
# timestamps is a batch of sequence of timestamps # timestamps is a batch of sequence of timestamps
batch, seq = timestamps.shape batch, seq = timestamps.shape
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
"""
timestamp_q_embed = self._tscalar_embed(timestamps)
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
"""
timestamp_v_embed = meta_embeds.unsqueeze(dim=0) timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_qk_att_embed = self._tscalar_embed( timestamp_qk_att_embed = self._tscalar_embed(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
@ -212,7 +195,6 @@ class LFNA_Meta(super_core.SuperModule):
> self._thresh > self._thresh
) )
timestamp_embeds = self._trans_att( timestamp_embeds = self._trans_att(
# timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
timestamp_qk_att_embed, timestamp_qk_att_embed,
timestamp_v_embed, timestamp_v_embed,
mask, mask,

View File

@ -21,6 +21,8 @@ class DynamicGenerator(abc.ABC):
class GaussianDGenerator(DynamicGenerator): class GaussianDGenerator(DynamicGenerator):
"""Generate data from Gaussian distribution."""
def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)): def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)):
super(GaussianDGenerator, self).__init__() super(GaussianDGenerator, self).__init__()
self._ndim = assert_list_tuple(mean_functors) self._ndim = assert_list_tuple(mean_functors)
@ -41,6 +43,10 @@ class GaussianDGenerator(DynamicGenerator):
assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1] assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1]
self._trunc = trunc self._trunc = trunc
@property
def ndim(self):
return self._ndim
def __call__(self, time, num): def __call__(self, time, num):
mean_list = [functor(time) for functor in self._mean_functors] mean_list = [functor(time) for functor in self._mean_functors]
cov_matrix = [ cov_matrix = [

View File

@ -115,7 +115,7 @@ class SyntheticDEnv(data.Dataset):
name=self.__class__.__name__, name=self.__class__.__name__,
cur_num=len(self), cur_num=len(self),
total=len(self._time_generator), total=len(self._time_generator),
ndim=self._ndim, ndim=self._data_generator.ndim,
num_per_task=self._num_per_task, num_per_task=self._num_per_task,
xrange_min=self.min_timestamp, xrange_min=self.min_timestamp,
xrange_max=self.max_timestamp, xrange_max=self.max_timestamp,