LFNA ok on the valid data
This commit is contained in:
		| @@ -99,18 +99,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | |||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_model.eval() |             meta_model.eval() | ||||||
|             base_model.eval() |             base_model.eval() | ||||||
|             _, [future_container], _ = meta_model( |             _, [future_container], time_embeds = meta_model( | ||||||
|                 future_time.to(args.device).view(1, 1), None, True |                 future_time.to(args.device).view(1, 1), None, True | ||||||
|             ) |             ) | ||||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) |             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) |             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||||
|             future_loss = criterion(future_y_hat, future_y) |             future_loss = criterion(future_y_hat, future_y) | ||||||
|             logger.log( |         refine, post_refine_loss = meta_model.adapt( | ||||||
|                 "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( |  | ||||||
|                     idx, len(env), future_loss.item() |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|         refine = meta_model.adapt( |  | ||||||
|             base_model, |             base_model, | ||||||
|             criterion, |             criterion, | ||||||
|             future_time.item(), |             future_time.item(), | ||||||
| @@ -118,6 +113,13 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | |||||||
|             future_y, |             future_y, | ||||||
|             args.refine_lr, |             args.refine_lr, | ||||||
|             args.refine_epochs, |             args.refine_epochs, | ||||||
|  |             {"param": time_embeds, "loss": future_loss.item()}, | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||||
|  |                 idx, len(env), future_loss.item() | ||||||
|  |             ) | ||||||
|  |             + ", post-loss={:.4f}".format(post_refine_loss if refine else -1) | ||||||
|         ) |         ) | ||||||
|     meta_model.clear_fixed() |     meta_model.clear_fixed() | ||||||
|     meta_model.clear_learnt() |     meta_model.clear_learnt() | ||||||
| @@ -244,21 +246,6 @@ def main(args): | |||||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) |     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||||
|  |  | ||||||
|     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) |     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||||
|     # train_env.reset_max_seq_length(args.seq_length) |  | ||||||
|     # valid_env.reset_max_seq_length(args.seq_length) |  | ||||||
|     valid_env_loader = torch.utils.data.DataLoader( |  | ||||||
|         valid_env, |  | ||||||
|         batch_size=args.meta_batch, |  | ||||||
|         shuffle=True, |  | ||||||
|         num_workers=args.workers, |  | ||||||
|         pin_memory=True, |  | ||||||
|     ) |  | ||||||
|     train_env_loader = torch.utils.data.DataLoader( |  | ||||||
|         train_env, |  | ||||||
|         batch_sampler=batch_sampler, |  | ||||||
|         num_workers=args.workers, |  | ||||||
|         pin_memory=True, |  | ||||||
|     ) |  | ||||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) |     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||||
|  |  | ||||||
|     # try to evaluate once |     # try to evaluate once | ||||||
| @@ -507,7 +494,7 @@ if __name__ == "__main__": | |||||||
|         help="The learning rate for the optimizer, during refine", |         help="The learning rate for the optimizer, during refine", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--refine_epochs", type=int, default=50, help="The final refine #epochs." |         "--refine_epochs", type=int, default=40, help="The final refine #epochs." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|   | |||||||
| @@ -276,10 +276,10 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|     def forward_candidate(self, input): |     def forward_candidate(self, input): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs): |     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||||
|         distance = self.get_closest_meta_distance(timestamp) |         distance = self.get_closest_meta_distance(timestamp) | ||||||
|         if distance + self._interval * 1e-2 <= self._interval: |         if distance + self._interval * 1e-2 <= self._interval: | ||||||
|             return False |             return False, None | ||||||
|         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) |         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||||
|         with torch.set_grad_enabled(True): |         with torch.set_grad_enabled(True): | ||||||
|             new_param = self.create_meta_embed() |             new_param = self.create_meta_embed() | ||||||
| @@ -290,7 +290,11 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|             self.replace_append_learnt(timestamp, new_param) |             self.replace_append_learnt(timestamp, new_param) | ||||||
|             self.train() |             self.train() | ||||||
|             base_model.train() |             base_model.train() | ||||||
|             best_new_param, best_loss = None, 1e9 |             if init_info is not None: | ||||||
|  |                 best_loss = init_info["loss"] | ||||||
|  |                 new_param.data.copy_(init_info["param"].data) | ||||||
|  |             else: | ||||||
|  |                 best_new_param, best_loss = None, 1e9 | ||||||
|             for iepoch in range(epochs): |             for iepoch in range(epochs): | ||||||
|                 optimizer.zero_grad() |                 optimizer.zero_grad() | ||||||
|                 _, [_], time_embed = self(timestamp.view(1, 1), None, True) |                 _, [_], time_embed = self(timestamp.view(1, 1), None, True) | ||||||
| @@ -303,14 +307,14 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|                 loss.backward() |                 loss.backward() | ||||||
|                 optimizer.step() |                 optimizer.step() | ||||||
|                 # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) |                 # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) | ||||||
|                 if loss.item() < best_loss: |                 if meta_loss.item() < best_loss: | ||||||
|                     with torch.no_grad(): |                     with torch.no_grad(): | ||||||
|                         best_loss = loss.item() |                         best_loss = meta_loss.item() | ||||||
|                         best_new_param = new_param.detach() |                         best_new_param = new_param.detach() | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             self.replace_append_learnt(None, None) |             self.replace_append_learnt(None, None) | ||||||
|             self.append_fixed(timestamp, best_new_param) |             self.append_fixed(timestamp, best_new_param) | ||||||
|         return True |         return True, meta_loss.item() | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( |         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||||
|   | |||||||
| @@ -66,11 +66,6 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         self._cov_functors = cov_functors |         self._cov_functors = cov_functors | ||||||
|  |  | ||||||
|         self._oracle_map = None |         self._oracle_map = None | ||||||
|         self._seq_length = None |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def seq_length(self): |  | ||||||
|         return self._seq_length |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def min_timestamp(self): |     def min_timestamp(self): | ||||||
| @@ -84,14 +79,12 @@ class SyntheticDEnv(data.Dataset): | |||||||
|     def timestamp_interval(self): |     def timestamp_interval(self): | ||||||
|         return self._timestamp_generator.interval |         return self._timestamp_generator.interval | ||||||
|  |  | ||||||
|     def random_timestamp(self): |     def random_timestamp(self, min_timestamp=None, max_timestamp=None): | ||||||
|         return ( |         if min_timestamp is None: | ||||||
|             random.random() * (self.max_timestamp - self.min_timestamp) |             min_timestamp = self.min_timestamp | ||||||
|             + self.min_timestamp |         if max_timestamp is None: | ||||||
|         ) |             max_timestamp = self.max_timestamp | ||||||
|  |         return random.random() * (max_timestamp - min_timestamp) + min_timestamp | ||||||
|     def reset_max_seq_length(self, seq_length): |  | ||||||
|         self._seq_length = seq_length |  | ||||||
|  |  | ||||||
|     def get_timestamp(self, index): |     def get_timestamp(self, index): | ||||||
|         if index is None: |         if index is None: | ||||||
| @@ -119,19 +112,7 @@ class SyntheticDEnv(data.Dataset): | |||||||
|     def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) |         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||||
|         index, timestamp = self._timestamp_generator[index] |         index, timestamp = self._timestamp_generator[index] | ||||||
|         if self._seq_length is None: |         return self.__call__(timestamp) | ||||||
|             return self.__call__(timestamp) |  | ||||||
|         else: |  | ||||||
|             noise = ( |  | ||||||
|                 random.random() * self.timestamp_interval * self._timestamp_noise_scale |  | ||||||
|             ) |  | ||||||
|             timestamps = [ |  | ||||||
|                 timestamp + i * self.timestamp_interval + noise |  | ||||||
|                 for i in range(self._seq_length) |  | ||||||
|             ] |  | ||||||
|             # xdata = [self.__call__(timestamp) for timestamp in timestamps] |  | ||||||
|             # return zip_sequence(xdata) |  | ||||||
|             return self.seq_call(timestamps) |  | ||||||
|  |  | ||||||
|     def seq_call(self, timestamps): |     def seq_call(self, timestamps): | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user