LFNA ok on the valid data
This commit is contained in:
		| @@ -99,18 +99,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             _, [future_container], _ = meta_model( | ||||
|             _, [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(1, 1), None, True | ||||
|             ) | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             logger.log( | ||||
|                 "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||
|                     idx, len(env), future_loss.item() | ||||
|                 ) | ||||
|             ) | ||||
|         refine = meta_model.adapt( | ||||
|         refine, post_refine_loss = meta_model.adapt( | ||||
|             base_model, | ||||
|             criterion, | ||||
|             future_time.item(), | ||||
| @@ -118,6 +113,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | ||||
|             future_y, | ||||
|             args.refine_lr, | ||||
|             args.refine_epochs, | ||||
|             {"param": time_embeds, "loss": future_loss.item()}, | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||
|                 idx, len(env), future_loss.item() | ||||
|             ) | ||||
|             + ", post-loss={:.4f}".format(post_refine_loss if refine else -1) | ||||
|         ) | ||||
|     meta_model.clear_fixed() | ||||
|     meta_model.clear_learnt() | ||||
| @@ -244,21 +246,6 @@ def main(args): | ||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||
|  | ||||
|     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||
|     # train_env.reset_max_seq_length(args.seq_length) | ||||
|     # valid_env.reset_max_seq_length(args.seq_length) | ||||
|     valid_env_loader = torch.utils.data.DataLoader( | ||||
|         valid_env, | ||||
|         batch_size=args.meta_batch, | ||||
|         shuffle=True, | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     train_env_loader = torch.utils.data.DataLoader( | ||||
|         train_env, | ||||
|         batch_sampler=batch_sampler, | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||
|  | ||||
|     # try to evaluate once | ||||
| @@ -507,7 +494,7 @@ if __name__ == "__main__": | ||||
|         help="The learning rate for the optimizer, during refine", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--refine_epochs", type=int, default=50, help="The final refine #epochs." | ||||
|         "--refine_epochs", type=int, default=40, help="The final refine #epochs." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|   | ||||
| @@ -276,10 +276,10 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs): | ||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||
|         distance = self.get_closest_meta_distance(timestamp) | ||||
|         if distance + self._interval * 1e-2 <= self._interval: | ||||
|             return False | ||||
|             return False, None | ||||
|         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||
|         with torch.set_grad_enabled(True): | ||||
|             new_param = self.create_meta_embed() | ||||
| @@ -290,7 +290,11 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|             self.replace_append_learnt(timestamp, new_param) | ||||
|             self.train() | ||||
|             base_model.train() | ||||
|             best_new_param, best_loss = None, 1e9 | ||||
|             if init_info is not None: | ||||
|                 best_loss = init_info["loss"] | ||||
|                 new_param.data.copy_(init_info["param"].data) | ||||
|             else: | ||||
|                 best_new_param, best_loss = None, 1e9 | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None, True) | ||||
| @@ -303,14 +307,14 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|                 loss.backward() | ||||
|                 optimizer.step() | ||||
|                 # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) | ||||
|                 if loss.item() < best_loss: | ||||
|                 if meta_loss.item() < best_loss: | ||||
|                     with torch.no_grad(): | ||||
|                         best_loss = loss.item() | ||||
|                         best_loss = meta_loss.item() | ||||
|                         best_new_param = new_param.detach() | ||||
|         with torch.no_grad(): | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, best_new_param) | ||||
|         return True | ||||
|         return True, meta_loss.item() | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||
|   | ||||
| @@ -66,11 +66,6 @@ class SyntheticDEnv(data.Dataset): | ||||
|         self._cov_functors = cov_functors | ||||
|  | ||||
|         self._oracle_map = None | ||||
|         self._seq_length = None | ||||
|  | ||||
|     @property | ||||
|     def seq_length(self): | ||||
|         return self._seq_length | ||||
|  | ||||
|     @property | ||||
|     def min_timestamp(self): | ||||
| @@ -84,14 +79,12 @@ class SyntheticDEnv(data.Dataset): | ||||
|     def timestamp_interval(self): | ||||
|         return self._timestamp_generator.interval | ||||
|  | ||||
|     def random_timestamp(self): | ||||
|         return ( | ||||
|             random.random() * (self.max_timestamp - self.min_timestamp) | ||||
|             + self.min_timestamp | ||||
|         ) | ||||
|  | ||||
|     def reset_max_seq_length(self, seq_length): | ||||
|         self._seq_length = seq_length | ||||
|     def random_timestamp(self, min_timestamp=None, max_timestamp=None): | ||||
|         if min_timestamp is None: | ||||
|             min_timestamp = self.min_timestamp | ||||
|         if max_timestamp is None: | ||||
|             max_timestamp = self.max_timestamp | ||||
|         return random.random() * (max_timestamp - min_timestamp) + min_timestamp | ||||
|  | ||||
|     def get_timestamp(self, index): | ||||
|         if index is None: | ||||
| @@ -119,19 +112,7 @@ class SyntheticDEnv(data.Dataset): | ||||
|     def __getitem__(self, index): | ||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||
|         index, timestamp = self._timestamp_generator[index] | ||||
|         if self._seq_length is None: | ||||
|             return self.__call__(timestamp) | ||||
|         else: | ||||
|             noise = ( | ||||
|                 random.random() * self.timestamp_interval * self._timestamp_noise_scale | ||||
|             ) | ||||
|             timestamps = [ | ||||
|                 timestamp + i * self.timestamp_interval + noise | ||||
|                 for i in range(self._seq_length) | ||||
|             ] | ||||
|             # xdata = [self.__call__(timestamp) for timestamp in timestamps] | ||||
|             # return zip_sequence(xdata) | ||||
|             return self.seq_call(timestamps) | ||||
|         return self.__call__(timestamp) | ||||
|  | ||||
|     def seq_call(self, timestamps): | ||||
|         with torch.no_grad(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user