Fix bugs in new env-v1
This commit is contained in:
parent
3ee0d348af
commit
3a2af8e55a
@ -28,7 +28,7 @@ from xautodl.utils import split_str2indexes
|
|||||||
|
|
||||||
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
|
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
|
||||||
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
|
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
|
||||||
from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler
|
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||||
from xautodl.models.xcore import get_model
|
from xautodl.models.xcore import get_model
|
||||||
from xautodl.xlayers import super_core, trunc_normal_
|
from xautodl.xlayers import super_core, trunc_normal_
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ def main(args):
|
|||||||
args.time_dim,
|
args.time_dim,
|
||||||
timestamps,
|
timestamps,
|
||||||
seq_length=args.seq_length,
|
seq_length=args.seq_length,
|
||||||
interval=train_env.timestamp_interval,
|
interval=train_env.time_interval,
|
||||||
)
|
)
|
||||||
meta_model = meta_model.to(args.device)
|
meta_model = meta_model.to(args.device)
|
||||||
|
|
||||||
@ -253,7 +253,7 @@ def main(args):
|
|||||||
logger.log("The base-model is\n{:}".format(base_model))
|
logger.log("The base-model is\n{:}".format(base_model))
|
||||||
logger.log("The meta-model is\n{:}".format(meta_model))
|
logger.log("The meta-model is\n{:}".format(meta_model))
|
||||||
|
|
||||||
batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
# batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge)
|
||||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||||
|
|
||||||
# try to evaluate once
|
# try to evaluate once
|
||||||
@ -387,7 +387,7 @@ def main(args):
|
|||||||
future_time = env_info["{:}-timestamp".format(idx)].item()
|
future_time = env_info["{:}-timestamp".format(idx)].item()
|
||||||
time_seqs = []
|
time_seqs = []
|
||||||
for iseq in range(args.seq_length):
|
for iseq in range(args.seq_length):
|
||||||
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
|
time_seqs.append(future_time - iseq * eval_env.time_interval)
|
||||||
time_seqs.reverse()
|
time_seqs.reverse()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
meta_model.eval()
|
meta_model.eval()
|
||||||
@ -409,7 +409,7 @@ def main(args):
|
|||||||
|
|
||||||
# creating the new meta-time-embedding
|
# creating the new meta-time-embedding
|
||||||
distance = meta_model.get_closest_meta_distance(future_time)
|
distance = meta_model.get_closest_meta_distance(future_time)
|
||||||
if distance < eval_env.timestamp_interval:
|
if distance < eval_env.time_interval:
|
||||||
continue
|
continue
|
||||||
#
|
#
|
||||||
new_param = meta_model.create_meta_embed()
|
new_param = meta_model.create_meta_embed()
|
||||||
|
@ -16,8 +16,8 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
shape_container,
|
shape_container,
|
||||||
layer_embedding,
|
layer_dim,
|
||||||
time_embedding,
|
time_dim,
|
||||||
meta_timestamps,
|
meta_timestamps,
|
||||||
mha_depth: int = 2,
|
mha_depth: int = 2,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
@ -39,53 +39,41 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
|
|
||||||
self.register_parameter(
|
self.register_parameter(
|
||||||
"_super_layer_embed",
|
"_super_layer_embed",
|
||||||
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)),
|
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
|
||||||
)
|
)
|
||||||
self.register_parameter(
|
self.register_parameter(
|
||||||
"_super_meta_embed",
|
"_super_meta_embed",
|
||||||
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)),
|
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
|
||||||
)
|
)
|
||||||
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
|
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
|
||||||
# register a time difference buffer
|
# register a time difference buffer
|
||||||
time_interval = [-i * self._interval for i in range(self._seq_length)]
|
time_interval = [-i * self._interval for i in range(self._seq_length)]
|
||||||
time_interval.reverse()
|
time_interval.reverse()
|
||||||
self.register_buffer("_time_interval", torch.Tensor(time_interval))
|
self.register_buffer("_time_interval", torch.Tensor(time_interval))
|
||||||
self._time_embed_dim = time_embedding
|
self._time_embed_dim = time_dim
|
||||||
self._append_meta_embed = dict(fixed=None, learnt=None)
|
self._append_meta_embed = dict(fixed=None, learnt=None)
|
||||||
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||||
|
|
||||||
self._tscalar_embed = super_core.SuperDynamicPositionE(
|
self._tscalar_embed = super_core.SuperDynamicPositionE(
|
||||||
time_embedding, scale=500
|
time_dim, scale=1 / interval
|
||||||
)
|
)
|
||||||
|
|
||||||
# build transformer
|
# build transformer
|
||||||
self._trans_att = super_core.SuperQKVAttentionV2(
|
self._trans_att = super_core.SuperQKVAttentionV2(
|
||||||
qk_att_dim=time_embedding,
|
qk_att_dim=time_dim,
|
||||||
in_v_dim=time_embedding,
|
in_v_dim=time_dim,
|
||||||
hidden_dim=time_embedding,
|
hidden_dim=time_dim,
|
||||||
num_heads=4,
|
num_heads=4,
|
||||||
proj_dim=time_embedding,
|
proj_dim=time_dim,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
attn_drop=None,
|
attn_drop=None,
|
||||||
proj_drop=dropout,
|
proj_drop=dropout,
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
self._trans_att = super_core.SuperQKVAttention(
|
|
||||||
time_embedding,
|
|
||||||
time_embedding,
|
|
||||||
time_embedding,
|
|
||||||
time_embedding,
|
|
||||||
num_heads=4,
|
|
||||||
qkv_bias=True,
|
|
||||||
attn_drop=None,
|
|
||||||
proj_drop=dropout,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
layers = []
|
layers = []
|
||||||
for ilayer in range(mha_depth):
|
for ilayer in range(mha_depth):
|
||||||
layers.append(
|
layers.append(
|
||||||
super_core.SuperTransformerEncoderLayer(
|
super_core.SuperTransformerEncoderLayer(
|
||||||
time_embedding * 2,
|
time_dim * 2,
|
||||||
4,
|
4,
|
||||||
True,
|
True,
|
||||||
4,
|
4,
|
||||||
@ -95,14 +83,14 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
use_mask=True,
|
use_mask=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
|
layers.append(super_core.SuperLinear(time_dim * 2, time_dim))
|
||||||
self._meta_corrector = super_core.SuperSequential(*layers)
|
self._meta_corrector = super_core.SuperSequential(*layers)
|
||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
config=dict(model_type="dual_norm_mlp"),
|
config=dict(model_type="dual_norm_mlp"),
|
||||||
input_dim=layer_embedding + time_embedding,
|
input_dim=layer_dim + time_dim,
|
||||||
output_dim=max(self._numel_per_layer),
|
output_dim=max(self._numel_per_layer),
|
||||||
hidden_dims=[(layer_embedding + time_embedding) * 2] * 3,
|
hidden_dims=[(layer_dim + time_dim) * 2] * 3,
|
||||||
act_cls="gelu",
|
act_cls="gelu",
|
||||||
norm_cls="layer_norm_1d",
|
norm_cls="layer_norm_1d",
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -193,11 +181,6 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
# timestamps is a batch of sequence of timestamps
|
# timestamps is a batch of sequence of timestamps
|
||||||
batch, seq = timestamps.shape
|
batch, seq = timestamps.shape
|
||||||
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
||||||
"""
|
|
||||||
timestamp_q_embed = self._tscalar_embed(timestamps)
|
|
||||||
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
|
||||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
|
||||||
"""
|
|
||||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||||
timestamp_qk_att_embed = self._tscalar_embed(
|
timestamp_qk_att_embed = self._tscalar_embed(
|
||||||
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
|
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
|
||||||
@ -212,7 +195,6 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
> self._thresh
|
> self._thresh
|
||||||
)
|
)
|
||||||
timestamp_embeds = self._trans_att(
|
timestamp_embeds = self._trans_att(
|
||||||
# timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
|
|
||||||
timestamp_qk_att_embed,
|
timestamp_qk_att_embed,
|
||||||
timestamp_v_embed,
|
timestamp_v_embed,
|
||||||
mask,
|
mask,
|
||||||
|
@ -21,6 +21,8 @@ class DynamicGenerator(abc.ABC):
|
|||||||
|
|
||||||
|
|
||||||
class GaussianDGenerator(DynamicGenerator):
|
class GaussianDGenerator(DynamicGenerator):
|
||||||
|
"""Generate data from Gaussian distribution."""
|
||||||
|
|
||||||
def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)):
|
def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)):
|
||||||
super(GaussianDGenerator, self).__init__()
|
super(GaussianDGenerator, self).__init__()
|
||||||
self._ndim = assert_list_tuple(mean_functors)
|
self._ndim = assert_list_tuple(mean_functors)
|
||||||
@ -41,6 +43,10 @@ class GaussianDGenerator(DynamicGenerator):
|
|||||||
assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1]
|
assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1]
|
||||||
self._trunc = trunc
|
self._trunc = trunc
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ndim(self):
|
||||||
|
return self._ndim
|
||||||
|
|
||||||
def __call__(self, time, num):
|
def __call__(self, time, num):
|
||||||
mean_list = [functor(time) for functor in self._mean_functors]
|
mean_list = [functor(time) for functor in self._mean_functors]
|
||||||
cov_matrix = [
|
cov_matrix = [
|
||||||
|
@ -115,7 +115,7 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
cur_num=len(self),
|
cur_num=len(self),
|
||||||
total=len(self._time_generator),
|
total=len(self._time_generator),
|
||||||
ndim=self._ndim,
|
ndim=self._data_generator.ndim,
|
||||||
num_per_task=self._num_per_task,
|
num_per_task=self._num_per_task,
|
||||||
xrange_min=self.min_timestamp,
|
xrange_min=self.min_timestamp,
|
||||||
xrange_max=self.max_timestamp,
|
xrange_max=self.max_timestamp,
|
||||||
|
Loading…
Reference in New Issue
Block a user