update_name
This commit is contained in:
parent
a6bd0117d4
commit
2c00828630
161
.gitignore
vendored
Normal file
161
.gitignore
vendored
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
.idea/
|
||||||
|
__pycache__/
|
||||||
|
dgd/configs/__pycache__/
|
||||||
|
egnn/__pycache__/
|
||||||
|
equivariant_diffusion/__pycache__/
|
||||||
|
outputs/
|
||||||
|
archives/qm9/__pycache__/
|
||||||
|
archives/qm9/data_utils/__pycache__/
|
||||||
|
archives/qm9/data_utils/prepare/__pycache__/
|
||||||
|
archives/qm9/property_prediction/__pycache__/
|
||||||
|
archives/*
|
||||||
|
.env
|
||||||
|
results/*.ckpt
|
||||||
|
results/qm9_molecules_h
|
||||||
|
results/qm9_molecules_noh
|
||||||
|
dgd/analysis/orca/orca
|
||||||
|
results/*
|
||||||
|
ggg_data/
|
||||||
|
ggg_utils/
|
||||||
|
saved_models
|
||||||
|
src/analysis/orca/orca
|
||||||
|
src/analysis/orca/tmp_XMYAR426.txt
|
||||||
|
|
||||||
|
|
||||||
|
# New
|
||||||
|
archive.zip
|
||||||
|
logs/
|
||||||
|
generated/
|
||||||
|
data/processed/
|
@ -1,9 +1,9 @@
|
|||||||
Inverse Molecular Design with Multi-Conditional Diffusion Guidance
|
Graph Diffusion Transformer for Multi-Conditional Molecular Generation
|
||||||
================================================================
|
================================================================
|
||||||
|
|
||||||
Paper: https://arxiv.org/abs/2401.13858
|
Paper: https://arxiv.org/abs/2401.13858
|
||||||
|
|
||||||
This is the code for MCD: a Multi-Conditional Diffusion Model for inverse small molecule and polymer designs and generations. The denoising model architecture in `mcd/models` looks like:
|
This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like:
|
||||||
|
|
||||||
<div style="display: flex;" markdown="1">
|
<div style="display: flex;" markdown="1">
|
||||||
<img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image">
|
<img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image">
|
||||||
@ -16,7 +16,7 @@ All dependencies are specified in the `requirements.txt` file.
|
|||||||
|
|
||||||
This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1.
|
This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1.
|
||||||
|
|
||||||
For molecular generation evaluation, we should first install rdkit:
|
For molecular generation evaluation, we should first install rdkit.
|
||||||
|
|
||||||
Then `fcd_torch`: `pip install fcd_torch` (https://github.com/insilicomedicine/fcd_torch).
|
Then `fcd_torch`: `pip install fcd_torch` (https://github.com/insilicomedicine/fcd_torch).
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
general:
|
general:
|
||||||
name: 'MCD'
|
name: 'graph_dit'
|
||||||
wandb: 'disabled'
|
wandb: 'disabled'
|
||||||
gpus: 1
|
gpus: 1
|
||||||
resume: null
|
resume: null
|
||||||
@ -14,11 +14,11 @@ general:
|
|||||||
final_model_samples_to_save: 20
|
final_model_samples_to_save: 20
|
||||||
final_model_chains_to_save: 1
|
final_model_chains_to_save: 1
|
||||||
enable_progress_bar: False
|
enable_progress_bar: False
|
||||||
save_model: False
|
save_model: True
|
||||||
model:
|
model:
|
||||||
type: 'discrete'
|
type: 'discrete'
|
||||||
transition: 'marginal'
|
transition: 'marginal'
|
||||||
model: 'MCD'
|
model: 'graph_dit'
|
||||||
diffusion_steps: 500
|
diffusion_steps: 500
|
||||||
diffusion_noise_schedule: 'cosine'
|
diffusion_noise_schedule: 'cosine'
|
||||||
guide_scale: 2
|
guide_scale: 2
|
||||||
|
@ -12,7 +12,7 @@ from metrics.train_loss import TrainLossDiscrete
|
|||||||
from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
|
from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
class MCD(pl.LightningModule):
|
class Graph_DiT(pl.LightningModule):
|
||||||
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||||
@ -174,7 +174,6 @@ class MCD(pl.LightningModule):
|
|||||||
def validation_step(self, data, i):
|
def validation_step(self, data, i):
|
||||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
||||||
|
|
||||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||||
dense_data = dense_data.mask(node_mask)
|
dense_data = dense_data.mask(node_mask)
|
||||||
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
||||||
@ -281,7 +280,6 @@ class MCD(pl.LightningModule):
|
|||||||
chains_left_to_save = self.cfg.general.final_model_chains_to_save
|
chains_left_to_save = self.cfg.general.final_model_chains_to_save
|
||||||
|
|
||||||
samples, all_ys, batch_id = [], [], 0
|
samples, all_ys, batch_id = [], [], 0
|
||||||
|
|
||||||
test_y_collection = torch.cat(self.test_y_collection, dim=0)
|
test_y_collection = torch.cat(self.test_y_collection, dim=0)
|
||||||
num_examples = test_y_collection.size(0)
|
num_examples = test_y_collection.size(0)
|
||||||
if self.cfg.general.final_model_samples_to_generate > num_examples:
|
if self.cfg.general.final_model_samples_to_generate > num_examples:
|
@ -9,7 +9,7 @@ from pytorch_lightning import Trainer
|
|||||||
|
|
||||||
import utils
|
import utils
|
||||||
from datasets import dataset
|
from datasets import dataset
|
||||||
from diffusion_model import MCD
|
from diffusion_model import Graph_DiT
|
||||||
from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete
|
from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete
|
||||||
from metrics.molecular_metrics_sampling import SamplingMolecularMetrics
|
from metrics.molecular_metrics_sampling import SamplingMolecularMetrics
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ def get_resume(cfg, model_kwargs):
|
|||||||
name = cfg.general.name + "_resume"
|
name = cfg.general.name + "_resume"
|
||||||
resume = cfg.general.test_only
|
resume = cfg.general.test_only
|
||||||
batch_size = cfg.train.batch_size
|
batch_size = cfg.train.batch_size
|
||||||
model = MCD.load_from_checkpoint(resume, **model_kwargs)
|
model = Graph_DiT.load_from_checkpoint(resume, **model_kwargs)
|
||||||
cfg = model.cfg
|
cfg = model.cfg
|
||||||
cfg.general.test_only = resume
|
cfg.general.test_only = resume
|
||||||
cfg.general.name = name
|
cfg.general.name = name
|
||||||
@ -54,7 +54,7 @@ def get_resume_adaptive(cfg, model_kwargs):
|
|||||||
resume_path = os.path.join(root_dir, cfg.general.resume)
|
resume_path = os.path.join(root_dir, cfg.general.resume)
|
||||||
|
|
||||||
if cfg.model.type == "discrete":
|
if cfg.model.type == "discrete":
|
||||||
model = MCD.load_from_checkpoint(
|
model = Graph_DiT.load_from_checkpoint(
|
||||||
resume_path, **model_kwargs
|
resume_path, **model_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -73,7 +73,7 @@ def get_resume_adaptive(cfg, model_kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@hydra.main(
|
@hydra.main(
|
||||||
version_base="1.1", config_path="../configs", config_name="config_dev"
|
version_base="1.1", config_path="../configs", config_name="config"
|
||||||
)
|
)
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ def main(cfg: DictConfig):
|
|||||||
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
|
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
|
||||||
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
||||||
|
|
||||||
model = MCD(cfg=cfg, **model_kwargs)
|
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
gradient_clip_val=cfg.train.clip_grad,
|
gradient_clip_val=cfg.train.clip_grad,
|
||||||
accelerator="gpu"
|
accelerator="gpu"
|
@ -44,7 +44,7 @@ class Denoiser(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder = Decoder(
|
self.out_layer = OutLayer(
|
||||||
max_n_nodes=max_n_nodes,
|
max_n_nodes=max_n_nodes,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
atom_type=Xdim,
|
atom_type=Xdim,
|
||||||
@ -73,7 +73,7 @@ class Denoiser(nn.Module):
|
|||||||
|
|
||||||
for block in self.encoders :
|
for block in self.encoders :
|
||||||
_constant_init(block.adaLN_modulation[0], 0)
|
_constant_init(block.adaLN_modulation[0], 0)
|
||||||
_constant_init(self.decoder.adaLN_modulation[0], 0)
|
_constant_init(self.out_layer.adaLN_modulation[0], 0)
|
||||||
|
|
||||||
def forward(self, x, e, node_mask, y, t, unconditioned):
|
def forward(self, x, e, node_mask, y, t, unconditioned):
|
||||||
|
|
||||||
@ -99,7 +99,7 @@ class Denoiser(nn.Module):
|
|||||||
x = block(x, c, node_mask)
|
x = block(x, c, node_mask)
|
||||||
|
|
||||||
# X: B * N * dx, E: B * N * N * de
|
# X: B * N * dx, E: B * N * N * de
|
||||||
X, E, y = self.decoder(x, x_in, e_in, c, t, node_mask)
|
X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask)
|
||||||
return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
|
return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
|
||||||
|
|
||||||
|
|
||||||
@ -140,8 +140,8 @@ class SELayer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class OutLayer(nn.Module):
|
||||||
# Structure Decoder
|
# Structure Output Layer
|
||||||
def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
|
def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.atom_type = atom_type
|
self.atom_type = atom_type
|
Loading…
Reference in New Issue
Block a user