Add name filters for exp-org

This commit is contained in:
D-X-Y 2021-03-28 22:26:06 -07:00
parent b51320dfb1
commit 62bedaa094
3 changed files with 17 additions and 11 deletions

View File

@ -3,7 +3,7 @@
##################################################### #####################################################
# python exps/trading/organize_results.py # # python exps/trading/organize_results.py #
##################################################### #####################################################
import sys, argparse import re, sys, argparse
import numpy as np import numpy as np
from typing import List, Text from typing import List, Text
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
@ -121,7 +121,7 @@ def filter_finished(recorders):
return returned_recorders, not_finished return returned_recorders, not_finished
def query_info(save_dir, verbose): def query_info(save_dir, verbose, name_filter):
R.set_uri(save_dir) R.set_uri(save_dir)
experiments = R.list_experiments() experiments = R.list_experiments()
@ -143,6 +143,8 @@ def query_info(save_dir, verbose):
for idx, (key, experiment) in enumerate(experiments.items()): for idx, (key, experiment) in enumerate(experiments.items()):
if experiment.id == "0": if experiment.id == "0":
continue continue
if name_filter is not None and re.match(name_filter, experiment.name) is None:
continue
recorders = experiment.list_recorders() recorders = experiment.list_recorders()
recorders, not_finished = filter_finished(recorders) recorders, not_finished = filter_finished(recorders)
if verbose: if verbose:
@ -205,6 +207,9 @@ if __name__ == "__main__":
default=False, default=False,
help="Print detailed log information or not.", help="Print detailed log information or not.",
) )
parser.add_argument(
"--name_filter", type=str, default=".*", help="Filter experiment names."
)
args = parser.parse_args() args = parser.parse_args()
print("Show results of {:}".format(args.save_dir)) print("Show results of {:}".format(args.save_dir))
@ -216,7 +221,7 @@ if __name__ == "__main__":
all_info_dict = [] all_info_dict = []
for save_dir in args.save_dir: for save_dir in args.save_dir:
_, info_dict = query_info(save_dir, args.verbose) _, info_dict = query_info(save_dir, args.verbose, args.name_filter)
all_info_dict.append(info_dict) all_info_dict.append(info_dict)
info_dict = QResult.merge_dict(all_info_dict) info_dict = QResult.merge_dict(all_info_dict)
compare_results( compare_results(

View File

@ -100,9 +100,9 @@ def run_exp(
# Train model # Train model
try: try:
if hasattr(model, "to"): # Recoverable model if hasattr(model, "to"): # Recoverable model
device = model.device ori_device = model.device
model = R.load_object(model_obj_name) model = R.load_object(model_obj_name)
model.to(device) model.to(ori_device)
else: else:
model = R.load_object(model_obj_name) model = R.load_object(model_obj_name)
logger.info("[Find existing object from {:}]".format(model_obj_name)) logger.info("[Find existing object from {:}]".format(model_obj_name))
@ -119,9 +119,10 @@ def run_exp(
model.fit(**model_fit_kwargs) model.fit(**model_fit_kwargs)
# remove model to CPU for saving # remove model to CPU for saving
if hasattr(model, "to"): if hasattr(model, "to"):
old_device = model.device
model.to("cpu") model.to("cpu")
R.save_objects(**{model_obj_name: model}) R.save_objects(**{model_obj_name: model})
model.to() model.to(old_device)
else: else:
R.save_objects(**{model_obj_name: model}) R.save_objects(**{model_obj_name: model})
except Exception as e: except Exception as e:

View File

@ -114,9 +114,9 @@ class QuantTransformer(Model):
def to(self, device): def to(self, device):
if device is None: if device is None:
self.model.to(self.device) device = "cpu"
else: self.device = device
self.model.to("cpu") self.model.to(self.device)
def loss_fn(self, pred, label): def loss_fn(self, pred, label):
mask = ~torch.isnan(label) mask = ~torch.isnan(label)
@ -227,7 +227,7 @@ class QuantTransformer(Model):
# Pre-fetch the potential checkpoints # Pre-fetch the potential checkpoints
ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__)) ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__))
if os.path.exists(ckp_path): if os.path.exists(ckp_path):
ckp_data = torch.load(ckp_path) ckp_data = torch.load(ckp_path, map_location=self.device)
stop_steps, best_score, best_epoch = ( stop_steps, best_score, best_epoch = (
ckp_data["stop_steps"], ckp_data["stop_steps"],
ckp_data["best_score"], ckp_data["best_score"],
@ -298,7 +298,7 @@ class QuantTransformer(Model):
results_dict=results_dict, results_dict=results_dict,
start_epoch=iepoch + 1, start_epoch=iepoch + 1,
) )
torch.save(save_info, ckp_path) torch.save(save_info, ckp_path, map_location="cpu")
self.logger.info( self.logger.info(
"The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch) "The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch)
) )