From 14186fa97f733c19340d569f658147f783d40ddc Mon Sep 17 00:00:00 2001 From: mhz Date: Wed, 26 Jun 2024 23:41:37 +0200 Subject: [PATCH] write test code --- graph_dit/main.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/graph_dit/main.py b/graph_dit/main.py index 24e3c39..27a924d 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -80,14 +80,18 @@ def main(cfg: DictConfig): datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) # train_smiles, reference_smiles = datamodule.get_train_smiles() + train_graphs, reference_graphs = datamodule.get_train_graphs() # get input output dimensions dataset_infos.compute_input_output_dims(datamodule=datamodule) - # train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) + train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) # sampling_metrics = SamplingMolecularMetrics( # dataset_infos, train_smiles, reference_smiles # ) + sampling_metrics = SamplingGraphMetrics( + dataset_infos, train_graphs, reference_graphs + ) visualization_tools = MolecularVisualization(dataset_infos) model_kwargs = { @@ -135,5 +139,16 @@ def main(cfg: DictConfig): else: trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) +@hydra.main( + version_base="1.1", config_path="../configs", config_name="config" +) +def test(cfg: DictConfig): + datamodule = dataset.DataModule(cfg) + datamodule.prepare_data() + dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) + train_graphs, reference_graphs = datamodule.get_train_graphs() + + dataset_infos.compute_input_output_dims(datamodule=datamodule) + if __name__ == "__main__": - main() + test()