Fix bugs
This commit is contained in:
		| @@ -20,7 +20,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         time_dim, |         time_dim, | ||||||
|         meta_timestamps, |         meta_timestamps, | ||||||
|         dropout: float = 0.1, |         dropout: float = 0.1, | ||||||
|         seq_length: int = 10, |         seq_length: int = None, | ||||||
|         interval: float = None, |         interval: float = None, | ||||||
|         thresh: float = None, |         thresh: float = None, | ||||||
|     ): |     ): | ||||||
| @@ -33,8 +33,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         self._raw_meta_timestamps = meta_timestamps |         self._raw_meta_timestamps = meta_timestamps | ||||||
|         assert interval is not None |         assert interval is not None | ||||||
|         self._interval = interval |         self._interval = interval | ||||||
|         self._seq_length = seq_length |         self._thresh = interval * seq_length if thresh is None else thresh | ||||||
|         self._thresh = interval * 50 if thresh is None else thresh |  | ||||||
|  |  | ||||||
|         self.register_parameter( |         self.register_parameter( | ||||||
|             "_super_layer_embed", |             "_super_layer_embed", | ||||||
| @@ -45,10 +44,6 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), |             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||||
|         ) |         ) | ||||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) |         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||||
|         # register a time difference buffer |  | ||||||
|         # time_interval = [-i * self._interval for i in range(self._seq_length)] |  | ||||||
|         # time_interval.reverse() |  | ||||||
|         # self.register_buffer("_time_interval", torch.Tensor(time_interval)) |  | ||||||
|         self._time_embed_dim = time_dim |         self._time_embed_dim = time_dim | ||||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) |         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) |         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||||
| @@ -186,7 +181,6 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|  |  | ||||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): |     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||||
|         if time_embeds is None: |         if time_embeds is None: | ||||||
|             # time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) |  | ||||||
|             [B] = timestamps.shape |             [B] = timestamps.shape | ||||||
|             time_embeds = self._obtain_time_embed(timestamps) |             time_embeds = self._obtain_time_embed(timestamps) | ||||||
|         else:  # use the hyper-net only |         else:  # use the hyper-net only | ||||||
| @@ -210,7 +204,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             batch_containers.append( |             batch_containers.append( | ||||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) |                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||||
|             ) |             ) | ||||||
|         return time_seq, batch_containers, time_embeds |         return batch_containers, time_embeds | ||||||
|  |  | ||||||
|     def forward_candidate(self, input): |     def forward_candidate(self, input): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| @@ -239,10 +233,10 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|                 best_new_param = new_param.detach().clone() |                 best_new_param = new_param.detach().clone() | ||||||
|             for iepoch in range(epochs): |             for iepoch in range(epochs): | ||||||
|                 optimizer.zero_grad() |                 optimizer.zero_grad() | ||||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None) |                 _, time_embed = self(timestamp.view(1), None) | ||||||
|                 match_loss = criterion(new_param, time_embed) |                 match_loss = criterion(new_param, time_embed) | ||||||
|  |  | ||||||
|                 _, [container], time_embed = self(None, new_param.view(1, -1)) |                 [container], time_embed = self(None, new_param.view(1, -1)) | ||||||
|                 y_hat = base_model.forward_with_container(x, container) |                 y_hat = base_model.forward_with_container(x, container) | ||||||
|                 meta_loss = criterion(y_hat, y) |                 meta_loss = criterion(y_hat, y) | ||||||
|                 loss = meta_loss + match_loss |                 loss = meta_loss + match_loss | ||||||
|   | |||||||
| @@ -46,8 +46,8 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | |||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_model.eval() |             meta_model.eval() | ||||||
|             base_model.eval() |             base_model.eval() | ||||||
|             _, [future_container], time_embeds = meta_model( |             [future_container], time_embeds = meta_model( | ||||||
|                 future_time.to(args.device).view(1, 1), None, False |                 future_time.to(args.device).view(-1), None, False | ||||||
|             ) |             ) | ||||||
|             if save: |             if save: | ||||||
|                 w_containers[idx] = future_container.no_grad_clone() |                 w_containers[idx] = future_container.no_grad_clone() | ||||||
| @@ -117,10 +117,10 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         ) |         ) | ||||||
|         # future loss |         # future loss | ||||||
|         total_future_losses, total_present_losses = [], [] |         total_future_losses, total_present_losses = [], [] | ||||||
|         _, future_containers, _ = meta_model( |         future_containers, _ = meta_model( | ||||||
|             None, generated_time_embeds[batch_indexes], False |             None, generated_time_embeds[batch_indexes], False | ||||||
|         ) |         ) | ||||||
|         _, present_containers, _ = meta_model( |         present_containers, _ = meta_model( | ||||||
|             None, meta_model.super_meta_embed[batch_indexes], False |             None, meta_model.super_meta_embed[batch_indexes], False | ||||||
|         ) |         ) | ||||||
|         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): |         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): | ||||||
|   | |||||||
| @@ -1,6 +1,3 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # |  | ||||||
| ##################################################### |  | ||||||
| import math | import math | ||||||
| import abc | import abc | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -13,11 +10,11 @@ class UnifiedSplit: | |||||||
|     """A class to unify the split strategy.""" |     """A class to unify the split strategy.""" | ||||||
|  |  | ||||||
|     def __init__(self, total_num, mode): |     def __init__(self, total_num, mode): | ||||||
|         # Training Set 65% |         # Training Set 75% | ||||||
|         num_of_train = int(total_num * 0.65) |         num_of_train = int(total_num * 0.75) | ||||||
|         # Validation Set 05% |         # Validation Set 05% | ||||||
|         num_of_valid = int(total_num * 0.05) |         num_of_valid = int(total_num * 0.05) | ||||||
|         # Test Set 30% |         # Test Set 20% | ||||||
|         num_of_set = total_num - num_of_train - num_of_valid |         num_of_set = total_num - num_of_train - num_of_valid | ||||||
|         all_indexes = list(range(total_num)) |         all_indexes = list(range(total_num)) | ||||||
|         if mode is None: |         if mode is None: | ||||||
|   | |||||||
| @@ -1,27 +1,32 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | ||||||
|  | ##################################################### | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
|  |  | ||||||
| def count_parameters_in_MB(model): | def count_parameters_in_MB(model): | ||||||
|     return count_parameters(model, "mb") |     return count_parameters(model, "mb", deprecated=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| def count_parameters(model_or_parameters, unit="mb"): | def count_parameters(model_or_parameters, unit="mb", deprecated=False): | ||||||
|     if isinstance(model_or_parameters, nn.Module): |     if isinstance(model_or_parameters, nn.Module): | ||||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) |         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||||
|     elif isinstance(model_or_parameters, nn.Parameter): |     elif isinstance(model_or_parameters, nn.Parameter): | ||||||
|         counts = models_or_parameters.numel() |         counts = models_or_parameters.numel() | ||||||
|     elif isinstance(model_or_parameters, (list, tuple)): |     elif isinstance(model_or_parameters, (list, tuple)): | ||||||
|         counts = sum(count_parameters(x, None) for x in models_or_parameters) |         counts = sum( | ||||||
|  |             count_parameters(x, None, deprecated) for x in models_or_parameters | ||||||
|  |         ) | ||||||
|     else: |     else: | ||||||
|         counts = sum(np.prod(v.size()) for v in model_or_parameters) |         counts = sum(np.prod(v.size()) for v in model_or_parameters) | ||||||
|     if unit.lower() == "kb" or unit.lower() == "k": |     if unit.lower() == "kb" or unit.lower() == "k": | ||||||
|         counts /= 2 ** 10  # changed from 1e3 to 2^10 |         counts /= 1e3 if deprecated else 2 ** 10  # changed from 1e3 to 2^10 | ||||||
|     elif unit.lower() == "mb" or unit.lower() == "m": |     elif unit.lower() == "mb" or unit.lower() == "m": | ||||||
|         counts /= 2 ** 20  # changed from 1e6 to 2^20 |         counts /= 1e6 if deprecated else 2 ** 20  # changed from 1e6 to 2^20 | ||||||
|     elif unit.lower() == "gb" or unit.lower() == "g": |     elif unit.lower() == "gb" or unit.lower() == "g": | ||||||
|         counts /= 2 ** 30  # changed from 1e9 to 2^30 |         counts /= 1e9 if deprecated else 2 ** 30  # changed from 1e9 to 2^30 | ||||||
|     elif unit is not None: |     elif unit is not None: | ||||||
|         raise ValueError("Unknow unit: {:}".format(unit)) |         raise ValueError("Unknow unit: {:}".format(unit)) | ||||||
|     return counts |     return counts | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user