From 62bedaa0946d764f3facaaf524e525ba98036e3a Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 28 Mar 2021 22:26:06 -0700 Subject: [PATCH] Add name filters for exp-org --- exps/trading/organize_results.py | 11 ++++++++--- lib/procedures/q_exps.py | 7 ++++--- lib/trade_models/quant_transformer.py | 10 +++++----- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/exps/trading/organize_results.py b/exps/trading/organize_results.py index 1a66b8d..f810057 100644 --- a/exps/trading/organize_results.py +++ b/exps/trading/organize_results.py @@ -3,7 +3,7 @@ ##################################################### # python exps/trading/organize_results.py # ##################################################### -import sys, argparse +import re, sys, argparse import numpy as np from typing import List, Text from collections import defaultdict, OrderedDict @@ -121,7 +121,7 @@ def filter_finished(recorders): return returned_recorders, not_finished -def query_info(save_dir, verbose): +def query_info(save_dir, verbose, name_filter): R.set_uri(save_dir) experiments = R.list_experiments() @@ -143,6 +143,8 @@ def query_info(save_dir, verbose): for idx, (key, experiment) in enumerate(experiments.items()): if experiment.id == "0": continue + if name_filter is not None and re.match(name_filter, experiment.name) is None: + continue recorders = experiment.list_recorders() recorders, not_finished = filter_finished(recorders) if verbose: @@ -205,6 +207,9 @@ if __name__ == "__main__": default=False, help="Print detailed log information or not.", ) + parser.add_argument( + "--name_filter", type=str, default=".*", help="Filter experiment names." + ) args = parser.parse_args() print("Show results of {:}".format(args.save_dir)) @@ -216,7 +221,7 @@ if __name__ == "__main__": all_info_dict = [] for save_dir in args.save_dir: - _, info_dict = query_info(save_dir, args.verbose) + _, info_dict = query_info(save_dir, args.verbose, args.name_filter) all_info_dict.append(info_dict) info_dict = QResult.merge_dict(all_info_dict) compare_results( diff --git a/lib/procedures/q_exps.py b/lib/procedures/q_exps.py index c1e606a..794a406 100644 --- a/lib/procedures/q_exps.py +++ b/lib/procedures/q_exps.py @@ -100,9 +100,9 @@ def run_exp( # Train model try: if hasattr(model, "to"): # Recoverable model - device = model.device + ori_device = model.device model = R.load_object(model_obj_name) - model.to(device) + model.to(ori_device) else: model = R.load_object(model_obj_name) logger.info("[Find existing object from {:}]".format(model_obj_name)) @@ -119,9 +119,10 @@ def run_exp( model.fit(**model_fit_kwargs) # remove model to CPU for saving if hasattr(model, "to"): + old_device = model.device model.to("cpu") R.save_objects(**{model_obj_name: model}) - model.to() + model.to(old_device) else: R.save_objects(**{model_obj_name: model}) except Exception as e: diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index f56bd81..4f5cdd3 100644 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -114,9 +114,9 @@ class QuantTransformer(Model): def to(self, device): if device is None: - self.model.to(self.device) - else: - self.model.to("cpu") + device = "cpu" + self.device = device + self.model.to(self.device) def loss_fn(self, pred, label): mask = ~torch.isnan(label) @@ -227,7 +227,7 @@ class QuantTransformer(Model): # Pre-fetch the potential checkpoints ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__)) if os.path.exists(ckp_path): - ckp_data = torch.load(ckp_path) + ckp_data = torch.load(ckp_path, map_location=self.device) stop_steps, best_score, best_epoch = ( ckp_data["stop_steps"], ckp_data["best_score"], @@ -298,7 +298,7 @@ class QuantTransformer(Model): results_dict=results_dict, start_epoch=iepoch + 1, ) - torch.save(save_info, ckp_path) + torch.save(save_info, ckp_path, map_location="cpu") self.logger.info( "The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch) )