Fix bugs
This commit is contained in:
		| @@ -20,7 +20,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         time_dim, | ||||
|         meta_timestamps, | ||||
|         dropout: float = 0.1, | ||||
|         seq_length: int = 10, | ||||
|         seq_length: int = None, | ||||
|         interval: float = None, | ||||
|         thresh: float = None, | ||||
|     ): | ||||
| @@ -33,8 +33,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         self._raw_meta_timestamps = meta_timestamps | ||||
|         assert interval is not None | ||||
|         self._interval = interval | ||||
|         self._seq_length = seq_length | ||||
|         self._thresh = interval * 50 if thresh is None else thresh | ||||
|         self._thresh = interval * seq_length if thresh is None else thresh | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
| @@ -45,10 +44,6 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||
|         ) | ||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||
|         # register a time difference buffer | ||||
|         # time_interval = [-i * self._interval for i in range(self._seq_length)] | ||||
|         # time_interval.reverse() | ||||
|         # self.register_buffer("_time_interval", torch.Tensor(time_interval)) | ||||
|         self._time_embed_dim = time_dim | ||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
| @@ -186,7 +181,6 @@ class MetaModelV1(super_core.SuperModule): | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         if time_embeds is None: | ||||
|             # time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) | ||||
|             [B] = timestamps.shape | ||||
|             time_embeds = self._obtain_time_embed(timestamps) | ||||
|         else:  # use the hyper-net only | ||||
| @@ -210,7 +204,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             batch_containers.append( | ||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||
|             ) | ||||
|         return time_seq, batch_containers, time_embeds | ||||
|         return batch_containers, time_embeds | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
| @@ -239,10 +233,10 @@ class MetaModelV1(super_core.SuperModule): | ||||
|                 best_new_param = new_param.detach().clone() | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None) | ||||
|                 _, time_embed = self(timestamp.view(1), None) | ||||
|                 match_loss = criterion(new_param, time_embed) | ||||
|  | ||||
|                 _, [container], time_embed = self(None, new_param.view(1, -1)) | ||||
|                 [container], time_embed = self(None, new_param.view(1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|                 meta_loss = criterion(y_hat, y) | ||||
|                 loss = meta_loss + match_loss | ||||
|   | ||||
| @@ -46,8 +46,8 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             _, [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(1, 1), None, False | ||||
|             [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(-1), None, False | ||||
|             ) | ||||
|             if save: | ||||
|                 w_containers[idx] = future_container.no_grad_clone() | ||||
| @@ -117,10 +117,10 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         ) | ||||
|         # future loss | ||||
|         total_future_losses, total_present_losses = [], [] | ||||
|         _, future_containers, _ = meta_model( | ||||
|         future_containers, _ = meta_model( | ||||
|             None, generated_time_embeds[batch_indexes], False | ||||
|         ) | ||||
|         _, present_containers, _ = meta_model( | ||||
|         present_containers, _ = meta_model( | ||||
|             None, meta_model.super_meta_embed[batch_indexes], False | ||||
|         ) | ||||
|         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): | ||||
|   | ||||
| @@ -1,6 +1,3 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import numpy as np | ||||
| @@ -13,11 +10,11 @@ class UnifiedSplit: | ||||
|     """A class to unify the split strategy.""" | ||||
|  | ||||
|     def __init__(self, total_num, mode): | ||||
|         # Training Set 65% | ||||
|         num_of_train = int(total_num * 0.65) | ||||
|         # Training Set 75% | ||||
|         num_of_train = int(total_num * 0.75) | ||||
|         # Validation Set 05% | ||||
|         num_of_valid = int(total_num * 0.05) | ||||
|         # Test Set 30% | ||||
|         # Test Set 20% | ||||
|         num_of_set = total_num - num_of_train - num_of_valid | ||||
|         all_indexes = list(range(total_num)) | ||||
|         if mode is None: | ||||
|   | ||||
| @@ -1,27 +1,32 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def count_parameters_in_MB(model): | ||||
|     return count_parameters(model, "mb") | ||||
|     return count_parameters(model, "mb", deprecated=True) | ||||
|  | ||||
|  | ||||
| def count_parameters(model_or_parameters, unit="mb"): | ||||
| def count_parameters(model_or_parameters, unit="mb", deprecated=False): | ||||
|     if isinstance(model_or_parameters, nn.Module): | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||
|     elif isinstance(model_or_parameters, nn.Parameter): | ||||
|         counts = models_or_parameters.numel() | ||||
|     elif isinstance(model_or_parameters, (list, tuple)): | ||||
|         counts = sum(count_parameters(x, None) for x in models_or_parameters) | ||||
|         counts = sum( | ||||
|             count_parameters(x, None, deprecated) for x in models_or_parameters | ||||
|         ) | ||||
|     else: | ||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters) | ||||
|     if unit.lower() == "kb" or unit.lower() == "k": | ||||
|         counts /= 2 ** 10  # changed from 1e3 to 2^10 | ||||
|         counts /= 1e3 if deprecated else 2 ** 10  # changed from 1e3 to 2^10 | ||||
|     elif unit.lower() == "mb" or unit.lower() == "m": | ||||
|         counts /= 2 ** 20  # changed from 1e6 to 2^20 | ||||
|         counts /= 1e6 if deprecated else 2 ** 20  # changed from 1e6 to 2^20 | ||||
|     elif unit.lower() == "gb" or unit.lower() == "g": | ||||
|         counts /= 2 ** 30  # changed from 1e9 to 2^30 | ||||
|         counts /= 1e9 if deprecated else 2 ** 30  # changed from 1e9 to 2^30 | ||||
|     elif unit is not None: | ||||
|         raise ValueError("Unknow unit: {:}".format(unit)) | ||||
|     return counts | ||||
|   | ||||
		Reference in New Issue
	
	Block a user