diff --git a/graph_dit/main.py b/graph_dit/main.py index d7fbe14..24e3c39 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -78,7 +78,7 @@ def main(cfg: DictConfig): datamodule = dataset.DataModule(cfg) datamodule.prepare_data() - dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) + dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) # train_smiles, reference_smiles = datamodule.get_train_smiles() # get input output dimensions