From b1064e5a6090c4bbafb09ee07f9f226c697223ea Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 23 May 2021 19:14:12 +0000 Subject: [PATCH] LFNA ok on the valid data --- exps/LFNA/lfna.py | 33 ++++++++++--------------------- exps/LFNA/lfna_meta_model.py | 16 +++++++++------ xautodl/datasets/synthetic_env.py | 33 +++++++------------------------ 3 files changed, 27 insertions(+), 55 deletions(-) diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 6c827fd..c46a083 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -99,18 +99,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): with torch.no_grad(): meta_model.eval() base_model.eval() - _, [future_container], _ = meta_model( + _, [future_container], time_embeds = meta_model( future_time.to(args.device).view(1, 1), None, True ) future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_y_hat = base_model.forward_with_container(future_x, future_container) future_loss = criterion(future_y_hat, future_y) - logger.log( - "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( - idx, len(env), future_loss.item() - ) - ) - refine = meta_model.adapt( + refine, post_refine_loss = meta_model.adapt( base_model, criterion, future_time.item(), @@ -118,6 +113,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): future_y, args.refine_lr, args.refine_epochs, + {"param": time_embeds, "loss": future_loss.item()}, + ) + logger.log( + "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( + idx, len(env), future_loss.item() + ) + + ", post-loss={:.4f}".format(post_refine_loss if refine else -1) ) meta_model.clear_fixed() meta_model.clear_learnt() @@ -244,21 +246,6 @@ def main(args): logger.log("The meta-model is\n{:}".format(meta_model)) batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) - # train_env.reset_max_seq_length(args.seq_length) - # valid_env.reset_max_seq_length(args.seq_length) - valid_env_loader = torch.utils.data.DataLoader( - valid_env, - batch_size=args.meta_batch, - shuffle=True, - num_workers=args.workers, - pin_memory=True, - ) - train_env_loader = torch.utils.data.DataLoader( - train_env, - batch_sampler=batch_sampler, - num_workers=args.workers, - pin_memory=True, - ) pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) # try to evaluate once @@ -507,7 +494,7 @@ if __name__ == "__main__": help="The learning rate for the optimizer, during refine", ) parser.add_argument( - "--refine_epochs", type=int, default=50, help="The final refine #epochs." + "--refine_epochs", type=int, default=40, help="The final refine #epochs." ) parser.add_argument( "--early_stop_thresh", diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index 7291f76..3ec1acc 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -276,10 +276,10 @@ class LFNA_Meta(super_core.SuperModule): def forward_candidate(self, input): raise NotImplementedError - def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs): + def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): distance = self.get_closest_meta_distance(timestamp) if distance + self._interval * 1e-2 <= self._interval: - return False + return False, None x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) with torch.set_grad_enabled(True): new_param = self.create_meta_embed() @@ -290,7 +290,11 @@ class LFNA_Meta(super_core.SuperModule): self.replace_append_learnt(timestamp, new_param) self.train() base_model.train() - best_new_param, best_loss = None, 1e9 + if init_info is not None: + best_loss = init_info["loss"] + new_param.data.copy_(init_info["param"].data) + else: + best_new_param, best_loss = None, 1e9 for iepoch in range(epochs): optimizer.zero_grad() _, [_], time_embed = self(timestamp.view(1, 1), None, True) @@ -303,14 +307,14 @@ class LFNA_Meta(super_core.SuperModule): loss.backward() optimizer.step() # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) - if loss.item() < best_loss: + if meta_loss.item() < best_loss: with torch.no_grad(): - best_loss = loss.item() + best_loss = meta_loss.item() best_new_param = new_param.detach() with torch.no_grad(): self.replace_append_learnt(None, None) self.append_fixed(timestamp, best_new_param) - return True + return True, meta_loss.item() def extra_repr(self) -> str: return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( diff --git a/xautodl/datasets/synthetic_env.py b/xautodl/datasets/synthetic_env.py index 944d5a5..8c7854c 100644 --- a/xautodl/datasets/synthetic_env.py +++ b/xautodl/datasets/synthetic_env.py @@ -66,11 +66,6 @@ class SyntheticDEnv(data.Dataset): self._cov_functors = cov_functors self._oracle_map = None - self._seq_length = None - - @property - def seq_length(self): - return self._seq_length @property def min_timestamp(self): @@ -84,14 +79,12 @@ class SyntheticDEnv(data.Dataset): def timestamp_interval(self): return self._timestamp_generator.interval - def random_timestamp(self): - return ( - random.random() * (self.max_timestamp - self.min_timestamp) - + self.min_timestamp - ) - - def reset_max_seq_length(self, seq_length): - self._seq_length = seq_length + def random_timestamp(self, min_timestamp=None, max_timestamp=None): + if min_timestamp is None: + min_timestamp = self.min_timestamp + if max_timestamp is None: + max_timestamp = self.max_timestamp + return random.random() * (max_timestamp - min_timestamp) + min_timestamp def get_timestamp(self, index): if index is None: @@ -119,19 +112,7 @@ class SyntheticDEnv(data.Dataset): def __getitem__(self, index): assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) index, timestamp = self._timestamp_generator[index] - if self._seq_length is None: - return self.__call__(timestamp) - else: - noise = ( - random.random() * self.timestamp_interval * self._timestamp_noise_scale - ) - timestamps = [ - timestamp + i * self.timestamp_interval + noise - for i in range(self._seq_length) - ] - # xdata = [self.__call__(timestamp) for timestamp in timestamps] - # return zip_sequence(xdata) - return self.seq_call(timestamps) + return self.__call__(timestamp) def seq_call(self, timestamps): with torch.no_grad():