update the main function
This commit is contained in:
		| @@ -11,9 +11,13 @@ import utils | ||||
| from datasets import dataset | ||||
| from diffusion_model import Graph_DiT | ||||
| from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete | ||||
| from metrics.molecular_metrics_train import TrainGraphMetricsDiscrete | ||||
| from metrics.molecular_metrics_sampling import SamplingMolecularMetrics | ||||
| from metrics.molecular_metrics_sampling import SamplingGraphMetrics | ||||
|  | ||||
|  | ||||
| from analysis.visualization import MolecularVisualization | ||||
| from analysis.visualization import GraphVisualization | ||||
|  | ||||
| warnings.filterwarnings("ignore", category=UserWarning) | ||||
| torch.set_float32_matmul_precision("medium") | ||||
| @@ -79,19 +83,20 @@ def main(cfg: DictConfig): | ||||
|     datamodule = dataset.DataModule(cfg) | ||||
|     datamodule.prepare_data() | ||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) | ||||
|     # train_smiles, reference_smiles = datamodule.get_train_smiles() | ||||
|     train_graphs, reference_graphs = datamodule.get_train_graphs() | ||||
|     train_smiles, reference_smiles = datamodule.get_train_smiles() | ||||
|     # train_graphs, reference_graphs = datamodule.get_train_graphs() | ||||
|  | ||||
|     # get input output dimensions | ||||
|     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||
|     train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) | ||||
|     # train_metrics = TrainGraphMetricsDiscrete(dataset_infos) | ||||
|  | ||||
|     # sampling_metrics = SamplingMolecularMetrics( | ||||
|     #     dataset_infos, train_smiles, reference_smiles | ||||
|     # ) | ||||
|     sampling_metrics = SamplingGraphMetrics( | ||||
|         dataset_infos, train_graphs, reference_graphs | ||||
|     sampling_metrics = SamplingMolecularMetrics( | ||||
|         dataset_infos, train_smiles, reference_smiles | ||||
|     ) | ||||
|     # sampling_metrics = SamplingGraphMetrics( | ||||
|     #     dataset_infos, train_graphs, reference_graphs | ||||
|     # ) | ||||
|     visualization_tools = MolecularVisualization(dataset_infos) | ||||
|  | ||||
|     model_kwargs = { | ||||
| @@ -149,6 +154,54 @@ def test(cfg: DictConfig): | ||||
|     train_graphs, reference_graphs = datamodule.get_train_graphs() | ||||
|  | ||||
|     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||
|     train_metrics = TrainGraphMetricsDiscrete(dataset_infos) | ||||
|  | ||||
|     sampling_metrics = SamplingGraphMetrics( | ||||
|         dataset_infos, train_graphs, reference_graphs | ||||
|     ) | ||||
|  | ||||
|     visulization_tools = GraphVisualization(dataset_infos) | ||||
|  | ||||
|     model_kwargs = { | ||||
|         "dataset_infos": dataset_infos, | ||||
|         "train_metrics": train_metrics, | ||||
|         "sampling_metrics": sampling_metrics, | ||||
|         "visualization_tools": visulization_tools, | ||||
|     } | ||||
|  | ||||
|     if cfg.general.test_only: | ||||
|         cfg, _ = get_resume(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.test_only.split("checkpoints")[0]) | ||||
|     elif cfg.general.resume is not None: | ||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||
|     trainer = Trainer( | ||||
|         gradient_clip_val=cfg.train.clip_grad, | ||||
|         # accelerator="cpu", | ||||
|         accelerator="gpu" | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else "cpu", | ||||
|         devices=cfg.general.gpus | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else None, | ||||
|         max_epochs=cfg.train.n_epochs, | ||||
|         enable_checkpointing=False, | ||||
|         check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, | ||||
|         val_check_interval=cfg.train.val_check_interval, | ||||
|         strategy="ddp" if cfg.general.gpus > 1 else "auto", | ||||
|         enable_progress_bar=cfg.general.enable_progress_bar, | ||||
|         callbacks=[], | ||||
|         reload_dataloaders_every_n_epochs=0, | ||||
|         logger=[], | ||||
|     ) | ||||
|  | ||||
|     if not cfg.general.test_only: | ||||
|         print("start testing fit method") | ||||
|         trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) | ||||
|         if cfg.general.save_model: | ||||
|             trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") | ||||
|         trainer.test(model, datamodule=datamodule) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     test() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user