Fix bugs in new env-v1
This commit is contained in:
		| @@ -28,7 +28,7 @@ from xautodl.utils import split_str2indexes | ||||
|  | ||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.xlayers import super_core, trunc_normal_ | ||||
|  | ||||
| @@ -244,7 +244,7 @@ def main(args): | ||||
|         args.time_dim, | ||||
|         timestamps, | ||||
|         seq_length=args.seq_length, | ||||
|         interval=train_env.timestamp_interval, | ||||
|         interval=train_env.time_interval, | ||||
|     ) | ||||
|     meta_model = meta_model.to(args.device) | ||||
|  | ||||
| @@ -253,7 +253,7 @@ def main(args): | ||||
|     logger.log("The base-model is\n{:}".format(base_model)) | ||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||
|  | ||||
|     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||
|     # batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||
|  | ||||
|     # try to evaluate once | ||||
| @@ -387,7 +387,7 @@ def main(args): | ||||
|         future_time = env_info["{:}-timestamp".format(idx)].item() | ||||
|         time_seqs = [] | ||||
|         for iseq in range(args.seq_length): | ||||
|             time_seqs.append(future_time - iseq * eval_env.timestamp_interval) | ||||
|             time_seqs.append(future_time - iseq * eval_env.time_interval) | ||||
|         time_seqs.reverse() | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
| @@ -409,7 +409,7 @@ def main(args): | ||||
|  | ||||
|         # creating the new meta-time-embedding | ||||
|         distance = meta_model.get_closest_meta_distance(future_time) | ||||
|         if distance < eval_env.timestamp_interval: | ||||
|         if distance < eval_env.time_interval: | ||||
|             continue | ||||
|         # | ||||
|         new_param = meta_model.create_meta_embed() | ||||
|   | ||||
| @@ -16,8 +16,8 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|     def __init__( | ||||
|         self, | ||||
|         shape_container, | ||||
|         layer_embedding, | ||||
|         time_embedding, | ||||
|         layer_dim, | ||||
|         time_dim, | ||||
|         meta_timestamps, | ||||
|         mha_depth: int = 2, | ||||
|         dropout: float = 0.1, | ||||
| @@ -39,53 +39,41 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)), | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "_super_meta_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)), | ||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||
|         ) | ||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||
|         # register a time difference buffer | ||||
|         time_interval = [-i * self._interval for i in range(self._seq_length)] | ||||
|         time_interval.reverse() | ||||
|         self.register_buffer("_time_interval", torch.Tensor(time_interval)) | ||||
|         self._time_embed_dim = time_embedding | ||||
|         self._time_embed_dim = time_dim | ||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
|  | ||||
|         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||
|             time_embedding, scale=500 | ||||
|             time_dim, scale=1 / interval | ||||
|         ) | ||||
|  | ||||
|         # build transformer | ||||
|         self._trans_att = super_core.SuperQKVAttentionV2( | ||||
|             qk_att_dim=time_embedding, | ||||
|             in_v_dim=time_embedding, | ||||
|             hidden_dim=time_embedding, | ||||
|             qk_att_dim=time_dim, | ||||
|             in_v_dim=time_dim, | ||||
|             hidden_dim=time_dim, | ||||
|             num_heads=4, | ||||
|             proj_dim=time_embedding, | ||||
|             proj_dim=time_dim, | ||||
|             qkv_bias=True, | ||||
|             attn_drop=None, | ||||
|             proj_drop=dropout, | ||||
|         ) | ||||
|         """ | ||||
|         self._trans_att = super_core.SuperQKVAttention( | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             time_embedding, | ||||
|             num_heads=4, | ||||
|             qkv_bias=True, | ||||
|             attn_drop=None, | ||||
|             proj_drop=dropout, | ||||
|         ) | ||||
|         """ | ||||
|         layers = [] | ||||
|         for ilayer in range(mha_depth): | ||||
|             layers.append( | ||||
|                 super_core.SuperTransformerEncoderLayer( | ||||
|                     time_embedding * 2, | ||||
|                     time_dim * 2, | ||||
|                     4, | ||||
|                     True, | ||||
|                     4, | ||||
| @@ -95,14 +83,14 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|                     use_mask=True, | ||||
|                 ) | ||||
|             ) | ||||
|         layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) | ||||
|         layers.append(super_core.SuperLinear(time_dim * 2, time_dim)) | ||||
|         self._meta_corrector = super_core.SuperSequential(*layers) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             config=dict(model_type="dual_norm_mlp"), | ||||
|             input_dim=layer_embedding + time_embedding, | ||||
|             input_dim=layer_dim + time_dim, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dims=[(layer_embedding + time_embedding) * 2] * 3, | ||||
|             hidden_dims=[(layer_dim + time_dim) * 2] * 3, | ||||
|             act_cls="gelu", | ||||
|             norm_cls="layer_norm_1d", | ||||
|             dropout=dropout, | ||||
| @@ -193,11 +181,6 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         # timestamps is a batch of sequence of timestamps | ||||
|         batch, seq = timestamps.shape | ||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||
|         """ | ||||
|         timestamp_q_embed = self._tscalar_embed(timestamps) | ||||
|         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) | ||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||
|         """ | ||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||
|         timestamp_qk_att_embed = self._tscalar_embed( | ||||
|             torch.unsqueeze(timestamps, dim=-1) - meta_timestamps | ||||
| @@ -212,7 +195,6 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|             > self._thresh | ||||
|         ) | ||||
|         timestamp_embeds = self._trans_att( | ||||
|             # timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask | ||||
|             timestamp_qk_att_embed, | ||||
|             timestamp_v_embed, | ||||
|             mask, | ||||
|   | ||||
| @@ -21,6 +21,8 @@ class DynamicGenerator(abc.ABC): | ||||
|  | ||||
|  | ||||
| class GaussianDGenerator(DynamicGenerator): | ||||
|     """Generate data from Gaussian distribution.""" | ||||
|  | ||||
|     def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)): | ||||
|         super(GaussianDGenerator, self).__init__() | ||||
|         self._ndim = assert_list_tuple(mean_functors) | ||||
| @@ -41,6 +43,10 @@ class GaussianDGenerator(DynamicGenerator): | ||||
|             assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1] | ||||
|         self._trunc = trunc | ||||
|  | ||||
|     @property | ||||
|     def ndim(self): | ||||
|         return self._ndim | ||||
|  | ||||
|     def __call__(self, time, num): | ||||
|         mean_list = [functor(time) for functor in self._mean_functors] | ||||
|         cov_matrix = [ | ||||
|   | ||||
| @@ -115,7 +115,7 @@ class SyntheticDEnv(data.Dataset): | ||||
|             name=self.__class__.__name__, | ||||
|             cur_num=len(self), | ||||
|             total=len(self._time_generator), | ||||
|             ndim=self._ndim, | ||||
|             ndim=self._data_generator.ndim, | ||||
|             num_per_task=self._num_per_task, | ||||
|             xrange_min=self.min_timestamp, | ||||
|             xrange_max=self.max_timestamp, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user