Add name filters for exp-org
This commit is contained in:
		| @@ -3,7 +3,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/trading/organize_results.py           # | # python exps/trading/organize_results.py           # | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, argparse | import re, sys, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import List, Text | from typing import List, Text | ||||||
| from collections import defaultdict, OrderedDict | from collections import defaultdict, OrderedDict | ||||||
| @@ -121,7 +121,7 @@ def filter_finished(recorders): | |||||||
|     return returned_recorders, not_finished |     return returned_recorders, not_finished | ||||||
|  |  | ||||||
|  |  | ||||||
| def query_info(save_dir, verbose): | def query_info(save_dir, verbose, name_filter): | ||||||
|     R.set_uri(save_dir) |     R.set_uri(save_dir) | ||||||
|     experiments = R.list_experiments() |     experiments = R.list_experiments() | ||||||
|  |  | ||||||
| @@ -143,6 +143,8 @@ def query_info(save_dir, verbose): | |||||||
|     for idx, (key, experiment) in enumerate(experiments.items()): |     for idx, (key, experiment) in enumerate(experiments.items()): | ||||||
|         if experiment.id == "0": |         if experiment.id == "0": | ||||||
|             continue |             continue | ||||||
|  |         if name_filter is not None and re.match(name_filter, experiment.name) is None: | ||||||
|  |             continue | ||||||
|         recorders = experiment.list_recorders() |         recorders = experiment.list_recorders() | ||||||
|         recorders, not_finished = filter_finished(recorders) |         recorders, not_finished = filter_finished(recorders) | ||||||
|         if verbose: |         if verbose: | ||||||
| @@ -205,6 +207,9 @@ if __name__ == "__main__": | |||||||
|         default=False, |         default=False, | ||||||
|         help="Print detailed log information or not.", |         help="Print detailed log information or not.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--name_filter", type=str, default=".*", help="Filter experiment names." | ||||||
|  |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     print("Show results of {:}".format(args.save_dir)) |     print("Show results of {:}".format(args.save_dir)) | ||||||
| @@ -216,7 +221,7 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     all_info_dict = [] |     all_info_dict = [] | ||||||
|     for save_dir in args.save_dir: |     for save_dir in args.save_dir: | ||||||
|         _, info_dict = query_info(save_dir, args.verbose) |         _, info_dict = query_info(save_dir, args.verbose, args.name_filter) | ||||||
|         all_info_dict.append(info_dict) |         all_info_dict.append(info_dict) | ||||||
|     info_dict = QResult.merge_dict(all_info_dict) |     info_dict = QResult.merge_dict(all_info_dict) | ||||||
|     compare_results( |     compare_results( | ||||||
|   | |||||||
| @@ -100,9 +100,9 @@ def run_exp( | |||||||
|         # Train model |         # Train model | ||||||
|         try: |         try: | ||||||
|             if hasattr(model, "to"):  # Recoverable model |             if hasattr(model, "to"):  # Recoverable model | ||||||
|                 device = model.device |                 ori_device = model.device | ||||||
|                 model = R.load_object(model_obj_name) |                 model = R.load_object(model_obj_name) | ||||||
|                 model.to(device) |                 model.to(ori_device) | ||||||
|             else: |             else: | ||||||
|                 model = R.load_object(model_obj_name) |                 model = R.load_object(model_obj_name) | ||||||
|             logger.info("[Find existing object from {:}]".format(model_obj_name)) |             logger.info("[Find existing object from {:}]".format(model_obj_name)) | ||||||
| @@ -119,9 +119,10 @@ def run_exp( | |||||||
|             model.fit(**model_fit_kwargs) |             model.fit(**model_fit_kwargs) | ||||||
|             # remove model to CPU for saving |             # remove model to CPU for saving | ||||||
|             if hasattr(model, "to"): |             if hasattr(model, "to"): | ||||||
|  |                 old_device = model.device | ||||||
|                 model.to("cpu") |                 model.to("cpu") | ||||||
|                 R.save_objects(**{model_obj_name: model}) |                 R.save_objects(**{model_obj_name: model}) | ||||||
|                 model.to() |                 model.to(old_device) | ||||||
|             else: |             else: | ||||||
|                 R.save_objects(**{model_obj_name: model}) |                 R.save_objects(**{model_obj_name: model}) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|   | |||||||
| @@ -114,9 +114,9 @@ class QuantTransformer(Model): | |||||||
|  |  | ||||||
|     def to(self, device): |     def to(self, device): | ||||||
|         if device is None: |         if device is None: | ||||||
|             self.model.to(self.device) |             device = "cpu" | ||||||
|         else: |         self.device = device | ||||||
|             self.model.to("cpu") |         self.model.to(self.device) | ||||||
|  |  | ||||||
|     def loss_fn(self, pred, label): |     def loss_fn(self, pred, label): | ||||||
|         mask = ~torch.isnan(label) |         mask = ~torch.isnan(label) | ||||||
| @@ -227,7 +227,7 @@ class QuantTransformer(Model): | |||||||
|         # Pre-fetch the potential checkpoints |         # Pre-fetch the potential checkpoints | ||||||
|         ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__)) |         ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__)) | ||||||
|         if os.path.exists(ckp_path): |         if os.path.exists(ckp_path): | ||||||
|             ckp_data = torch.load(ckp_path) |             ckp_data = torch.load(ckp_path, map_location=self.device) | ||||||
|             stop_steps, best_score, best_epoch = ( |             stop_steps, best_score, best_epoch = ( | ||||||
|                 ckp_data["stop_steps"], |                 ckp_data["stop_steps"], | ||||||
|                 ckp_data["best_score"], |                 ckp_data["best_score"], | ||||||
| @@ -298,7 +298,7 @@ class QuantTransformer(Model): | |||||||
|                 results_dict=results_dict, |                 results_dict=results_dict, | ||||||
|                 start_epoch=iepoch + 1, |                 start_epoch=iepoch + 1, | ||||||
|             ) |             ) | ||||||
|             torch.save(save_info, ckp_path) |             torch.save(save_info, ckp_path, map_location="cpu") | ||||||
|         self.logger.info( |         self.logger.info( | ||||||
|             "The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch) |             "The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch) | ||||||
|         ) |         ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user