Fix bugs in new env-v1
This commit is contained in:
parent
3ee0d348af
commit
3a2af8e55a
@ -28,7 +28,7 @@ from xautodl.utils import split_str2indexes
|
||||
|
||||
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
|
||||
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
|
||||
from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler
|
||||
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||
from xautodl.models.xcore import get_model
|
||||
from xautodl.xlayers import super_core, trunc_normal_
|
||||
|
||||
@ -244,7 +244,7 @@ def main(args):
|
||||
args.time_dim,
|
||||
timestamps,
|
||||
seq_length=args.seq_length,
|
||||
interval=train_env.timestamp_interval,
|
||||
interval=train_env.time_interval,
|
||||
)
|
||||
meta_model = meta_model.to(args.device)
|
||||
|
||||
@ -253,7 +253,7 @@ def main(args):
|
||||
logger.log("The base-model is\n{:}".format(base_model))
|
||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||
|
||||
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
||||
# batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||
|
||||
# try to evaluate once
|
||||
@ -387,7 +387,7 @@ def main(args):
|
||||
future_time = env_info["{:}-timestamp".format(idx)].item()
|
||||
time_seqs = []
|
||||
for iseq in range(args.seq_length):
|
||||
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
|
||||
time_seqs.append(future_time - iseq * eval_env.time_interval)
|
||||
time_seqs.reverse()
|
||||
with torch.no_grad():
|
||||
meta_model.eval()
|
||||
@ -409,7 +409,7 @@ def main(args):
|
||||
|
||||
# creating the new meta-time-embedding
|
||||
distance = meta_model.get_closest_meta_distance(future_time)
|
||||
if distance < eval_env.timestamp_interval:
|
||||
if distance < eval_env.time_interval:
|
||||
continue
|
||||
#
|
||||
new_param = meta_model.create_meta_embed()
|
||||
|
@ -16,8 +16,8 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
def __init__(
|
||||
self,
|
||||
shape_container,
|
||||
layer_embedding,
|
||||
time_embedding,
|
||||
layer_dim,
|
||||
time_dim,
|
||||
meta_timestamps,
|
||||
mha_depth: int = 2,
|
||||
dropout: float = 0.1,
|
||||
@ -39,53 +39,41 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
|
||||
self.register_parameter(
|
||||
"_super_layer_embed",
|
||||
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)),
|
||||
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
|
||||
)
|
||||
self.register_parameter(
|
||||
"_super_meta_embed",
|
||||
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)),
|
||||
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
|
||||
)
|
||||
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
|
||||
# register a time difference buffer
|
||||
time_interval = [-i * self._interval for i in range(self._seq_length)]
|
||||
time_interval.reverse()
|
||||
self.register_buffer("_time_interval", torch.Tensor(time_interval))
|
||||
self._time_embed_dim = time_embedding
|
||||
self._time_embed_dim = time_dim
|
||||
self._append_meta_embed = dict(fixed=None, learnt=None)
|
||||
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||
|
||||
self._tscalar_embed = super_core.SuperDynamicPositionE(
|
||||
time_embedding, scale=500
|
||||
time_dim, scale=1 / interval
|
||||
)
|
||||
|
||||
# build transformer
|
||||
self._trans_att = super_core.SuperQKVAttentionV2(
|
||||
qk_att_dim=time_embedding,
|
||||
in_v_dim=time_embedding,
|
||||
hidden_dim=time_embedding,
|
||||
qk_att_dim=time_dim,
|
||||
in_v_dim=time_dim,
|
||||
hidden_dim=time_dim,
|
||||
num_heads=4,
|
||||
proj_dim=time_embedding,
|
||||
proj_dim=time_dim,
|
||||
qkv_bias=True,
|
||||
attn_drop=None,
|
||||
proj_drop=dropout,
|
||||
)
|
||||
"""
|
||||
self._trans_att = super_core.SuperQKVAttention(
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
time_embedding,
|
||||
num_heads=4,
|
||||
qkv_bias=True,
|
||||
attn_drop=None,
|
||||
proj_drop=dropout,
|
||||
)
|
||||
"""
|
||||
layers = []
|
||||
for ilayer in range(mha_depth):
|
||||
layers.append(
|
||||
super_core.SuperTransformerEncoderLayer(
|
||||
time_embedding * 2,
|
||||
time_dim * 2,
|
||||
4,
|
||||
True,
|
||||
4,
|
||||
@ -95,14 +83,14 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
use_mask=True,
|
||||
)
|
||||
)
|
||||
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
|
||||
layers.append(super_core.SuperLinear(time_dim * 2, time_dim))
|
||||
self._meta_corrector = super_core.SuperSequential(*layers)
|
||||
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="dual_norm_mlp"),
|
||||
input_dim=layer_embedding + time_embedding,
|
||||
input_dim=layer_dim + time_dim,
|
||||
output_dim=max(self._numel_per_layer),
|
||||
hidden_dims=[(layer_embedding + time_embedding) * 2] * 3,
|
||||
hidden_dims=[(layer_dim + time_dim) * 2] * 3,
|
||||
act_cls="gelu",
|
||||
norm_cls="layer_norm_1d",
|
||||
dropout=dropout,
|
||||
@ -193,11 +181,6 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
# timestamps is a batch of sequence of timestamps
|
||||
batch, seq = timestamps.shape
|
||||
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
||||
"""
|
||||
timestamp_q_embed = self._tscalar_embed(timestamps)
|
||||
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||
"""
|
||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||
timestamp_qk_att_embed = self._tscalar_embed(
|
||||
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
|
||||
@ -212,7 +195,6 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
> self._thresh
|
||||
)
|
||||
timestamp_embeds = self._trans_att(
|
||||
# timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
|
||||
timestamp_qk_att_embed,
|
||||
timestamp_v_embed,
|
||||
mask,
|
||||
|
@ -21,6 +21,8 @@ class DynamicGenerator(abc.ABC):
|
||||
|
||||
|
||||
class GaussianDGenerator(DynamicGenerator):
|
||||
"""Generate data from Gaussian distribution."""
|
||||
|
||||
def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)):
|
||||
super(GaussianDGenerator, self).__init__()
|
||||
self._ndim = assert_list_tuple(mean_functors)
|
||||
@ -41,6 +43,10 @@ class GaussianDGenerator(DynamicGenerator):
|
||||
assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1]
|
||||
self._trunc = trunc
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return self._ndim
|
||||
|
||||
def __call__(self, time, num):
|
||||
mean_list = [functor(time) for functor in self._mean_functors]
|
||||
cov_matrix = [
|
||||
|
@ -115,7 +115,7 @@ class SyntheticDEnv(data.Dataset):
|
||||
name=self.__class__.__name__,
|
||||
cur_num=len(self),
|
||||
total=len(self._time_generator),
|
||||
ndim=self._ndim,
|
||||
ndim=self._data_generator.ndim,
|
||||
num_per_task=self._num_per_task,
|
||||
xrange_min=self.min_timestamp,
|
||||
xrange_max=self.max_timestamp,
|
||||
|
Loading…
Reference in New Issue
Block a user