update_name

This commit is contained in:
gang liu 2024-05-25 15:32:36 -04:00
parent a6bd0117d4
commit 2c00828630
28 changed files with 178 additions and 19 deletions

161
.gitignore vendored Normal file
View File

@ -0,0 +1,161 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.DS_Store
.idea/
__pycache__/
dgd/configs/__pycache__/
egnn/__pycache__/
equivariant_diffusion/__pycache__/
outputs/
archives/qm9/__pycache__/
archives/qm9/data_utils/__pycache__/
archives/qm9/data_utils/prepare/__pycache__/
archives/qm9/property_prediction/__pycache__/
archives/*
.env
results/*.ckpt
results/qm9_molecules_h
results/qm9_molecules_noh
dgd/analysis/orca/orca
results/*
ggg_data/
ggg_utils/
saved_models
src/analysis/orca/orca
src/analysis/orca/tmp_XMYAR426.txt
# New
archive.zip
logs/
generated/
data/processed/

View File

@ -1,9 +1,9 @@
Inverse Molecular Design with Multi-Conditional Diffusion Guidance Graph Diffusion Transformer for Multi-Conditional Molecular Generation
================================================================ ================================================================
Paper: https://arxiv.org/abs/2401.13858 Paper: https://arxiv.org/abs/2401.13858
This is the code for MCD: a Multi-Conditional Diffusion Model for inverse small molecule and polymer designs and generations. The denoising model architecture in `mcd/models` looks like: This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like:
<div style="display: flex;" markdown="1"> <div style="display: flex;" markdown="1">
<img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image"> <img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image">
@ -16,7 +16,7 @@ All dependencies are specified in the `requirements.txt` file.
This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1. This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1.
For molecular generation evaluation, we should first install rdkit: For molecular generation evaluation, we should first install rdkit.
Then `fcd_torch`: `pip install fcd_torch` (https://github.com/insilicomedicine/fcd_torch). Then `fcd_torch`: `pip install fcd_torch` (https://github.com/insilicomedicine/fcd_torch).

View File

@ -1,5 +1,5 @@
general: general:
name: 'MCD' name: 'graph_dit'
wandb: 'disabled' wandb: 'disabled'
gpus: 1 gpus: 1
resume: null resume: null
@ -14,11 +14,11 @@ general:
final_model_samples_to_save: 20 final_model_samples_to_save: 20
final_model_chains_to_save: 1 final_model_chains_to_save: 1
enable_progress_bar: False enable_progress_bar: False
save_model: False save_model: True
model: model:
type: 'discrete' type: 'discrete'
transition: 'marginal' transition: 'marginal'
model: 'MCD' model: 'graph_dit'
diffusion_steps: 500 diffusion_steps: 500
diffusion_noise_schedule: 'cosine' diffusion_noise_schedule: 'cosine'
guide_scale: 2 guide_scale: 2

View File

@ -12,7 +12,7 @@ from metrics.train_loss import TrainLossDiscrete
from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
import utils import utils
class MCD(pl.LightningModule): class Graph_DiT(pl.LightningModule):
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
super().__init__() super().__init__()
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
@ -174,7 +174,6 @@ class MCD(pl.LightningModule):
def validation_step(self, data, i): def validation_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask) dense_data = dense_data.mask(node_mask)
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
@ -281,7 +280,6 @@ class MCD(pl.LightningModule):
chains_left_to_save = self.cfg.general.final_model_chains_to_save chains_left_to_save = self.cfg.general.final_model_chains_to_save
samples, all_ys, batch_id = [], [], 0 samples, all_ys, batch_id = [], [], 0
test_y_collection = torch.cat(self.test_y_collection, dim=0) test_y_collection = torch.cat(self.test_y_collection, dim=0)
num_examples = test_y_collection.size(0) num_examples = test_y_collection.size(0)
if self.cfg.general.final_model_samples_to_generate > num_examples: if self.cfg.general.final_model_samples_to_generate > num_examples:

View File

@ -9,7 +9,7 @@ from pytorch_lightning import Trainer
import utils import utils
from datasets import dataset from datasets import dataset
from diffusion_model import MCD from diffusion_model import Graph_DiT
from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete
from metrics.molecular_metrics_sampling import SamplingMolecularMetrics from metrics.molecular_metrics_sampling import SamplingMolecularMetrics
@ -36,7 +36,7 @@ def get_resume(cfg, model_kwargs):
name = cfg.general.name + "_resume" name = cfg.general.name + "_resume"
resume = cfg.general.test_only resume = cfg.general.test_only
batch_size = cfg.train.batch_size batch_size = cfg.train.batch_size
model = MCD.load_from_checkpoint(resume, **model_kwargs) model = Graph_DiT.load_from_checkpoint(resume, **model_kwargs)
cfg = model.cfg cfg = model.cfg
cfg.general.test_only = resume cfg.general.test_only = resume
cfg.general.name = name cfg.general.name = name
@ -54,7 +54,7 @@ def get_resume_adaptive(cfg, model_kwargs):
resume_path = os.path.join(root_dir, cfg.general.resume) resume_path = os.path.join(root_dir, cfg.general.resume)
if cfg.model.type == "discrete": if cfg.model.type == "discrete":
model = MCD.load_from_checkpoint( model = Graph_DiT.load_from_checkpoint(
resume_path, **model_kwargs resume_path, **model_kwargs
) )
else: else:
@ -73,7 +73,7 @@ def get_resume_adaptive(cfg, model_kwargs):
@hydra.main( @hydra.main(
version_base="1.1", config_path="../configs", config_name="config_dev" version_base="1.1", config_path="../configs", config_name="config"
) )
def main(cfg: DictConfig): def main(cfg: DictConfig):
@ -106,7 +106,7 @@ def main(cfg: DictConfig):
cfg, _ = get_resume_adaptive(cfg, model_kwargs) cfg, _ = get_resume_adaptive(cfg, model_kwargs)
os.chdir(cfg.general.resume.split("checkpoints")[0]) os.chdir(cfg.general.resume.split("checkpoints")[0])
model = MCD(cfg=cfg, **model_kwargs) model = Graph_DiT(cfg=cfg, **model_kwargs)
trainer = Trainer( trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad, gradient_clip_val=cfg.train.clip_grad,
accelerator="gpu" accelerator="gpu"

View File

@ -44,7 +44,7 @@ class Denoiser(nn.Module):
] ]
) )
self.decoder = Decoder( self.out_layer = OutLayer(
max_n_nodes=max_n_nodes, max_n_nodes=max_n_nodes,
hidden_size=hidden_size, hidden_size=hidden_size,
atom_type=Xdim, atom_type=Xdim,
@ -73,7 +73,7 @@ class Denoiser(nn.Module):
for block in self.encoders : for block in self.encoders :
_constant_init(block.adaLN_modulation[0], 0) _constant_init(block.adaLN_modulation[0], 0)
_constant_init(self.decoder.adaLN_modulation[0], 0) _constant_init(self.out_layer.adaLN_modulation[0], 0)
def forward(self, x, e, node_mask, y, t, unconditioned): def forward(self, x, e, node_mask, y, t, unconditioned):
@ -99,7 +99,7 @@ class Denoiser(nn.Module):
x = block(x, c, node_mask) x = block(x, c, node_mask)
# X: B * N * dx, E: B * N * N * de # X: B * N * dx, E: B * N * N * de
X, E, y = self.decoder(x, x_in, e_in, c, t, node_mask) X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask)
return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask) return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
@ -140,8 +140,8 @@ class SELayer(nn.Module):
return x return x
class Decoder(nn.Module): class OutLayer(nn.Module):
# Structure Decoder # Structure Output Layer
def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None): def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
super().__init__() super().__init__()
self.atom_type = atom_type self.atom_type = atom_type