write test code
This commit is contained in:
parent
a222c514d9
commit
14186fa97f
@ -80,14 +80,18 @@ def main(cfg: DictConfig):
|
|||||||
datamodule.prepare_data()
|
datamodule.prepare_data()
|
||||||
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
|
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
|
||||||
# train_smiles, reference_smiles = datamodule.get_train_smiles()
|
# train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||||
|
train_graphs, reference_graphs = datamodule.get_train_graphs()
|
||||||
|
|
||||||
# get input output dimensions
|
# get input output dimensions
|
||||||
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||||
# train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||||
|
|
||||||
# sampling_metrics = SamplingMolecularMetrics(
|
# sampling_metrics = SamplingMolecularMetrics(
|
||||||
# dataset_infos, train_smiles, reference_smiles
|
# dataset_infos, train_smiles, reference_smiles
|
||||||
# )
|
# )
|
||||||
|
sampling_metrics = SamplingGraphMetrics(
|
||||||
|
dataset_infos, train_graphs, reference_graphs
|
||||||
|
)
|
||||||
visualization_tools = MolecularVisualization(dataset_infos)
|
visualization_tools = MolecularVisualization(dataset_infos)
|
||||||
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
@ -135,5 +139,16 @@ def main(cfg: DictConfig):
|
|||||||
else:
|
else:
|
||||||
trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
|
trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
|
||||||
|
|
||||||
|
@hydra.main(
|
||||||
|
version_base="1.1", config_path="../configs", config_name="config"
|
||||||
|
)
|
||||||
|
def test(cfg: DictConfig):
|
||||||
|
datamodule = dataset.DataModule(cfg)
|
||||||
|
datamodule.prepare_data()
|
||||||
|
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
|
||||||
|
train_graphs, reference_graphs = datamodule.get_train_graphs()
|
||||||
|
|
||||||
|
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
test()
|
||||||
|
Loading…
Reference in New Issue
Block a user