Sync qlib

This commit is contained in:
D-X-Y 2021-03-17 09:10:45 +00:00
parent 1ba1585f20
commit a9093e41e1
4 changed files with 20 additions and 12 deletions

@ -1 +1 @@
Subproject commit aa552fdb2089cf5b4396a6b75191d2c13211b42d
Subproject commit 6ef204f1905602d60ba47b3e47f31d482df9f21d

View File

@ -111,6 +111,7 @@ def query_info(save_dir, verbose):
experiments = R.list_experiments()
key_map = {
"RMSE": "RMSE",
"IC": "IC",
"ICIR": "ICIR",
"Rank IC": "Rank_IC",

View File

@ -68,7 +68,7 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
model_fit_kwargs = dict(dataset=dataset)
# Let's start the experiment.
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri):
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri, resume=True):
# Setup log
recorder_root_dir = R.get_recorder().get_local_dir()
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
@ -81,7 +81,9 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
# Train model
R.log_params(**flatten_dict(task_config))
if "save_path" in inspect.getfullargspec(model.fit).args:
model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model.ckps")
model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model.ckp")
elif "save_dir" in inspect.getfullargspec(model.fit).args:
model_fit_kwargs["save_dir"] = os.path.join(recorder_root_dir, "model-ckps")
model.fit(**model_fit_kwargs)
# Get the recorder
recorder = R.get_recorder()

View File

@ -138,7 +138,7 @@ class QuantTransformer(Model):
def fit(
self,
dataset: DatasetH,
save_path: Optional[Text] = None,
save_dir: Optional[Text] = None,
):
def _prepare_dataset(df_data):
return th_data.TensorDataset(
@ -172,8 +172,8 @@ class QuantTransformer(Model):
_prepare_loader(test_dataset, False),
)
save_path = get_or_create_path(save_path, return_dir=True)
self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_path))
save_dir = get_or_create_path(save_dir, return_dir=True)
self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_dir))
def _internal_test(ckp_epoch=None, results_dict=None):
with torch.no_grad():
@ -196,15 +196,18 @@ class QuantTransformer(Model):
return dict(train=train_score, valid=valid_score, test=test_score), xstr
# Pre-fetch the potential checkpoints
ckp_path = os.path.join(save_path, "{:}.pth".format(self.__class__.__name__))
ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__))
if os.path.exists(ckp_path):
ckp_data = torch.load(ckp_path)
import pdb
pdb.set_trace()
stop_steps, best_score, best_epoch = ckp_data['stop_steps'], ckp_data['best_score'], ckp_data['best_epoch']
start_epoch, best_param = ckp_data['start_epoch'], ckp_data['best_param']
results_dict = ckp_data['results_dict']
self.model.load_state_dict(ckp_data['net_state_dict'])
self.train_optimizer.load_state_dict(ckp_data['opt_state_dict'])
self.logger.info("Resume from existing checkpoint: {:}".format(ckp_path))
else:
stop_steps, best_score, best_epoch = 0, -np.inf, -1
start_epoch = 0
start_epoch, best_param = 0, None
results_dict = dict(train=OrderedDict(), valid=OrderedDict(), test=OrderedDict())
_, eval_str = _internal_test(-1, results_dict)
self.logger.info("Training from scratch, metrics@start: {:}".format(eval_str))
@ -215,7 +218,6 @@ class QuantTransformer(Model):
iepoch, self.opt_config["epochs"], best_epoch, best_score
)
)
train_loss, train_score = self.train_or_test_epoch(
train_loader, self.model, self.loss_fn, self.metric_fn, True, self.train_optimizer
)
@ -241,11 +243,14 @@ class QuantTransformer(Model):
stop_steps=stop_steps,
best_score=best_score,
best_epoch=best_epoch,
results_dict=results_dict,
start_epoch=iepoch + 1,
)
torch.save(save_info, ckp_path)
self.logger.info("The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch))
self.model.load_state_dict(best_param)
_, eval_str = _internal_test('final', results_dict)
self.logger.info("Reload the best parameter :: {:}".format(eval_str))
if self.use_gpu:
torch.cuda.empty_cache()