diff --git a/graph_dit/main.py b/graph_dit/main.py index 27a924d..684cb8f 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -11,9 +11,13 @@ import utils from datasets import dataset from diffusion_model import Graph_DiT from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete +from metrics.molecular_metrics_train import TrainGraphMetricsDiscrete from metrics.molecular_metrics_sampling import SamplingMolecularMetrics +from metrics.molecular_metrics_sampling import SamplingGraphMetrics + from analysis.visualization import MolecularVisualization +from analysis.visualization import GraphVisualization warnings.filterwarnings("ignore", category=UserWarning) torch.set_float32_matmul_precision("medium") @@ -79,19 +83,20 @@ def main(cfg: DictConfig): datamodule = dataset.DataModule(cfg) datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) - # train_smiles, reference_smiles = datamodule.get_train_smiles() - train_graphs, reference_graphs = datamodule.get_train_graphs() + train_smiles, reference_smiles = datamodule.get_train_smiles() + # train_graphs, reference_graphs = datamodule.get_train_graphs() # get input output dimensions dataset_infos.compute_input_output_dims(datamodule=datamodule) train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) + # train_metrics = TrainGraphMetricsDiscrete(dataset_infos) - # sampling_metrics = SamplingMolecularMetrics( - # dataset_infos, train_smiles, reference_smiles - # ) - sampling_metrics = SamplingGraphMetrics( - dataset_infos, train_graphs, reference_graphs + sampling_metrics = SamplingMolecularMetrics( + dataset_infos, train_smiles, reference_smiles ) + # sampling_metrics = SamplingGraphMetrics( + # dataset_infos, train_graphs, reference_graphs + # ) visualization_tools = MolecularVisualization(dataset_infos) model_kwargs = { @@ -149,6 +154,54 @@ def test(cfg: DictConfig): train_graphs, reference_graphs = datamodule.get_train_graphs() dataset_infos.compute_input_output_dims(datamodule=datamodule) + train_metrics = TrainGraphMetricsDiscrete(dataset_infos) + + sampling_metrics = SamplingGraphMetrics( + dataset_infos, train_graphs, reference_graphs + ) + + visulization_tools = GraphVisualization(dataset_infos) + + model_kwargs = { + "dataset_infos": dataset_infos, + "train_metrics": train_metrics, + "sampling_metrics": sampling_metrics, + "visualization_tools": visulization_tools, + } + + if cfg.general.test_only: + cfg, _ = get_resume(cfg, model_kwargs) + os.chdir(cfg.general.test_only.split("checkpoints")[0]) + elif cfg.general.resume is not None: + cfg, _ = get_resume_adaptive(cfg, model_kwargs) + os.chdir(cfg.general.resume.split("checkpoints")[0]) + model = Graph_DiT(cfg=cfg, **model_kwargs) + trainer = Trainer( + gradient_clip_val=cfg.train.clip_grad, + # accelerator="cpu", + accelerator="gpu" + if torch.cuda.is_available() and cfg.general.gpus > 0 + else "cpu", + devices=cfg.general.gpus + if torch.cuda.is_available() and cfg.general.gpus > 0 + else None, + max_epochs=cfg.train.n_epochs, + enable_checkpointing=False, + check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, + val_check_interval=cfg.train.val_check_interval, + strategy="ddp" if cfg.general.gpus > 1 else "auto", + enable_progress_bar=cfg.general.enable_progress_bar, + callbacks=[], + reload_dataloaders_every_n_epochs=0, + logger=[], + ) + + if not cfg.general.test_only: + print("start testing fit method") + trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) + if cfg.general.save_model: + trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") + trainer.test(model, datamodule=datamodule) if __name__ == "__main__": test()