Sync qlib
This commit is contained in:
parent
1ba1585f20
commit
a9093e41e1
@ -1 +1 @@
|
||||
Subproject commit aa552fdb2089cf5b4396a6b75191d2c13211b42d
|
||||
Subproject commit 6ef204f1905602d60ba47b3e47f31d482df9f21d
|
@ -111,6 +111,7 @@ def query_info(save_dir, verbose):
|
||||
experiments = R.list_experiments()
|
||||
|
||||
key_map = {
|
||||
"RMSE": "RMSE",
|
||||
"IC": "IC",
|
||||
"ICIR": "ICIR",
|
||||
"Rank IC": "Rank_IC",
|
||||
|
@ -68,7 +68,7 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
|
||||
model_fit_kwargs = dict(dataset=dataset)
|
||||
|
||||
# Let's start the experiment.
|
||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri):
|
||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri, resume=True):
|
||||
# Setup log
|
||||
recorder_root_dir = R.get_recorder().get_local_dir()
|
||||
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
|
||||
@ -81,7 +81,9 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
|
||||
# Train model
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
if "save_path" in inspect.getfullargspec(model.fit).args:
|
||||
model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model.ckps")
|
||||
model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model.ckp")
|
||||
elif "save_dir" in inspect.getfullargspec(model.fit).args:
|
||||
model_fit_kwargs["save_dir"] = os.path.join(recorder_root_dir, "model-ckps")
|
||||
model.fit(**model_fit_kwargs)
|
||||
# Get the recorder
|
||||
recorder = R.get_recorder()
|
||||
|
@ -138,7 +138,7 @@ class QuantTransformer(Model):
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
save_path: Optional[Text] = None,
|
||||
save_dir: Optional[Text] = None,
|
||||
):
|
||||
def _prepare_dataset(df_data):
|
||||
return th_data.TensorDataset(
|
||||
@ -172,8 +172,8 @@ class QuantTransformer(Model):
|
||||
_prepare_loader(test_dataset, False),
|
||||
)
|
||||
|
||||
save_path = get_or_create_path(save_path, return_dir=True)
|
||||
self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_path))
|
||||
save_dir = get_or_create_path(save_dir, return_dir=True)
|
||||
self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_dir))
|
||||
|
||||
def _internal_test(ckp_epoch=None, results_dict=None):
|
||||
with torch.no_grad():
|
||||
@ -196,15 +196,18 @@ class QuantTransformer(Model):
|
||||
return dict(train=train_score, valid=valid_score, test=test_score), xstr
|
||||
|
||||
# Pre-fetch the potential checkpoints
|
||||
ckp_path = os.path.join(save_path, "{:}.pth".format(self.__class__.__name__))
|
||||
ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__))
|
||||
if os.path.exists(ckp_path):
|
||||
ckp_data = torch.load(ckp_path)
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
stop_steps, best_score, best_epoch = ckp_data['stop_steps'], ckp_data['best_score'], ckp_data['best_epoch']
|
||||
start_epoch, best_param = ckp_data['start_epoch'], ckp_data['best_param']
|
||||
results_dict = ckp_data['results_dict']
|
||||
self.model.load_state_dict(ckp_data['net_state_dict'])
|
||||
self.train_optimizer.load_state_dict(ckp_data['opt_state_dict'])
|
||||
self.logger.info("Resume from existing checkpoint: {:}".format(ckp_path))
|
||||
else:
|
||||
stop_steps, best_score, best_epoch = 0, -np.inf, -1
|
||||
start_epoch = 0
|
||||
start_epoch, best_param = 0, None
|
||||
results_dict = dict(train=OrderedDict(), valid=OrderedDict(), test=OrderedDict())
|
||||
_, eval_str = _internal_test(-1, results_dict)
|
||||
self.logger.info("Training from scratch, metrics@start: {:}".format(eval_str))
|
||||
@ -215,7 +218,6 @@ class QuantTransformer(Model):
|
||||
iepoch, self.opt_config["epochs"], best_epoch, best_score
|
||||
)
|
||||
)
|
||||
|
||||
train_loss, train_score = self.train_or_test_epoch(
|
||||
train_loader, self.model, self.loss_fn, self.metric_fn, True, self.train_optimizer
|
||||
)
|
||||
@ -241,11 +243,14 @@ class QuantTransformer(Model):
|
||||
stop_steps=stop_steps,
|
||||
best_score=best_score,
|
||||
best_epoch=best_epoch,
|
||||
results_dict=results_dict,
|
||||
start_epoch=iepoch + 1,
|
||||
)
|
||||
torch.save(save_info, ckp_path)
|
||||
self.logger.info("The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
_, eval_str = _internal_test('final', results_dict)
|
||||
self.logger.info("Reload the best parameter :: {:}".format(eval_str))
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
Loading…
Reference in New Issue
Block a user