diff --git a/exps/GeMOSA/lfna_meta_model.py b/exps/GeMOSA/lfna_meta_model.py index c36e88b..0df20be 100644 --- a/exps/GeMOSA/lfna_meta_model.py +++ b/exps/GeMOSA/lfna_meta_model.py @@ -20,7 +20,7 @@ class MetaModelV1(super_core.SuperModule): time_dim, meta_timestamps, dropout: float = 0.1, - seq_length: int = 10, + seq_length: int = None, interval: float = None, thresh: float = None, ): @@ -33,8 +33,7 @@ class MetaModelV1(super_core.SuperModule): self._raw_meta_timestamps = meta_timestamps assert interval is not None self._interval = interval - self._seq_length = seq_length - self._thresh = interval * 50 if thresh is None else thresh + self._thresh = interval * seq_length if thresh is None else thresh self.register_parameter( "_super_layer_embed", @@ -45,10 +44,6 @@ class MetaModelV1(super_core.SuperModule): torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), ) self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) - # register a time difference buffer - # time_interval = [-i * self._interval for i in range(self._seq_length)] - # time_interval.reverse() - # self.register_buffer("_time_interval", torch.Tensor(time_interval)) self._time_embed_dim = time_dim self._append_meta_embed = dict(fixed=None, learnt=None) self._append_meta_timestamps = dict(fixed=None, learnt=None) @@ -186,7 +181,6 @@ class MetaModelV1(super_core.SuperModule): def forward_raw(self, timestamps, time_embeds, tembed_only=False): if time_embeds is None: - # time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) [B] = timestamps.shape time_embeds = self._obtain_time_embed(timestamps) else: # use the hyper-net only @@ -210,7 +204,7 @@ class MetaModelV1(super_core.SuperModule): batch_containers.append( self._shape_container.translate(torch.split(weights.squeeze(0), 1)) ) - return time_seq, batch_containers, time_embeds + return batch_containers, time_embeds def forward_candidate(self, input): raise NotImplementedError @@ -239,10 +233,10 @@ class MetaModelV1(super_core.SuperModule): best_new_param = new_param.detach().clone() for iepoch in range(epochs): optimizer.zero_grad() - _, [_], time_embed = self(timestamp.view(1, 1), None) + _, time_embed = self(timestamp.view(1), None) match_loss = criterion(new_param, time_embed) - _, [container], time_embed = self(None, new_param.view(1, -1)) + [container], time_embed = self(None, new_param.view(1, -1)) y_hat = base_model.forward_with_container(x, container) meta_loss = criterion(y_hat, y) loss = meta_loss + match_loss diff --git a/exps/GeMOSA/main.py b/exps/GeMOSA/main.py index a78f8fd..2ba0117 100644 --- a/exps/GeMOSA/main.py +++ b/exps/GeMOSA/main.py @@ -46,8 +46,8 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F with torch.no_grad(): meta_model.eval() base_model.eval() - _, [future_container], time_embeds = meta_model( - future_time.to(args.device).view(1, 1), None, False + [future_container], time_embeds = meta_model( + future_time.to(args.device).view(-1), None, False ) if save: w_containers[idx] = future_container.no_grad_clone() @@ -117,10 +117,10 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): ) # future loss total_future_losses, total_present_losses = [], [] - _, future_containers, _ = meta_model( + future_containers, _ = meta_model( None, generated_time_embeds[batch_indexes], False ) - _, present_containers, _ = meta_model( + present_containers, _ = meta_model( None, meta_model.super_meta_embed[batch_indexes], False ) for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): diff --git a/xautodl/datasets/synthetic_utils.py b/xautodl/datasets/synthetic_utils.py index af353c3..9d9000e 100644 --- a/xautodl/datasets/synthetic_utils.py +++ b/xautodl/datasets/synthetic_utils.py @@ -1,6 +1,3 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # -##################################################### import math import abc import numpy as np @@ -13,11 +10,11 @@ class UnifiedSplit: """A class to unify the split strategy.""" def __init__(self, total_num, mode): - # Training Set 65% - num_of_train = int(total_num * 0.65) + # Training Set 75% + num_of_train = int(total_num * 0.75) # Validation Set 05% num_of_valid = int(total_num * 0.05) - # Test Set 30% + # Test Set 20% num_of_set = total_num - num_of_train - num_of_valid all_indexes = list(range(total_num)) if mode is None: diff --git a/xautodl/utils/flop_benchmark.py b/xautodl/utils/flop_benchmark.py index e9141f7..4cade13 100644 --- a/xautodl/utils/flop_benchmark.py +++ b/xautodl/utils/flop_benchmark.py @@ -1,27 +1,32 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # +##################################################### import torch import torch.nn as nn import numpy as np def count_parameters_in_MB(model): - return count_parameters(model, "mb") + return count_parameters(model, "mb", deprecated=True) -def count_parameters(model_or_parameters, unit="mb"): +def count_parameters(model_or_parameters, unit="mb", deprecated=False): if isinstance(model_or_parameters, nn.Module): counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) elif isinstance(model_or_parameters, nn.Parameter): counts = models_or_parameters.numel() elif isinstance(model_or_parameters, (list, tuple)): - counts = sum(count_parameters(x, None) for x in models_or_parameters) + counts = sum( + count_parameters(x, None, deprecated) for x in models_or_parameters + ) else: counts = sum(np.prod(v.size()) for v in model_or_parameters) if unit.lower() == "kb" or unit.lower() == "k": - counts /= 2 ** 10 # changed from 1e3 to 2^10 + counts /= 1e3 if deprecated else 2 ** 10 # changed from 1e3 to 2^10 elif unit.lower() == "mb" or unit.lower() == "m": - counts /= 2 ** 20 # changed from 1e6 to 2^20 + counts /= 1e6 if deprecated else 2 ** 20 # changed from 1e6 to 2^20 elif unit.lower() == "gb" or unit.lower() == "g": - counts /= 2 ** 30 # changed from 1e9 to 2^30 + counts /= 1e9 if deprecated else 2 ** 30 # changed from 1e9 to 2^30 elif unit is not None: raise ValueError("Unknow unit: {:}".format(unit)) return counts