init_code
This commit is contained in:
parent
353d892291
commit
91727d2500
@ -14,7 +14,7 @@ This is the code for MCD: a Multi-Conditional Diffusion Model for inverse small
|
||||
## Requirements
|
||||
All dependencies are specified in the `requirements.txt` file.
|
||||
|
||||
This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, pytorch-lightning 2.0.1.
|
||||
This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1.
|
||||
|
||||
For molecular generation evaluation, we should first install rdkit:
|
||||
|
||||
|
0
configs/__init__.py
Normal file
0
configs/__init__.py
Normal file
46
configs/config.yaml
Normal file
46
configs/config.yaml
Normal file
@ -0,0 +1,46 @@
|
||||
general:
|
||||
name: 'MCD'
|
||||
wandb: 'disabled'
|
||||
gpus: 1
|
||||
resume: null
|
||||
test_only: null
|
||||
sample_every_val: 2500
|
||||
samples_to_generate: 512
|
||||
samples_to_save: 3
|
||||
chains_to_save: 1
|
||||
log_every_steps: 50
|
||||
number_chain_steps: 8
|
||||
final_model_samples_to_generate: 10000
|
||||
final_model_samples_to_save: 20
|
||||
final_model_chains_to_save: 1
|
||||
enable_progress_bar: False
|
||||
save_model: False
|
||||
model:
|
||||
type: 'discrete'
|
||||
transition: 'marginal'
|
||||
model: 'MCD'
|
||||
diffusion_steps: 500
|
||||
diffusion_noise_schedule: 'cosine'
|
||||
guide_scale: 2
|
||||
hidden_size: 1152
|
||||
depth: 6
|
||||
num_heads: 16
|
||||
mlp_ratio: 4
|
||||
drop_condition: 0.01
|
||||
lambda_train: [1, 10] # node and edge training weight
|
||||
ensure_connected: True
|
||||
train:
|
||||
n_epochs: 10000
|
||||
batch_size: 1200
|
||||
lr: 0.0002
|
||||
clip_grad: null
|
||||
num_workers: 0
|
||||
weight_decay: 0
|
||||
seed: 0
|
||||
val_check_interval: null
|
||||
check_val_every_n_epoch: 1
|
||||
dataset:
|
||||
datadir: 'data/'
|
||||
task_name: null
|
||||
guidance_target: null
|
||||
pin_memory: False
|
0
mcd/__init__.py
Normal file
0
mcd/__init__.py
Normal file
0
mcd/analysis/__init__.py
Normal file
0
mcd/analysis/__init__.py
Normal file
411
mcd/analysis/rdkit_functions.py
Normal file
411
mcd/analysis/rdkit_functions.py
Normal file
@ -0,0 +1,411 @@
|
||||
from rdkit import Chem, RDLogger
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
from fcd_torch import FCD as FCDMetric
|
||||
from mini_moses.metrics.metrics import FragMetric, internal_diversity
|
||||
from mini_moses.metrics.utils import get_mol, mapper
|
||||
|
||||
import re
|
||||
import time
|
||||
import random
|
||||
random.seed(0)
|
||||
import numpy as np
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
from metrics.property_metric import calculateSAS
|
||||
|
||||
bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE,
|
||||
Chem.rdchem.BondType.AROMATIC]
|
||||
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}
|
||||
|
||||
bd_dict_x = {'O2-N2': [5.00E+04, 1.00E-03]}
|
||||
bd_dict_y = {'O2-N2': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05]}
|
||||
|
||||
selectivity = ['O2-N2']
|
||||
a_dict = {}
|
||||
b_dict = {}
|
||||
for prop_name in selectivity:
|
||||
x1, x2 = np.log10(bd_dict_x[prop_name][0]), np.log10(bd_dict_x[prop_name][1])
|
||||
y1, y2 = np.log10(bd_dict_y[prop_name][0]), np.log10(bd_dict_y[prop_name][1])
|
||||
a = (y1-y2)/(x1-x2)
|
||||
b = y1-a*x1
|
||||
a_dict[prop_name] = a
|
||||
b_dict[prop_name] = b
|
||||
|
||||
def selectivity_evaluation(gas1, gas2, prop_name):
|
||||
x = np.log10(np.array(gas1))
|
||||
y = np.log10(np.array(gas1) / np.array(gas2))
|
||||
upper = (y - (a_dict[prop_name] * x + b_dict[prop_name])) > 0
|
||||
return upper
|
||||
|
||||
class BasicMolecularMetrics(object):
|
||||
def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512):
|
||||
self.dataset_smiles_list = train_smiles
|
||||
self.atom_decoder = atom_decoder
|
||||
self.n_jobs = n_jobs
|
||||
self.device = device
|
||||
self.batch_size = batch_size
|
||||
self.stat_ref = stat_ref
|
||||
self.task_evaluator = task_evaluator
|
||||
|
||||
def compute_relaxed_validity(self, generated, ensure_connected):
|
||||
valid = []
|
||||
num_components = []
|
||||
all_smiles = []
|
||||
valid_mols = []
|
||||
covered_atoms = set()
|
||||
direct_valid_count = 0
|
||||
for graph in generated:
|
||||
atom_types, edge_types = graph
|
||||
mol = build_molecule_with_partial_charges(atom_types, edge_types, self.atom_decoder)
|
||||
direct_valid_flag = True if check_mol(mol, largest_connected_comp=True) is not None else False
|
||||
if direct_valid_flag:
|
||||
direct_valid_count += 1
|
||||
if not ensure_connected:
|
||||
mol_conn, _ = correct_mol(mol, connection=True)
|
||||
mol = mol_conn if mol_conn is not None else correct_mol(mol, connection=False)[0]
|
||||
else: # ensure fully connected
|
||||
mol, _ = correct_mol(mol, connection=True)
|
||||
smiles = mol2smiles(mol)
|
||||
mol = get_mol(smiles)
|
||||
try:
|
||||
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
|
||||
num_components.append(len(mol_frags))
|
||||
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
|
||||
smiles = mol2smiles(largest_mol)
|
||||
if smiles is not None and largest_mol is not None and len(smiles) > 1 and Chem.MolFromSmiles(smiles) is not None:
|
||||
valid_mols.append(largest_mol)
|
||||
valid.append(smiles)
|
||||
for atom in largest_mol.GetAtoms():
|
||||
covered_atoms.add(atom.GetSymbol())
|
||||
all_smiles.append(smiles)
|
||||
else:
|
||||
all_smiles.append(None)
|
||||
except Exception as e:
|
||||
# print(f"An error occurred: {e}")
|
||||
all_smiles.append(None)
|
||||
|
||||
return valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_smiles, covered_atoms
|
||||
|
||||
def evaluate(self, generated, targets, ensure_connected, active_atoms=None):
|
||||
""" generated: list of pairs (positions: n x 3, atom_types: n [int])
|
||||
the positions and atom types should already be masked. """
|
||||
valid, validity, nc_validity, num_components, all_smiles, covered_atoms = self.compute_relaxed_validity(generated, ensure_connected=ensure_connected)
|
||||
nc_mu = num_components.mean() if len(num_components) > 0 else 0
|
||||
nc_min = num_components.min() if len(num_components) > 0 else 0
|
||||
nc_max = num_components.max() if len(num_components) > 0 else 0
|
||||
|
||||
len_active = len(active_atoms) if active_atoms is not None else 1
|
||||
|
||||
cover_str = f"Cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}"
|
||||
print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}")
|
||||
print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")
|
||||
|
||||
if validity > 0:
|
||||
dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity}
|
||||
unique = list(set(valid))
|
||||
close_pool = False
|
||||
if self.n_jobs != 1:
|
||||
pool = Pool(self.n_jobs)
|
||||
close_pool = True
|
||||
else:
|
||||
pool = 1
|
||||
valid_mols = mapper(pool)(get_mol, valid)
|
||||
dist_metrics['interval_diversity'] = internal_diversity(valid_mols, pool, device=self.device)
|
||||
|
||||
start_time = time.time()
|
||||
if self.stat_ref is not None:
|
||||
kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size}
|
||||
kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size}
|
||||
try:
|
||||
dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_mols, pref=self.stat_ref['Frag'])
|
||||
except:
|
||||
print('error: ', 'pool', pool)
|
||||
print('valid_mols: ', valid_mols)
|
||||
dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD'])
|
||||
|
||||
if self.task_evaluator is not None:
|
||||
evaluation_list = list(self.task_evaluator.keys())
|
||||
evaluation_list = evaluation_list.copy()
|
||||
|
||||
assert 'meta_taskname' in evaluation_list
|
||||
meta_taskname = self.task_evaluator['meta_taskname']
|
||||
evaluation_list.remove('meta_taskname')
|
||||
meta_split = meta_taskname.split('-')
|
||||
|
||||
valid_index = np.array([True if smiles else False for smiles in all_smiles])
|
||||
targets_log = {}
|
||||
for i, name in enumerate(evaluation_list):
|
||||
targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index))
|
||||
targets_log[f'input_{name}'] = targets[:, i]
|
||||
|
||||
targets = targets[valid_index]
|
||||
if len(meta_split) == 2:
|
||||
cached_perm = {meta_split[0]: None, meta_split[1]: None}
|
||||
|
||||
for i, name in enumerate(evaluation_list):
|
||||
if name == 'scs':
|
||||
continue
|
||||
elif name == 'sas':
|
||||
scores = calculateSAS(valid)
|
||||
else:
|
||||
scores = self.task_evaluator[name](valid)
|
||||
targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index))
|
||||
targets_log[f'output_{name}'][valid_index] = scores
|
||||
if name in ['O2', 'N2', 'CO2']:
|
||||
if len(meta_split) == 2:
|
||||
cached_perm[name] = scores
|
||||
scores, cur_targets = np.log10(scores), np.log10(targets[:, i])
|
||||
dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets))
|
||||
elif name == 'sas':
|
||||
dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i]))
|
||||
else:
|
||||
true_y = targets[:, i]
|
||||
predicted_labels = (scores >= 0.5).astype(int)
|
||||
acc = (predicted_labels == true_y).sum() / len(true_y)
|
||||
dist_metrics[f'{name}/acc'] = acc
|
||||
|
||||
if len(meta_split) == 2:
|
||||
if cached_perm[meta_split[0]] is not None and cached_perm[meta_split[1]] is not None:
|
||||
task_name = self.task_evaluator['meta_taskname']
|
||||
upper = selectivity_evaluation(cached_perm[meta_split[0]], cached_perm[meta_split[1]], task_name)
|
||||
dist_metrics[f'selectivity/{task_name}'] = np.sum(upper)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
max_key_length = max(len(key) for key in dist_metrics)
|
||||
print(f'Details over {len(valid)} ({len(generated)}) valid (total) molecules, calculating metrics using {elapsed_time:.2f} s:')
|
||||
strs = ''
|
||||
for i, (key, value) in enumerate(dist_metrics.items()):
|
||||
if isinstance(value, (int, float, np.floating, np.integer)):
|
||||
strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t'
|
||||
if i % 4 == 3:
|
||||
strs = strs + '\n'
|
||||
print(strs)
|
||||
|
||||
if close_pool:
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
unique = []
|
||||
dist_metrics = {}
|
||||
targets_log = None
|
||||
return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles, dist_metrics, targets_log
|
||||
|
||||
def mol2smiles(mol):
|
||||
if mol is None:
|
||||
return None
|
||||
try:
|
||||
Chem.SanitizeMol(mol)
|
||||
except ValueError:
|
||||
return None
|
||||
return Chem.MolToSmiles(mol)
|
||||
|
||||
def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False):
|
||||
if verbose:
|
||||
print("\nbuilding new molecule")
|
||||
|
||||
mol = Chem.RWMol()
|
||||
for atom in atom_types:
|
||||
a = Chem.Atom(atom_decoder[atom.item()])
|
||||
mol.AddAtom(a)
|
||||
if verbose:
|
||||
print("Atom added: ", atom.item(), atom_decoder[atom.item()])
|
||||
|
||||
edge_types = torch.triu(edge_types)
|
||||
all_bonds = torch.nonzero(edge_types)
|
||||
|
||||
for i, bond in enumerate(all_bonds):
|
||||
if bond[0].item() != bond[1].item():
|
||||
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
|
||||
if verbose:
|
||||
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
|
||||
bond_dict[edge_types[bond[0], bond[1]].item()])
|
||||
# add formal charge to atom: e.g. [O+], [N+], [S+]
|
||||
# not support [O-], [N-], [S-], [NH+] etc.
|
||||
flag, atomid_valence = check_valency(mol)
|
||||
if verbose:
|
||||
print("flag, valence", flag, atomid_valence)
|
||||
if flag:
|
||||
continue
|
||||
else:
|
||||
if len(atomid_valence) == 2:
|
||||
idx = atomid_valence[0]
|
||||
v = atomid_valence[1]
|
||||
an = mol.GetAtomWithIdx(idx).GetAtomicNum()
|
||||
if verbose:
|
||||
print("atomic num of atom with a large valence", an)
|
||||
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
|
||||
mol.GetAtomWithIdx(idx).SetFormalCharge(1)
|
||||
# print("Formal charge added")
|
||||
else:
|
||||
continue
|
||||
return mol
|
||||
|
||||
def check_valency(mol):
|
||||
try:
|
||||
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
|
||||
return True, None
|
||||
except ValueError as e:
|
||||
e = str(e)
|
||||
p = e.find('#')
|
||||
e_sub = e[p:]
|
||||
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
|
||||
return False, atomid_valence
|
||||
|
||||
|
||||
def correct_mol(mol, connection=False):
|
||||
#####
|
||||
no_correct = False
|
||||
flag, _ = check_valency(mol)
|
||||
if flag:
|
||||
no_correct = True
|
||||
|
||||
while True:
|
||||
if connection:
|
||||
mol_conn = connect_fragments(mol)
|
||||
# if mol_conn is not None:
|
||||
mol = mol_conn
|
||||
if mol is None:
|
||||
return None, no_correct
|
||||
flag, atomid_valence = check_valency(mol)
|
||||
if flag:
|
||||
break
|
||||
else:
|
||||
try:
|
||||
assert len(atomid_valence) == 2
|
||||
idx = atomid_valence[0]
|
||||
v = atomid_valence[1]
|
||||
queue = []
|
||||
check_idx = 0
|
||||
for b in mol.GetAtomWithIdx(idx).GetBonds():
|
||||
type = int(b.GetBondType())
|
||||
queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
|
||||
if type == 12:
|
||||
check_idx += 1
|
||||
queue.sort(key=lambda tup: tup[1], reverse=True)
|
||||
|
||||
if queue[-1][1] == 12:
|
||||
return None, no_correct
|
||||
elif len(queue) > 0:
|
||||
start = queue[check_idx][2]
|
||||
end = queue[check_idx][3]
|
||||
t = queue[check_idx][1] - 1
|
||||
mol.RemoveBond(start, end)
|
||||
if t >= 1:
|
||||
mol.AddBond(start, end, bond_dict[t])
|
||||
except Exception as e:
|
||||
# print(f"An error occurred in correction: {e}")
|
||||
return None, no_correct
|
||||
return mol, no_correct
|
||||
|
||||
|
||||
def check_mol(m, largest_connected_comp=True):
|
||||
if m is None:
|
||||
return None
|
||||
sm = Chem.MolToSmiles(m, isomericSmiles=True)
|
||||
if largest_connected_comp and '.' in sm:
|
||||
vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
|
||||
vsm.sort(key=lambda tup: tup[1], reverse=True)
|
||||
mol = Chem.MolFromSmiles(vsm[0][0])
|
||||
else:
|
||||
mol = Chem.MolFromSmiles(sm)
|
||||
return mol
|
||||
|
||||
|
||||
##### connect fragements
|
||||
def select_atom_with_available_valency(frag):
|
||||
atoms = list(frag.GetAtoms())
|
||||
random.shuffle(atoms)
|
||||
for atom in atoms:
|
||||
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0:
|
||||
return atom
|
||||
|
||||
return None
|
||||
|
||||
def select_atoms_with_available_valency(frag):
|
||||
return [atom for atom in frag.GetAtoms() if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0]
|
||||
|
||||
def try_to_connect_fragments(combined_mol, frag, atom1, atom2):
|
||||
# Make copies of the molecules to try the connection
|
||||
trial_combined_mol = Chem.RWMol(combined_mol)
|
||||
trial_frag = Chem.RWMol(frag)
|
||||
|
||||
# Add the new fragment to the combined molecule with new indices
|
||||
new_indices = {atom.GetIdx(): trial_combined_mol.AddAtom(atom) for atom in trial_frag.GetAtoms()}
|
||||
|
||||
# Add the bond between the suitable atoms from each fragment
|
||||
trial_combined_mol.AddBond(atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE)
|
||||
|
||||
# Adjust the hydrogen count of the connected atoms
|
||||
for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]:
|
||||
atom = trial_combined_mol.GetAtomWithIdx(atom_idx)
|
||||
num_h = atom.GetTotalNumHs()
|
||||
atom.SetNumExplicitHs(max(0, num_h - 1))
|
||||
|
||||
# Add bonds for the new fragment
|
||||
for bond in trial_frag.GetBonds():
|
||||
trial_combined_mol.AddBond(new_indices[bond.GetBeginAtomIdx()], new_indices[bond.GetEndAtomIdx()], bond.GetBondType())
|
||||
|
||||
# Convert to a Mol object and try to sanitize it
|
||||
new_mol = Chem.Mol(trial_combined_mol)
|
||||
try:
|
||||
Chem.SanitizeMol(new_mol)
|
||||
return new_mol # Return the new valid molecule
|
||||
except Chem.MolSanitizeException:
|
||||
return None # If the molecule is not valid, return None
|
||||
|
||||
def connect_fragments(mol):
|
||||
# Get the separate fragments
|
||||
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
|
||||
if len(frags) < 2:
|
||||
return mol
|
||||
|
||||
combined_mol = Chem.RWMol(frags[0])
|
||||
|
||||
for frag in frags[1:]:
|
||||
# Select all atoms with available valency from both molecules
|
||||
atoms1 = select_atoms_with_available_valency(combined_mol)
|
||||
atoms2 = select_atoms_with_available_valency(frag)
|
||||
|
||||
# Try to connect using all combinations of available valency atoms
|
||||
for atom1 in atoms1:
|
||||
for atom2 in atoms2:
|
||||
new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2)
|
||||
if new_mol is not None:
|
||||
# If a valid connection is made, update the combined molecule and break
|
||||
combined_mol = new_mol
|
||||
break
|
||||
else:
|
||||
# Continue if the inner loop didn't break (no valid connection found for atom1)
|
||||
continue
|
||||
# Break if the inner loop did break (valid connection found)
|
||||
break
|
||||
else:
|
||||
# If no valid connections could be made with any of the atoms, return None
|
||||
return None
|
||||
|
||||
return combined_mol
|
||||
|
||||
#### connect fragements
|
||||
|
||||
def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config):
|
||||
""" molecule_list: (dict) """
|
||||
|
||||
atom_decoder = dataset_info.atom_decoder
|
||||
active_atoms = dataset_info.active_atoms
|
||||
ensure_connected = dataset_info.ensure_connected
|
||||
metrics = BasicMolecularMetrics(atom_decoder, train_smiles, stat_ref, task_evaluator, **comput_config)
|
||||
evaluated_res = metrics.evaluate(molecule_list, targets, ensure_connected, active_atoms)
|
||||
all_smiles = evaluated_res[-3]
|
||||
all_metrics = evaluated_res[-2]
|
||||
targets_log = evaluated_res[-1]
|
||||
unique_smiles = evaluated_res[0]
|
||||
|
||||
return unique_smiles, all_smiles, all_metrics, targets_log
|
||||
|
||||
if __name__ == '__main__':
|
||||
smiles_mol = 'C1CCC1'
|
||||
print("Smiles mol %s" % smiles_mol)
|
||||
chem_mol = Chem.MolFromSmiles(smiles_mol)
|
||||
print(block_mol)
|
222
mcd/analysis/visualization.py
Normal file
222
mcd/analysis/visualization.py
Normal file
@ -0,0 +1,222 @@
|
||||
import os
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import Draw, AllChem
|
||||
from rdkit.Geometry import Point3D
|
||||
from rdkit import RDLogger
|
||||
import imageio
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import rdkit.Chem
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class MolecularVisualization:
|
||||
def __init__(self, dataset_infos):
|
||||
self.dataset_infos = dataset_infos
|
||||
|
||||
def mol_from_graphs(self, node_list, adjacency_matrix):
|
||||
"""
|
||||
Convert graphs to rdkit molecules
|
||||
node_list: the nodes of a batch of nodes (bs x n)
|
||||
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
||||
"""
|
||||
# dictionary to map integer value to the char of atom
|
||||
atom_decoder = self.dataset_infos.atom_decoder
|
||||
# [list(self.dataset_infos.atom_decoder.keys())[0]]
|
||||
|
||||
# create empty editable mol object
|
||||
mol = Chem.RWMol()
|
||||
|
||||
# add atoms to mol and keep track of index
|
||||
node_to_idx = {}
|
||||
for i in range(len(node_list)):
|
||||
if node_list[i] == -1:
|
||||
continue
|
||||
a = Chem.Atom(atom_decoder[int(node_list[i])])
|
||||
molIdx = mol.AddAtom(a)
|
||||
node_to_idx[i] = molIdx
|
||||
|
||||
for ix, row in enumerate(adjacency_matrix):
|
||||
for iy, bond in enumerate(row):
|
||||
# only traverse half the symmetric matrix
|
||||
if iy <= ix:
|
||||
continue
|
||||
if bond == 1:
|
||||
bond_type = Chem.rdchem.BondType.SINGLE
|
||||
elif bond == 2:
|
||||
bond_type = Chem.rdchem.BondType.DOUBLE
|
||||
elif bond == 3:
|
||||
bond_type = Chem.rdchem.BondType.TRIPLE
|
||||
elif bond == 4:
|
||||
bond_type = Chem.rdchem.BondType.AROMATIC
|
||||
else:
|
||||
continue
|
||||
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
|
||||
|
||||
try:
|
||||
mol = mol.GetMol()
|
||||
except rdkit.Chem.KekulizeException:
|
||||
print("Can't kekulize molecule")
|
||||
mol = None
|
||||
return mol
|
||||
|
||||
def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'):
|
||||
# define path to save figures
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
# visualize the final molecules
|
||||
print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}")
|
||||
if num_molecules_to_visualize > len(molecules):
|
||||
print(f"Shortening to {len(molecules)}")
|
||||
num_molecules_to_visualize = len(molecules)
|
||||
|
||||
for i in range(num_molecules_to_visualize):
|
||||
file_path = os.path.join(path, 'molecule_{}.png'.format(i))
|
||||
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
|
||||
try:
|
||||
Draw.MolToFile(mol, file_path)
|
||||
except rdkit.Chem.KekulizeException:
|
||||
print("Can't kekulize molecule")
|
||||
|
||||
def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None):
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
# convert graphs to the rdkit molecules
|
||||
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
|
||||
|
||||
# find the coordinates of atoms in the final molecule
|
||||
final_molecule = mols[-1]
|
||||
AllChem.Compute2DCoords(final_molecule)
|
||||
|
||||
coords = []
|
||||
for i, atom in enumerate(final_molecule.GetAtoms()):
|
||||
positions = final_molecule.GetConformer().GetAtomPosition(i)
|
||||
coords.append((positions.x, positions.y, positions.z))
|
||||
|
||||
# align all the molecules
|
||||
for i, mol in enumerate(mols):
|
||||
AllChem.Compute2DCoords(mol)
|
||||
conf = mol.GetConformer()
|
||||
for j, atom in enumerate(mol.GetAtoms()):
|
||||
x, y, z = coords[j]
|
||||
conf.SetAtomPosition(j, Point3D(x, y, z))
|
||||
|
||||
# draw gif
|
||||
save_paths = []
|
||||
num_frams = nodes_list.shape[0]
|
||||
|
||||
for frame in range(num_frams):
|
||||
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
|
||||
Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}")
|
||||
save_paths.append(file_name)
|
||||
|
||||
imgs = [imageio.imread(fn) for fn in save_paths]
|
||||
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
||||
imgs.extend([imgs[-1]] * 10)
|
||||
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
||||
|
||||
# draw grid image
|
||||
try:
|
||||
img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200))
|
||||
img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1])))
|
||||
except Chem.rdchem.KekulizeException:
|
||||
print("Can't kekulize molecule")
|
||||
return mols
|
||||
|
||||
def visualize_by_smiles(self, path: str, smiles_list: list, num_to_visualize: int):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
print(f"Visualizing corrected {num_to_visualize} of {len(smiles_list)}")
|
||||
if num_to_visualize > len(smiles_list):
|
||||
print(f"Shortening to {len(smiles_list)}")
|
||||
num_to_visualize = len(smiles_list)
|
||||
|
||||
for i in range(num_to_visualize):
|
||||
file_path = os.path.join(path, 'molecule_corrected_{}.png'.format(i))
|
||||
if smiles_list[i] is None:
|
||||
continue
|
||||
mol = Chem.MolFromSmiles(smiles_list[i])
|
||||
try:
|
||||
Draw.MolToFile(mol, file_path)
|
||||
except rdkit.Chem.KekulizeException:
|
||||
print("Can't kekulize molecule")
|
||||
|
||||
class NonMolecularVisualization:
|
||||
def to_networkx(self, node_list, adjacency_matrix):
|
||||
"""
|
||||
Convert graphs to networkx graphs
|
||||
node_list: the nodes of a batch of nodes (bs x n)
|
||||
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
||||
"""
|
||||
graph = nx.Graph()
|
||||
|
||||
for i in range(len(node_list)):
|
||||
if node_list[i] == -1:
|
||||
continue
|
||||
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
|
||||
|
||||
rows, cols = np.where(adjacency_matrix >= 1)
|
||||
edges = zip(rows.tolist(), cols.tolist())
|
||||
for edge in edges:
|
||||
edge_type = adjacency_matrix[edge[0]][edge[1]]
|
||||
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
|
||||
|
||||
return graph
|
||||
|
||||
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False):
|
||||
if largest_component:
|
||||
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
|
||||
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
||||
graph = CGs[0]
|
||||
|
||||
# Plot the graph structure with colors
|
||||
if pos is None:
|
||||
pos = nx.spring_layout(graph, iterations=iterations)
|
||||
|
||||
# Set node colors based on the eigenvectors
|
||||
w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray())
|
||||
vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1])
|
||||
m = max(np.abs(vmin), vmax)
|
||||
vmin, vmax = -m, m
|
||||
|
||||
plt.figure()
|
||||
nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
|
||||
cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path)
|
||||
plt.close("all")
|
||||
|
||||
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
|
||||
# define path to save figures
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
# visualize the final molecules
|
||||
for i in range(num_graphs_to_visualize):
|
||||
file_path = os.path.join(path, 'graph_{}.png'.format(i))
|
||||
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
|
||||
self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
|
||||
im = plt.imread(file_path)
|
||||
|
||||
def visualize_chain(self, path, nodes_list, adjacency_matrix):
|
||||
# convert graphs to networkx
|
||||
graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
|
||||
# find the coordinates of atoms in the final molecule
|
||||
final_graph = graphs[-1]
|
||||
final_pos = nx.spring_layout(final_graph, seed=0)
|
||||
|
||||
# draw gif
|
||||
save_paths = []
|
||||
num_frams = nodes_list.shape[0]
|
||||
|
||||
for frame in range(num_frams):
|
||||
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
|
||||
self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name)
|
||||
save_paths.append(file_name)
|
||||
|
||||
imgs = [imageio.imread(fn) for fn in save_paths]
|
||||
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
||||
imgs.extend([imgs[-1]] * 10)
|
||||
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
0
mcd/datasets/__init__.py
Normal file
0
mcd/datasets/__init__.py
Normal file
126
mcd/datasets/abstract_dataset.py
Normal file
126
mcd/datasets/abstract_dataset.py
Normal file
@ -0,0 +1,126 @@
|
||||
from diffusion.distributions import DistributionNodes
|
||||
import utils as utils
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch_geometric.loader import DataLoader
|
||||
|
||||
|
||||
class AbstractDataModule(pl.LightningDataModule):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.dataloaders = None
|
||||
self.input_dims = None
|
||||
self.output_dims = None
|
||||
|
||||
def prepare_data(self, datasets) -> None:
|
||||
batch_size = self.cfg.train.batch_size
|
||||
num_workers = self.cfg.train.num_workers
|
||||
self.dataloaders = {split: DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
|
||||
shuffle='debug' not in self.cfg.general.name)
|
||||
for split, dataset in datasets.items()}
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.dataloaders["train"]
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.dataloaders["val"]
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.dataloaders["test"]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataloaders['train'][idx]
|
||||
|
||||
def node_counts(self, max_nodes_possible=300):
|
||||
all_counts = torch.zeros(max_nodes_possible)
|
||||
for split in ['train', 'val', 'test']:
|
||||
for i, data in enumerate(self.dataloaders[split]):
|
||||
unique, counts = torch.unique(data.batch, return_counts=True)
|
||||
for count in counts:
|
||||
all_counts[count] += 1
|
||||
max_index = max(all_counts.nonzero())
|
||||
all_counts = all_counts[:max_index + 1]
|
||||
all_counts = all_counts / all_counts.sum()
|
||||
return all_counts
|
||||
|
||||
def node_types(self):
|
||||
num_classes = None
|
||||
for data in self.dataloaders['train']:
|
||||
num_classes = data.x.shape[1]
|
||||
break
|
||||
|
||||
counts = torch.zeros(num_classes)
|
||||
|
||||
for split in ['train', 'val', 'test']:
|
||||
for i, data in enumerate(self.dataloaders[split]):
|
||||
counts += data.x.sum(dim=0)
|
||||
|
||||
counts = counts / counts.sum()
|
||||
return counts
|
||||
|
||||
def edge_counts(self):
|
||||
num_classes = None
|
||||
for data in self.dataloaders['train']:
|
||||
num_classes = 5
|
||||
break
|
||||
|
||||
d = torch.Tensor(num_classes)
|
||||
|
||||
for split in ['train', 'val', 'test']:
|
||||
for i, data in enumerate(self.dataloaders[split]):
|
||||
unique, counts = torch.unique(data.batch, return_counts=True)
|
||||
|
||||
all_pairs = 0
|
||||
for count in counts:
|
||||
all_pairs += count * (count - 1)
|
||||
|
||||
num_edges = data.edge_index.shape[1]
|
||||
num_non_edges = all_pairs - num_edges
|
||||
|
||||
data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float()
|
||||
edge_types = data_edge_attr.sum(dim=0)
|
||||
assert num_non_edges >= 0
|
||||
d[0] += num_non_edges
|
||||
d[1:] += edge_types[1:]
|
||||
|
||||
d = d / d.sum()
|
||||
return d
|
||||
|
||||
|
||||
class MolecularDataModule(AbstractDataModule):
|
||||
def valency_count(self, max_n_nodes):
|
||||
valencies = torch.zeros(3 * max_n_nodes - 2) # Max valency possible if everything is connected
|
||||
multiplier = torch.Tensor([0, 1, 2, 3, 1.5])
|
||||
for split in ['train', 'val', 'test']:
|
||||
for i, data in enumerate(self.dataloaders[split]):
|
||||
n = data.x.shape[0]
|
||||
for atom in range(n):
|
||||
data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float()
|
||||
edges = data_edge_attr[data.edge_index[0] == atom]
|
||||
edges_total = edges.sum(dim=0)
|
||||
valency = (edges_total * multiplier).sum()
|
||||
valencies[valency.long().item()] += 1
|
||||
valencies = valencies / valencies.sum()
|
||||
return valencies
|
||||
|
||||
|
||||
class AbstractDatasetInfos:
|
||||
def complete_infos(self, n_nodes, node_types):
|
||||
self.input_dims = None
|
||||
self.output_dims = None
|
||||
self.num_classes = len(node_types)
|
||||
self.max_n_nodes = len(n_nodes) - 1
|
||||
self.nodes_dist = DistributionNodes(n_nodes)
|
||||
|
||||
def compute_input_output_dims(self, datamodule):
|
||||
example_batch = datamodule.example_batch()
|
||||
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
|
||||
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float()
|
||||
|
||||
self.input_dims = {'X': example_batch_x.size(1),
|
||||
'E': example_batch_edge_attr.size(1),
|
||||
'y': example_batch['y'].size(1)}
|
||||
self.output_dims = {'X': example_batch_x.size(1),
|
||||
'E': example_batch_edge_attr.size(1),
|
||||
'y': example_batch['y'].size(1)}
|
381
mcd/datasets/dataset.py
Normal file
381
mcd/datasets/dataset.py
Normal file
@ -0,0 +1,381 @@
|
||||
|
||||
import sys
|
||||
sys.path.append('../')
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import pathlib
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from rdkit import Chem, RDLogger
|
||||
from rdkit.Chem.rdchem import BondType as BT
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from torch_geometric.data import Data, InMemoryDataset
|
||||
from torch_geometric.loader import DataLoader
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import utils as utils
|
||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||
from diffusion.distributions import DistributionNodes
|
||||
|
||||
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||
|
||||
class DataModule(AbstractDataModule):
|
||||
def __init__(self, cfg):
|
||||
self.datadir = cfg.dataset.datadir
|
||||
self.task = cfg.dataset.task_name
|
||||
super().__init__(cfg)
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
target = getattr(self.cfg.dataset, 'guidance_target', None)
|
||||
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||
root_path = os.path.join(base_path, self.datadir)
|
||||
self.root_path = root_path
|
||||
|
||||
batch_size = self.cfg.train.batch_size
|
||||
num_workers = self.cfg.train.num_workers
|
||||
pin_memory = self.cfg.dataset.pin_memory
|
||||
|
||||
dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
|
||||
|
||||
if len(self.task.split('-')) == 2:
|
||||
train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
|
||||
else:
|
||||
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
|
||||
|
||||
self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
|
||||
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
|
||||
if len(unlabeled_index) > 0:
|
||||
train_index = torch.cat([train_index, unlabeled_index], dim=0)
|
||||
|
||||
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
|
||||
self.train_dataset = train_dataset
|
||||
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
|
||||
self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
||||
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
||||
|
||||
training_iterations = len(train_dataset) // batch_size
|
||||
self.training_iterations = training_iterations
|
||||
|
||||
def random_data_split(self, dataset):
|
||||
nan_count = torch.isnan(dataset.y[:, 0]).sum().item()
|
||||
labeled_len = len(dataset) - nan_count
|
||||
full_idx = list(range(labeled_len))
|
||||
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
|
||||
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
|
||||
train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
|
||||
unlabeled_index = list(range(labeled_len, len(dataset)))
|
||||
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
|
||||
return train_index, val_index, test_index, unlabeled_index
|
||||
|
||||
def fixed_split(self, dataset):
|
||||
if self.task == 'O2-N2':
|
||||
test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
|
||||
else:
|
||||
raise ValueError('Invalid task name: {}'.format(self.task))
|
||||
full_idx = list(range(len(dataset)))
|
||||
full_idx = list(set(full_idx) - set(test_index))
|
||||
train_ratio = 0.8
|
||||
train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
|
||||
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||
return train_index, val_index, test_index, []
|
||||
|
||||
def get_train_smiles(self):
|
||||
filename = f'{self.task}.csv.gz'
|
||||
df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
||||
df_test = df.iloc[self.test_index]
|
||||
df = df.iloc[self.train_index]
|
||||
smiles_list = df['smiles'].tolist()
|
||||
smiles_list_test = df_test['smiles'].tolist()
|
||||
smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
|
||||
smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
|
||||
return smiles_list, smiles_list_test
|
||||
|
||||
def get_data_split(self):
|
||||
filename = f'{self.task}.csv.gz'
|
||||
df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
||||
df_val = df.iloc[self.val_index]
|
||||
df_test = df.iloc[self.test_index]
|
||||
df_train = df.iloc[self.train_index]
|
||||
return df_train, df_val, df_test
|
||||
|
||||
def example_batch(self):
|
||||
return next(iter(self.val_loader))
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.val_loader
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.test_loader
|
||||
|
||||
|
||||
class Dataset(InMemoryDataset):
|
||||
def __init__(self, source, root, target_prop=None,
|
||||
transform=None, pre_transform=None, pre_filter=None):
|
||||
self.target_prop = target_prop
|
||||
self.source = source
|
||||
super().__init__(root, transform, pre_transform, pre_filter)
|
||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||
|
||||
@property
|
||||
def raw_file_names(self):
|
||||
return [f'{self.source}.csv.gz']
|
||||
|
||||
@property
|
||||
def processed_file_names(self):
|
||||
return [f'{self.source}.pt']
|
||||
|
||||
def process(self):
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
data_path = osp.join(self.raw_dir, self.raw_file_names[0])
|
||||
data_df = pd.read_csv(data_path)
|
||||
|
||||
def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None):
|
||||
type_idx = []
|
||||
heavy_atom_indices, active_atoms = [], []
|
||||
for atom in mol.GetAtoms():
|
||||
if atom.GetAtomicNum() != 1:
|
||||
type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2)
|
||||
heavy_atom_indices.append(atom.GetIdx())
|
||||
active_atoms.append(atom.GetSymbol())
|
||||
if valid_atoms is not None:
|
||||
if not atom.GetSymbol() in valid_atoms:
|
||||
return None, None
|
||||
x = torch.LongTensor(type_idx)
|
||||
|
||||
edges_list = []
|
||||
edge_type = []
|
||||
for bond in mol.GetBonds():
|
||||
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
||||
if start in heavy_atom_indices and end in heavy_atom_indices:
|
||||
start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end)
|
||||
edges_list.append((start_new, end_new))
|
||||
edge_type.append(bonds[bond.GetBondType()])
|
||||
edges_list.append((end_new, start_new))
|
||||
edge_type.append(bonds[bond.GetBondType()])
|
||||
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
||||
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
||||
edge_attr = edge_type
|
||||
|
||||
if target3 is not None:
|
||||
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1)
|
||||
elif target2 is not None:
|
||||
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1)
|
||||
else:
|
||||
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1)
|
||||
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
||||
if self.pre_transform is not None:
|
||||
data = self.pre_transform(data)
|
||||
return data, active_atoms
|
||||
|
||||
# Loop through every row in the DataFrame and apply the function
|
||||
data_list = []
|
||||
len_data = len(data_df)
|
||||
with tqdm(total=len_data) as pbar:
|
||||
# --- data processing start ---
|
||||
active_atoms = set()
|
||||
for i, (sms, df_row) in enumerate(data_df.iterrows()):
|
||||
if i == sms:
|
||||
sms = df_row['smiles']
|
||||
mol = Chem.MolFromSmiles(sms, sanitize=False)
|
||||
if len(self.target_prop.split('-')) == 2:
|
||||
target1, target2 = self.target_prop.split('-')
|
||||
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2])
|
||||
elif len(self.target_prop.split('-')) == 3:
|
||||
target1, target2, target3 = self.target_prop.split('-')
|
||||
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3])
|
||||
else:
|
||||
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop])
|
||||
active_atoms.update(cur_active_atoms)
|
||||
data_list.append(data)
|
||||
pbar.update(1)
|
||||
|
||||
torch.save(self.collate(data_list), self.processed_paths[0])
|
||||
|
||||
|
||||
class DataInfos(AbstractDatasetInfos):
|
||||
def __init__(self, datamodule, cfg):
|
||||
tasktype_dict = {
|
||||
'hiv_b': 'classification',
|
||||
'bace_b': 'classification',
|
||||
'bbbp_b': 'classification',
|
||||
'O2': 'regression',
|
||||
'N2': 'regression',
|
||||
'CO2': 'regression',
|
||||
}
|
||||
task_name = cfg.dataset.task_name
|
||||
self.task = task_name
|
||||
self.task_type = tasktype_dict.get(task_name, "regression")
|
||||
self.ensure_connected = cfg.model.ensure_connected
|
||||
|
||||
datadir = cfg.dataset.datadir
|
||||
|
||||
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||
meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json')
|
||||
data_root = os.path.join(base_path, datadir, 'raw')
|
||||
if os.path.exists(meta_filename):
|
||||
with open(meta_filename, 'r') as f:
|
||||
meta_dict = json.load(f)
|
||||
else:
|
||||
meta_dict = compute_meta(data_root, task_name, datamodule.train_index, datamodule.test_index)
|
||||
|
||||
self.base_path = base_path
|
||||
self.active_atoms = meta_dict['active_atoms']
|
||||
self.max_n_nodes = meta_dict['max_node']
|
||||
self.original_max_n_nodes = meta_dict['max_node']
|
||||
self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist'])
|
||||
self.edge_types = torch.Tensor(meta_dict['bond_type_dist'])
|
||||
self.transition_E = torch.Tensor(meta_dict['transition_E'])
|
||||
|
||||
self.atom_decoder = meta_dict['active_atoms']
|
||||
node_types = torch.Tensor(meta_dict['atom_type_dist'])
|
||||
active_index = (node_types > 0).nonzero().squeeze()
|
||||
self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index]
|
||||
self.nodes_dist = DistributionNodes(self.n_nodes)
|
||||
self.active_index = active_index
|
||||
|
||||
val_len = 3 * self.original_max_n_nodes - 2
|
||||
meta_val = torch.Tensor(meta_dict['valencies'])
|
||||
self.valency_distribution = torch.zeros(val_len)
|
||||
val_len = min(val_len, len(meta_val))
|
||||
self.valency_distribution[:val_len] = meta_val[:val_len]
|
||||
self.y_prior = None
|
||||
self.train_ymin = []
|
||||
self.train_ymax = []
|
||||
|
||||
|
||||
def compute_meta(root, source_name, train_index, test_index):
|
||||
pt = Chem.GetPeriodicTable()
|
||||
atom_name_list = []
|
||||
atom_count_list = []
|
||||
for i in range(2, 119):
|
||||
atom_name_list.append(pt.GetElementSymbol(i))
|
||||
atom_count_list.append(0)
|
||||
atom_name_list.append('*')
|
||||
atom_count_list.append(0)
|
||||
n_atoms_per_mol = [0] * 500
|
||||
bond_count_list = [0, 0, 0, 0, 0]
|
||||
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||
valencies = [0] * 500
|
||||
tansition_E = np.zeros((118, 118, 5))
|
||||
|
||||
filename = f'{source_name}.csv.gz'
|
||||
df = pd.read_csv(f'{root}/{filename}')
|
||||
all_index = list(range(len(df)))
|
||||
non_test_index = list(set(all_index) - set(test_index))
|
||||
df = df.iloc[non_test_index]
|
||||
tot_smiles = df['smiles'].tolist()
|
||||
|
||||
n_atom_list = []
|
||||
n_bond_list = []
|
||||
for i, sms in enumerate(tot_smiles):
|
||||
try:
|
||||
mol = Chem.MolFromSmiles(sms)
|
||||
except:
|
||||
continue
|
||||
|
||||
n_atom = mol.GetNumHeavyAtoms()
|
||||
n_bond = mol.GetNumBonds()
|
||||
n_atom_list.append(n_atom)
|
||||
n_bond_list.append(n_bond)
|
||||
|
||||
n_atoms_per_mol[n_atom] += 1
|
||||
cur_atom_count_arr = np.zeros(118)
|
||||
for atom in mol.GetAtoms():
|
||||
symbol = atom.GetSymbol()
|
||||
if symbol == 'H':
|
||||
continue
|
||||
elif symbol == '*':
|
||||
atom_count_list[-1] += 1
|
||||
cur_atom_count_arr[-1] += 1
|
||||
else:
|
||||
atom_count_list[atom.GetAtomicNum()-2] += 1
|
||||
cur_atom_count_arr[atom.GetAtomicNum()-2] += 1
|
||||
try:
|
||||
valencies[int(atom.GetExplicitValence())] += 1
|
||||
except:
|
||||
print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence()))
|
||||
|
||||
tansition_E_temp = np.zeros((118, 118, 5))
|
||||
for bond in mol.GetBonds():
|
||||
start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
|
||||
if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H':
|
||||
continue
|
||||
|
||||
if start_atom.GetSymbol() == '*':
|
||||
start_index = 117
|
||||
else:
|
||||
start_index = start_atom.GetAtomicNum() - 2
|
||||
if end_atom.GetSymbol() == '*':
|
||||
end_index = 117
|
||||
else:
|
||||
end_index = end_atom.GetAtomicNum() - 2
|
||||
|
||||
bond_type = bond.GetBondType()
|
||||
bond_index = bond_type_to_index[bond_type]
|
||||
bond_count_list[bond_index] += 2
|
||||
|
||||
tansition_E[start_index, end_index, bond_index] += 2
|
||||
tansition_E[end_index, start_index, bond_index] += 2
|
||||
tansition_E_temp[start_index, end_index, bond_index] += 2
|
||||
tansition_E_temp[end_index, start_index, bond_index] += 2
|
||||
|
||||
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
|
||||
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118
|
||||
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118
|
||||
tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1)
|
||||
assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
|
||||
|
||||
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
||||
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
|
||||
|
||||
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
|
||||
print('processed meta info: ------', filename, '------')
|
||||
print('len atom_count_list', len(atom_count_list))
|
||||
print('len atom_name_list', len(atom_name_list))
|
||||
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
|
||||
active_atoms = active_atoms.tolist()
|
||||
atom_count_list = atom_count_list.tolist()
|
||||
|
||||
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
|
||||
bond_count_list = bond_count_list.tolist()
|
||||
valencies = np.array(valencies) / np.sum(valencies)
|
||||
valencies = valencies.tolist()
|
||||
|
||||
no_edge = np.sum(tansition_E, axis=-1) == 0
|
||||
first_elt = tansition_E[:, :, 0]
|
||||
first_elt[no_edge] = 1
|
||||
tansition_E[:, :, 0] = first_elt
|
||||
|
||||
tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True)
|
||||
|
||||
meta_dict = {
|
||||
'source': source_name,
|
||||
'num_graph': len(n_atom_list),
|
||||
'n_atoms_per_mol_dist': n_atoms_per_mol,
|
||||
'max_node': max(n_atom_list),
|
||||
'max_bond': max(n_bond_list),
|
||||
'atom_type_dist': atom_count_list,
|
||||
'bond_type_dist': bond_count_list,
|
||||
'valencies': valencies,
|
||||
'active_atoms': active_atoms,
|
||||
'num_atom_type': len(active_atoms),
|
||||
'transition_E': tansition_E.tolist(),
|
||||
}
|
||||
|
||||
with open(f'{root}/{source_name}.meta.json', "w") as f:
|
||||
json.dump(meta_dict, f)
|
||||
|
||||
return meta_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
0
mcd/diffusion/__init__.py
Normal file
0
mcd/diffusion/__init__.py
Normal file
224
mcd/diffusion/diffusion_utils.py
Normal file
224
mcd/diffusion/diffusion_utils.py
Normal file
@ -0,0 +1,224 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
from utils import PlaceHolder
|
||||
|
||||
|
||||
def sum_except_batch(x):
|
||||
return x.reshape(x.size(0), -1).sum(dim=-1)
|
||||
|
||||
def assert_correctly_masked(variable, node_mask):
|
||||
assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \
|
||||
'Variables not masked properly.'
|
||||
|
||||
def cosine_beta_schedule_discrete(timesteps, s=0.008):
|
||||
""" Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """
|
||||
steps = timesteps + 2
|
||||
x = np.linspace(0, steps, steps)
|
||||
|
||||
alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
betas = 1 - alphas
|
||||
return betas.squeeze()
|
||||
|
||||
def custom_beta_schedule_discrete(timesteps, average_num_nodes=30, s=0.008):
|
||||
""" Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """
|
||||
steps = timesteps + 2
|
||||
x = np.linspace(0, steps, steps)
|
||||
|
||||
alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
betas = 1 - alphas
|
||||
|
||||
assert timesteps >= 100
|
||||
|
||||
p = 4 / 5 # 1 - 1 / num_edge_classes
|
||||
num_edges = average_num_nodes * (average_num_nodes - 1) / 2
|
||||
|
||||
# First 100 steps: only a few updates per graph
|
||||
updates_per_graph = 1.2
|
||||
beta_first = updates_per_graph / (p * num_edges)
|
||||
|
||||
betas[betas < beta_first] = beta_first
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def check_mask_correct(variables, node_mask):
|
||||
for i, variable in enumerate(variables):
|
||||
if len(variable) > 0:
|
||||
assert_correctly_masked(variable, node_mask)
|
||||
|
||||
|
||||
def check_tensor_same_size(*args):
|
||||
for i, arg in enumerate(args):
|
||||
if i == 0:
|
||||
continue
|
||||
assert args[0].size() == arg.size()
|
||||
|
||||
|
||||
|
||||
def reverse_tensor(x):
|
||||
return x[torch.arange(x.size(0) - 1, -1, -1)]
|
||||
|
||||
|
||||
def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True):
|
||||
''' Sample features from multinomial distribution with given probabilities (probX, probE, proby)
|
||||
:param probX: bs, n, dx_out node features
|
||||
:param probE: bs, n, n, de_out edge features
|
||||
:param proby: bs, dy_out global features.
|
||||
'''
|
||||
bs, n, _ = probX.shape
|
||||
|
||||
# Noise X
|
||||
# The masked rows should define probability distributions as well
|
||||
probX[~node_mask] = 1 / probX.shape[-1]
|
||||
|
||||
# Flatten the probability tensor to sample with multinomial
|
||||
probX = probX.reshape(bs * n, -1) # (bs * n, dx_out)
|
||||
|
||||
# Sample X
|
||||
probX = probX + 1e-12
|
||||
probX = probX / probX.sum(dim=-1, keepdim=True)
|
||||
X_t = probX.multinomial(1) # (bs * n, 1)
|
||||
X_t = X_t.reshape(bs, n) # (bs, n)
|
||||
|
||||
# Noise E
|
||||
# The masked rows should define probability distributions as well
|
||||
inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))
|
||||
diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1)
|
||||
|
||||
probE[inverse_edge_mask] = 1 / probE.shape[-1]
|
||||
probE[diag_mask.bool()] = 1 / probE.shape[-1]
|
||||
probE = probE.reshape(bs * n * n, -1) # (bs * n * n, de_out)
|
||||
probE = probE + 1e-12
|
||||
probE = probE / probE.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Sample E
|
||||
E_t = probE.multinomial(1).reshape(bs, n, n) # (bs, n, n)
|
||||
E_t = torch.triu(E_t, diagonal=1)
|
||||
E_t = (E_t + torch.transpose(E_t, 1, 2))
|
||||
|
||||
return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t))
|
||||
|
||||
|
||||
def compute_batched_over0_posterior_distribution(X_t, Qt, Qsb, Qtb):
|
||||
""" M: X or E
|
||||
Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0
|
||||
X_t: bs, n, dt or bs, n, n, dt
|
||||
Qt: bs, d_t-1, dt
|
||||
Qsb: bs, d0, d_t-1
|
||||
Qtb: bs, d0, dt.
|
||||
"""
|
||||
X_t = X_t.float()
|
||||
Qt_T = Qt.transpose(-1, -2).float() # bs, N, dt
|
||||
assert Qt.dim() == 3
|
||||
left_term = X_t @ Qt_T
|
||||
left_term = left_term.unsqueeze(dim=2) # bs, N, 1, d_t-1
|
||||
right_term = Qsb.unsqueeze(1)
|
||||
numerator = left_term * right_term # bs, N, d0, d_t-1
|
||||
|
||||
denominator = Qtb @ X_t.transpose(-1, -2) # bs, d0, N
|
||||
denominator = denominator.transpose(-1, -2) # bs, N, d0
|
||||
denominator = denominator.unsqueeze(-1) # bs, N, d0, 1
|
||||
|
||||
denominator[denominator == 0] = 1.
|
||||
return numerator / denominator
|
||||
|
||||
|
||||
def mask_distributions(true_X, true_E, pred_X, pred_E, node_mask):
|
||||
# Add a small value everywhere to avoid nans
|
||||
pred_X = pred_X.clamp_min(1e-18)
|
||||
pred_X = pred_X / torch.sum(pred_X, dim=-1, keepdim=True)
|
||||
|
||||
pred_E = pred_E.clamp_min(1e-18)
|
||||
pred_E = pred_E / torch.sum(pred_E, dim=-1, keepdim=True)
|
||||
|
||||
# Set masked rows to arbitrary distributions, so it doesn't contribute to loss
|
||||
row_X = torch.ones(true_X.size(-1), dtype=true_X.dtype, device=true_X.device)
|
||||
row_E = torch.zeros(true_E.size(-1), dtype=true_E.dtype, device=true_E.device).clamp_min(1e-18)
|
||||
row_E[0] = 1.
|
||||
|
||||
diag_mask = ~torch.eye(node_mask.size(1), device=node_mask.device, dtype=torch.bool).unsqueeze(0)
|
||||
true_X[~node_mask] = row_X
|
||||
true_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E
|
||||
pred_X[~node_mask] = row_X
|
||||
pred_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E
|
||||
|
||||
return true_X, true_E, pred_X, pred_E
|
||||
|
||||
def posterior_distributions(X, X_t, Qt, Qsb, Qtb, X_dim):
|
||||
bs, n, d = X.shape
|
||||
X = X.float()
|
||||
Qt_X_T = torch.transpose(Qt.X, -2, -1).float() # (bs, d, d)
|
||||
left_term = X_t @ Qt_X_T # (bs, N, d)
|
||||
right_term = X @ Qsb.X # (bs, N, d)
|
||||
|
||||
numerator = left_term * right_term # (bs, N, d)
|
||||
denominator = X @ Qtb.X # (bs, N, d) @ (bs, d, d) = (bs, N, d)
|
||||
denominator = denominator * X_t
|
||||
|
||||
num_X = numerator[:, :, :X_dim]
|
||||
num_E = numerator[:, :, X_dim:].reshape(bs, n*n, -1)
|
||||
|
||||
deno_X = denominator[:, :, :X_dim]
|
||||
deno_E = denominator[:, :, X_dim:].reshape(bs, n*n, -1)
|
||||
|
||||
# denominator = (denominator * X_t).sum(dim=-1) # (bs, N, d) * (bs, N, d) + sum = (bs, N)
|
||||
denominator = denominator.unsqueeze(-1) # (bs, N, 1)
|
||||
|
||||
deno_X = deno_X.sum(dim=-1).unsqueeze(-1)
|
||||
deno_E = deno_E.sum(dim=-1).unsqueeze(-1)
|
||||
|
||||
deno_X[deno_X == 0.] = 1
|
||||
deno_E[deno_E == 0.] = 1
|
||||
prob_X = num_X / deno_X
|
||||
prob_E = num_E / deno_E
|
||||
|
||||
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True)
|
||||
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True)
|
||||
return PlaceHolder(X=prob_X, E=prob_E, y=None)
|
||||
|
||||
|
||||
def sample_discrete_feature_noise(limit_dist, node_mask):
|
||||
""" Sample from the limit distribution of the diffusion process"""
|
||||
bs, n_max = node_mask.shape
|
||||
x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1)
|
||||
U_X = x_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max)
|
||||
U_X = F.one_hot(U_X.long(), num_classes=x_limit.shape[-1]).float()
|
||||
|
||||
e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1)
|
||||
U_E = e_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max, n_max)
|
||||
U_E = F.one_hot(U_E.long(), num_classes=e_limit.shape[-1]).float()
|
||||
|
||||
U_X = U_X.to(node_mask.device)
|
||||
U_E = U_E.to(node_mask.device)
|
||||
|
||||
# Get upper triangular part of edge noise, without main diagonal
|
||||
upper_triangular_mask = torch.zeros_like(U_E)
|
||||
indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1)
|
||||
upper_triangular_mask[:, indices[0], indices[1], :] = 1
|
||||
|
||||
U_E = U_E * upper_triangular_mask
|
||||
U_E = (U_E + torch.transpose(U_E, 1, 2))
|
||||
|
||||
assert (U_E == torch.transpose(U_E, 1, 2)).all()
|
||||
return PlaceHolder(X=U_X, E=U_E, y=None).mask(node_mask)
|
||||
|
||||
def index_QE(X, q_e, n_bond=5):
|
||||
bs, n, n_atom = X.shape
|
||||
node_indices = X.argmax(-1) # (bs, n)
|
||||
|
||||
exp_ind1 = node_indices[ :, :, None, None, None].expand(bs, n, n_atom, n_bond, n_bond)
|
||||
exp_ind2 = node_indices[ :, :, None, None, None].expand(bs, n, n, n_bond, n_bond)
|
||||
|
||||
q_e = torch.gather(q_e, 1, exp_ind1)
|
||||
q_e = torch.gather(q_e, 2, exp_ind2) # (bs, n, n, n_bond, n_bond)
|
||||
|
||||
|
||||
node_mask = X.sum(-1) != 0
|
||||
no_edge = (~node_mask)[:, :, None] & (~node_mask)[:, None, :]
|
||||
q_e[no_edge] = torch.tensor([1, 0, 0, 0, 0]).type_as(q_e)
|
||||
|
||||
return q_e
|
30
mcd/diffusion/distributions.py
Normal file
30
mcd/diffusion/distributions.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
|
||||
class DistributionNodes:
|
||||
def __init__(self, histogram):
|
||||
""" Compute the distribution of the number of nodes in the dataset, and sample from this distribution.
|
||||
historgram: dict. The keys are num_nodes, the values are counts
|
||||
"""
|
||||
|
||||
if type(histogram) == dict:
|
||||
max_n_nodes = max(histogram.keys())
|
||||
prob = torch.zeros(max_n_nodes + 1)
|
||||
for num_nodes, count in histogram.items():
|
||||
prob[num_nodes] = count
|
||||
else:
|
||||
prob = histogram
|
||||
|
||||
self.prob = prob / prob.sum()
|
||||
self.m = torch.distributions.Categorical(prob)
|
||||
|
||||
def sample_n(self, n_samples, device):
|
||||
idx = self.m.sample((n_samples,))
|
||||
return idx.to(device)
|
||||
|
||||
def log_prob(self, batch_n_nodes):
|
||||
assert len(batch_n_nodes.size()) == 1
|
||||
p = self.prob.to(batch_n_nodes.device)
|
||||
|
||||
probas = p[batch_n_nodes]
|
||||
log_p = torch.log(probas + 1e-30)
|
||||
return log_p
|
159
mcd/diffusion/noise_schedule.py
Normal file
159
mcd/diffusion/noise_schedule.py
Normal file
@ -0,0 +1,159 @@
|
||||
import torch
|
||||
import utils
|
||||
from diffusion import diffusion_utils
|
||||
|
||||
class PredefinedNoiseScheduleDiscrete(torch.nn.Module):
|
||||
def __init__(self, noise_schedule, timesteps):
|
||||
super(PredefinedNoiseScheduleDiscrete, self).__init__()
|
||||
self.timesteps = timesteps
|
||||
|
||||
if noise_schedule == 'cosine':
|
||||
betas = diffusion_utils.cosine_beta_schedule_discrete(timesteps)
|
||||
elif noise_schedule == 'custom':
|
||||
betas = diffusion_utils.custom_beta_schedule_discrete(timesteps)
|
||||
else:
|
||||
raise NotImplementedError(noise_schedule)
|
||||
|
||||
self.register_buffer('betas', torch.from_numpy(betas).float())
|
||||
|
||||
# 0.9999
|
||||
self.alphas = 1 - torch.clamp(self.betas, min=0, max=1)
|
||||
|
||||
log_alpha = torch.log(self.alphas)
|
||||
log_alpha_bar = torch.cumsum(log_alpha, dim=0)
|
||||
self.alphas_bar = torch.exp(log_alpha_bar)
|
||||
|
||||
def forward(self, t_normalized=None, t_int=None):
|
||||
assert int(t_normalized is None) + int(t_int is None) == 1
|
||||
if t_int is None:
|
||||
t_int = torch.round(t_normalized * self.timesteps)
|
||||
return self.betas[t_int.long()]
|
||||
|
||||
def get_alpha_bar(self, t_normalized=None, t_int=None):
|
||||
assert int(t_normalized is None) + int(t_int is None) == 1
|
||||
if t_int is None:
|
||||
t_int = torch.round(t_normalized * self.timesteps)
|
||||
### new
|
||||
self.alphas_bar = self.alphas_bar.to(t_int.device)
|
||||
return self.alphas_bar[t_int.long()]
|
||||
|
||||
|
||||
class DiscreteUniformTransition:
|
||||
def __init__(self, x_classes: int, e_classes: int, y_classes: int):
|
||||
self.X_classes = x_classes
|
||||
self.E_classes = e_classes
|
||||
self.y_classes = y_classes
|
||||
self.u_x = torch.ones(1, self.X_classes, self.X_classes)
|
||||
if self.X_classes > 0:
|
||||
self.u_x = self.u_x / self.X_classes
|
||||
|
||||
self.u_e = torch.ones(1, self.E_classes, self.E_classes)
|
||||
if self.E_classes > 0:
|
||||
self.u_e = self.u_e / self.E_classes
|
||||
|
||||
self.u_y = torch.ones(1, self.y_classes, self.y_classes)
|
||||
if self.y_classes > 0:
|
||||
self.u_y = self.u_y / self.y_classes
|
||||
|
||||
def get_Qt(self, beta_t, device, X=None, flatten_e=None):
|
||||
""" Returns one-step transition matrices for X and E, from step t - 1 to step t.
|
||||
Qt = (1 - beta_t) * I + beta_t / K
|
||||
|
||||
beta_t: (bs) noise level between 0 and 1
|
||||
returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy).
|
||||
"""
|
||||
beta_t = beta_t.unsqueeze(1)
|
||||
beta_t = beta_t.to(device)
|
||||
self.u_x = self.u_x.to(device)
|
||||
self.u_e = self.u_e.to(device)
|
||||
self.u_y = self.u_y.to(device)
|
||||
|
||||
q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes, device=device).unsqueeze(0)
|
||||
q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye(self.E_classes, device=device).unsqueeze(0)
|
||||
q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye(self.y_classes, device=device).unsqueeze(0)
|
||||
|
||||
return utils.PlaceHolder(X=q_x, E=q_e, y=q_y)
|
||||
|
||||
def get_Qt_bar(self, alpha_bar_t, device, X=None, flatten_e=None):
|
||||
""" Returns t-step transition matrices for X and E, from step 0 to step t.
|
||||
Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) / K
|
||||
|
||||
alpha_bar_t: (bs) Product of the (1 - beta_t) for each time step from 0 to t.
|
||||
returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy).
|
||||
"""
|
||||
alpha_bar_t = alpha_bar_t.unsqueeze(1)
|
||||
alpha_bar_t = alpha_bar_t.to(device)
|
||||
self.u_x = self.u_x.to(device)
|
||||
self.u_e = self.u_e.to(device)
|
||||
self.u_y = self.u_y.to(device)
|
||||
|
||||
q_x = alpha_bar_t * torch.eye(self.X_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_x
|
||||
q_e = alpha_bar_t * torch.eye(self.E_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_e
|
||||
q_y = alpha_bar_t * torch.eye(self.y_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_y
|
||||
|
||||
return utils.PlaceHolder(X=q_x, E=q_e, y=q_y)
|
||||
|
||||
|
||||
class MarginalTransition:
|
||||
def __init__(self, x_marginals, e_marginals, xe_conditions, ex_conditions, y_classes, n_nodes):
|
||||
self.X_classes = len(x_marginals)
|
||||
self.E_classes = len(e_marginals)
|
||||
self.y_classes = y_classes
|
||||
self.x_marginals = x_marginals # Dx
|
||||
self.e_marginals = e_marginals # Dx, De
|
||||
self.xe_conditions = xe_conditions
|
||||
|
||||
self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) # 1, Dx, Dx
|
||||
self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) # 1, De, De
|
||||
self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De
|
||||
self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx
|
||||
self.u = self.get_union_transition(self.u_x, self.u_e, self.u_xe, self.u_ex, n_nodes) # 1, Dx + n*De, Dx + n*De
|
||||
|
||||
def get_union_transition(self, u_x, u_e, u_xe, u_ex, n_nodes):
|
||||
u_e = u_e.repeat(1, n_nodes, n_nodes) # (1, n*de, n*de)
|
||||
u_xe = u_xe.repeat(1, 1, n_nodes) # (1, dx, n*de)
|
||||
u_ex = u_ex.repeat(1, n_nodes, 1) # (1, n*de, dx)
|
||||
u0 = torch.cat([u_x, u_xe], dim=2) # (1, dx, dx + n*de)
|
||||
u1 = torch.cat([u_ex, u_e], dim=2) # (1, n*de, dx + n*de)
|
||||
u = torch.cat([u0, u1], dim=1) # (1, dx + n*de, dx + n*de)
|
||||
return u
|
||||
|
||||
def index_edge_margin(self, X, q_e, n_bond=5):
|
||||
# q_e: (bs, dx, de) --> (bs, n, de)
|
||||
bs, n, n_atom = X.shape
|
||||
node_indices = X.argmax(-1) # (bs, n)
|
||||
ind = node_indices[ :, :, None].expand(bs, n, n_bond)
|
||||
q_e = torch.gather(q_e, 1, ind)
|
||||
return q_e
|
||||
|
||||
def get_Qt(self, beta_t, device):
|
||||
""" Returns one-step transition matrices for X and E, from step t - 1 to step t.
|
||||
Qt = (1 - beta_t) * I + beta_t / K
|
||||
beta_t: (bs)
|
||||
returns: q (bs, d0, d0)
|
||||
"""
|
||||
bs = beta_t.size(0)
|
||||
d0 = self.u.size(-1)
|
||||
self.u = self.u.to(device)
|
||||
u = self.u.expand(bs, d0, d0)
|
||||
|
||||
beta_t = beta_t.to(device)
|
||||
beta_t = beta_t.view(bs, 1, 1)
|
||||
q = beta_t * u + (1 - beta_t) * torch.eye(d0, device=device).unsqueeze(0)
|
||||
|
||||
return utils.PlaceHolder(X=q, E=None, y=None)
|
||||
|
||||
def get_Qt_bar(self, alpha_bar_t, device):
|
||||
""" Returns t-step transition matrices for X and E, from step 0 to step t.
|
||||
Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) * K
|
||||
alpha_bar_t: (bs, 1) roduct of the (1 - beta_t) for each time step from 0 to t.
|
||||
returns: q (bs, d0, d0)
|
||||
"""
|
||||
bs = alpha_bar_t.size(0)
|
||||
d0 = self.u.size(-1)
|
||||
alpha_bar_t = alpha_bar_t.to(device)
|
||||
alpha_bar_t = alpha_bar_t.view(bs, 1, 1)
|
||||
self.u = self.u.to(device)
|
||||
q = alpha_bar_t * torch.eye(d0, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u
|
||||
|
||||
return utils.PlaceHolder(X=q, E=None, y=None)
|
619
mcd/diffusion_model.py
Normal file
619
mcd/diffusion_model.py
Normal file
@ -0,0 +1,619 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
import time
|
||||
import os
|
||||
|
||||
from models.transformer import Denoiser
|
||||
from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition
|
||||
|
||||
from diffusion import diffusion_utils
|
||||
from metrics.train_loss import TrainLossDiscrete
|
||||
from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
|
||||
import utils
|
||||
|
||||
class MCD(pl.LightningModule):
|
||||
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||
self.test_only = cfg.general.test_only
|
||||
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
||||
|
||||
input_dims = dataset_infos.input_dims
|
||||
output_dims = dataset_infos.output_dims
|
||||
nodes_dist = dataset_infos.nodes_dist
|
||||
active_index = dataset_infos.active_index
|
||||
|
||||
self.cfg = cfg
|
||||
self.name = cfg.general.name
|
||||
self.T = cfg.model.diffusion_steps
|
||||
self.guide_scale = cfg.model.guide_scale
|
||||
|
||||
self.Xdim = input_dims['X']
|
||||
self.Edim = input_dims['E']
|
||||
self.ydim = input_dims['y']
|
||||
self.Xdim_output = output_dims['X']
|
||||
self.Edim_output = output_dims['E']
|
||||
self.ydim_output = output_dims['y']
|
||||
self.node_dist = nodes_dist
|
||||
self.active_index = active_index
|
||||
self.dataset_info = dataset_infos
|
||||
|
||||
self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train)
|
||||
|
||||
self.val_nll = NLL()
|
||||
self.val_X_kl = SumExceptBatchKL()
|
||||
self.val_E_kl = SumExceptBatchKL()
|
||||
self.val_X_logp = SumExceptBatchMetric()
|
||||
self.val_E_logp = SumExceptBatchMetric()
|
||||
self.val_y_collection = []
|
||||
|
||||
self.test_nll = NLL()
|
||||
self.test_X_kl = SumExceptBatchKL()
|
||||
self.test_E_kl = SumExceptBatchKL()
|
||||
self.test_X_logp = SumExceptBatchMetric()
|
||||
self.test_E_logp = SumExceptBatchMetric()
|
||||
self.test_y_collection = []
|
||||
|
||||
self.train_metrics = train_metrics
|
||||
self.sampling_metrics = sampling_metrics
|
||||
|
||||
self.visualization_tools = visualization_tools
|
||||
self.max_n_nodes = dataset_infos.max_n_nodes
|
||||
|
||||
self.model = Denoiser(max_n_nodes=self.max_n_nodes,
|
||||
hidden_size=cfg.model.hidden_size,
|
||||
depth=cfg.model.depth,
|
||||
num_heads=cfg.model.num_heads,
|
||||
mlp_ratio=cfg.model.mlp_ratio,
|
||||
drop_condition=cfg.model.drop_condition,
|
||||
Xdim=self.Xdim,
|
||||
Edim=self.Edim,
|
||||
ydim=self.ydim,
|
||||
task_type=dataset_infos.task_type)
|
||||
|
||||
self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule,
|
||||
timesteps=cfg.model.diffusion_steps)
|
||||
|
||||
|
||||
x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float())
|
||||
|
||||
e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float())
|
||||
x_marginals = x_marginals / (x_marginals ).sum()
|
||||
e_marginals = e_marginals / (e_marginals ).sum()
|
||||
|
||||
xe_conditions = self.dataset_info.transition_E.float()
|
||||
xe_conditions = xe_conditions[self.active_index][:, self.active_index]
|
||||
|
||||
xe_conditions = xe_conditions.sum(dim=1)
|
||||
ex_conditions = xe_conditions.t()
|
||||
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
|
||||
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
|
||||
|
||||
self.transition_model = MarginalTransition(x_marginals=x_marginals,
|
||||
e_marginals=e_marginals,
|
||||
xe_conditions=xe_conditions,
|
||||
ex_conditions=ex_conditions,
|
||||
y_classes=self.ydim_output,
|
||||
n_nodes=self.max_n_nodes)
|
||||
|
||||
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
|
||||
|
||||
self.start_epoch_time = None
|
||||
self.train_iterations = None
|
||||
self.val_iterations = None
|
||||
self.log_every_steps = cfg.general.log_every_steps
|
||||
self.number_chain_steps = cfg.general.number_chain_steps
|
||||
|
||||
self.best_val_nll = 1e8
|
||||
self.val_counter = 0
|
||||
self.batch_size = self.cfg.train.batch_size
|
||||
|
||||
|
||||
def forward(self, noisy_data, unconditioned=False):
|
||||
x, e, y = noisy_data['X_t'].float(), noisy_data['E_t'].float(), noisy_data['y_t'].float().clone()
|
||||
node_mask, t = noisy_data['node_mask'], noisy_data['t']
|
||||
pred = self.model(x, e, node_mask, y=y, t=t, unconditioned=unconditioned)
|
||||
return pred
|
||||
|
||||
def training_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask)
|
||||
X, E = dense_data.X, dense_data.E
|
||||
noisy_data = self.apply_noise(X, E, data.y, node_mask)
|
||||
pred = self.forward(noisy_data)
|
||||
loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
|
||||
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
|
||||
log=i % self.log_every_steps == 0)
|
||||
|
||||
self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
|
||||
log=i % self.log_every_steps == 0)
|
||||
self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
|
||||
return {'loss': loss}
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
params = self.parameters()
|
||||
optimizer = torch.optim.AdamW(params, lr=self.cfg.train.lr, amsgrad=True,
|
||||
weight_decay=self.cfg.train.weight_decay)
|
||||
return optimizer
|
||||
|
||||
def on_fit_start(self) -> None:
|
||||
self.train_iterations = self.trainer.datamodule.training_iterations
|
||||
print('on fit train iteration:', self.train_iterations)
|
||||
print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim))
|
||||
|
||||
def on_train_epoch_start(self) -> None:
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
|
||||
self.start_epoch_time = time.time()
|
||||
self.train_loss.reset()
|
||||
self.train_metrics.reset()
|
||||
|
||||
def on_train_epoch_end(self) -> None:
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
log = True
|
||||
else:
|
||||
log = False
|
||||
self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log)
|
||||
self.train_metrics.log_epoch_metrics(self.current_epoch, log)
|
||||
|
||||
def on_validation_epoch_start(self) -> None:
|
||||
self.val_nll.reset()
|
||||
self.val_X_kl.reset()
|
||||
self.val_E_kl.reset()
|
||||
self.val_X_logp.reset()
|
||||
self.val_E_logp.reset()
|
||||
self.sampling_metrics.reset()
|
||||
self.val_y_collection = []
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask)
|
||||
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
||||
pred = self.forward(noisy_data)
|
||||
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
|
||||
self.val_y_collection.append(data.y)
|
||||
self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True)
|
||||
return {'loss': nll}
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T,
|
||||
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
||||
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
||||
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
|
||||
|
||||
# Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
|
||||
self.log("val/NLL", metrics[0], sync_dist=True)
|
||||
|
||||
if metrics[0] < self.best_val_nll:
|
||||
self.best_val_nll = metrics[0]
|
||||
|
||||
self.val_counter += 1
|
||||
|
||||
if self.val_counter % self.cfg.general.sample_every_val == 0 and self.val_counter > 1:
|
||||
start = time.time()
|
||||
samples_left_to_generate = self.cfg.general.samples_to_generate
|
||||
samples_left_to_save = self.cfg.general.samples_to_save
|
||||
chains_left_to_save = self.cfg.general.chains_to_save
|
||||
|
||||
samples, all_ys, ident = [], [], 0
|
||||
|
||||
self.val_y_collection = torch.cat(self.val_y_collection, dim=0)
|
||||
num_examples = self.val_y_collection.size(0)
|
||||
start_index = 0
|
||||
while samples_left_to_generate > 0:
|
||||
bs = 1 * self.cfg.train.batch_size
|
||||
to_generate = min(samples_left_to_generate, bs)
|
||||
to_save = min(samples_left_to_save, bs)
|
||||
chains_save = min(chains_left_to_save, bs)
|
||||
|
||||
if start_index + to_generate > num_examples:
|
||||
start_index = 0
|
||||
if to_generate > num_examples:
|
||||
ratio = to_generate // num_examples
|
||||
self.val_y_collection = self.val_y_collection.repeat(ratio+1, 1)
|
||||
num_examples = self.val_y_collection.size(0)
|
||||
batch_y = self.val_y_collection[start_index:start_index + to_generate]
|
||||
all_ys.append(batch_y)
|
||||
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||
save_final=to_save,
|
||||
keep_chain=chains_save,
|
||||
number_chain_steps=self.number_chain_steps))
|
||||
ident += to_generate
|
||||
start_index += to_generate
|
||||
|
||||
samples_left_to_save -= to_save
|
||||
samples_left_to_generate -= to_generate
|
||||
chains_left_to_save -= chains_save
|
||||
|
||||
print(f"Computing sampling metrics", ' ...')
|
||||
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
||||
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
||||
current_path = os.getcwd()
|
||||
result_path = os.path.join(current_path,
|
||||
f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
|
||||
self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
|
||||
self.sampling_metrics.reset()
|
||||
|
||||
def on_test_epoch_start(self) -> None:
|
||||
print("Starting test...")
|
||||
self.test_nll.reset()
|
||||
self.test_X_kl.reset()
|
||||
self.test_E_kl.reset()
|
||||
self.test_X_logp.reset()
|
||||
self.test_E_logp.reset()
|
||||
self.test_y_collection = []
|
||||
|
||||
@torch.no_grad()
|
||||
def test_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
||||
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask)
|
||||
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
||||
pred = self.forward(noisy_data)
|
||||
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True)
|
||||
self.test_y_collection.append(data.y)
|
||||
return {'loss': nll}
|
||||
|
||||
def on_test_epoch_end(self) -> None:
|
||||
""" Measure likelihood on a test set and compute stability metrics. """
|
||||
metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(),
|
||||
self.test_X_logp.compute(), self.test_E_logp.compute()]
|
||||
|
||||
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
|
||||
f"Test Edge type KL: {metrics[2] :.2f}")
|
||||
|
||||
## final epcoh
|
||||
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
|
||||
samples_left_to_save = self.cfg.general.final_model_samples_to_save
|
||||
chains_left_to_save = self.cfg.general.final_model_chains_to_save
|
||||
|
||||
samples, all_ys, batch_id = [], [], 0
|
||||
|
||||
test_y_collection = torch.cat(self.test_y_collection, dim=0)
|
||||
num_examples = test_y_collection.size(0)
|
||||
if self.cfg.general.final_model_samples_to_generate > num_examples:
|
||||
ratio = self.cfg.general.final_model_samples_to_generate // num_examples
|
||||
test_y_collection = test_y_collection.repeat(ratio+1, 1)
|
||||
num_examples = test_y_collection.size(0)
|
||||
|
||||
while samples_left_to_generate > 0:
|
||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||
f'{self.cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||
bs = 1 * self.cfg.train.batch_size
|
||||
to_generate = min(samples_left_to_generate, bs)
|
||||
to_save = min(samples_left_to_save, bs)
|
||||
chains_save = min(chains_left_to_save, bs)
|
||||
batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||
|
||||
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
||||
samples = samples + cur_sample
|
||||
|
||||
all_ys.append(batch_y)
|
||||
batch_id += to_generate
|
||||
|
||||
samples_left_to_save -= to_save
|
||||
samples_left_to_generate -= to_generate
|
||||
chains_left_to_save -= chains_save
|
||||
|
||||
print(f"final Computing sampling metrics...")
|
||||
self.sampling_metrics.reset()
|
||||
self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, self.val_counter, test=True)
|
||||
self.sampling_metrics.reset()
|
||||
print(f"Done.")
|
||||
|
||||
|
||||
def kl_prior(self, X, E, node_mask):
|
||||
"""Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).
|
||||
|
||||
This is essentially a lot of work for something that is in practice negligible in the loss. However, you
|
||||
compute it so that you see it when you've made a mistake in your noise schedule.
|
||||
"""
|
||||
# Compute the last alpha value, alpha_T.
|
||||
ones = torch.ones((X.size(0), 1), device=X.device)
|
||||
Ts = self.T * ones
|
||||
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_int=Ts) # (bs, 1)
|
||||
|
||||
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
|
||||
|
||||
bs, n, d = X.shape
|
||||
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
|
||||
prob_all = X_all @ Qtb.X
|
||||
probX = prob_all[:, :, :self.Xdim_output]
|
||||
probE = prob_all[:, :, self.Xdim_output:].reshape((bs, n, n, -1))
|
||||
|
||||
assert probX.shape == X.shape
|
||||
|
||||
limit_X = self.limit_dist.X[None, None, :].expand(bs, n, -1).type_as(probX)
|
||||
limit_E = self.limit_dist.E[None, None, None, :].expand(bs, n, n, -1).type_as(probE)
|
||||
|
||||
# Make sure that masked rows do not contribute to the loss
|
||||
limit_dist_X, limit_dist_E, probX, probE = diffusion_utils.mask_distributions(true_X=limit_X.clone(),
|
||||
true_E=limit_E.clone(),
|
||||
pred_X=probX,
|
||||
pred_E=probE,
|
||||
node_mask=node_mask)
|
||||
|
||||
kl_distance_X = F.kl_div(input=probX.log(), target=limit_dist_X, reduction='none')
|
||||
kl_distance_E = F.kl_div(input=probE.log(), target=limit_dist_E, reduction='none')
|
||||
|
||||
return diffusion_utils.sum_except_batch(kl_distance_X) + \
|
||||
diffusion_utils.sum_except_batch(kl_distance_E)
|
||||
|
||||
def compute_Lt(self, X, E, y, pred, noisy_data, node_mask, test):
|
||||
pred_probs_X = F.softmax(pred.X, dim=-1)
|
||||
pred_probs_E = F.softmax(pred.E, dim=-1)
|
||||
|
||||
Qtb = self.transition_model.get_Qt_bar(noisy_data['alpha_t_bar'], self.device)
|
||||
Qsb = self.transition_model.get_Qt_bar(noisy_data['alpha_s_bar'], self.device)
|
||||
Qt = self.transition_model.get_Qt(noisy_data['beta_t'], self.device)
|
||||
|
||||
# Compute distributions to compare with KL
|
||||
bs, n, d = X.shape
|
||||
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1).float()
|
||||
Xt_all = torch.cat([noisy_data['X_t'], noisy_data['E_t'].reshape(bs, n, -1)], dim=-1).float()
|
||||
pred_probs_all = torch.cat([pred_probs_X, pred_probs_E.reshape(bs, n, -1)], dim=-1).float()
|
||||
|
||||
prob_true = diffusion_utils.posterior_distributions(X=X_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.Xdim_output)
|
||||
prob_true.E = prob_true.E.reshape((bs, n, n, -1))
|
||||
prob_pred = diffusion_utils.posterior_distributions(X=pred_probs_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.Xdim_output)
|
||||
prob_pred.E = prob_pred.E.reshape((bs, n, n, -1))
|
||||
|
||||
# Reshape and filter masked rows
|
||||
prob_true_X, prob_true_E, prob_pred.X, prob_pred.E = diffusion_utils.mask_distributions(true_X=prob_true.X,
|
||||
true_E=prob_true.E,
|
||||
pred_X=prob_pred.X,
|
||||
pred_E=prob_pred.E,
|
||||
node_mask=node_mask)
|
||||
kl_x = (self.test_X_kl if test else self.val_X_kl)(prob_true.X, torch.log(prob_pred.X))
|
||||
kl_e = (self.test_E_kl if test else self.val_E_kl)(prob_true.E, torch.log(prob_pred.E))
|
||||
|
||||
return self.T * (kl_x + kl_e)
|
||||
|
||||
def reconstruction_logp(self, t, X, E, y, node_mask):
|
||||
# Compute noise values for t = 0.
|
||||
t_zeros = torch.zeros_like(t)
|
||||
beta_0 = self.noise_schedule(t_zeros)
|
||||
Q0 = self.transition_model.get_Qt(beta_0, self.device)
|
||||
|
||||
bs, n, d = X.shape
|
||||
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
|
||||
prob_all = X_all @ Q0.X
|
||||
probX0 = prob_all[:, :, :self.Xdim_output]
|
||||
probE0 = prob_all[:, :, self.Xdim_output:].reshape((bs, n, n, -1))
|
||||
|
||||
sampled0 = diffusion_utils.sample_discrete_features(probX=probX0, probE=probE0, node_mask=node_mask)
|
||||
|
||||
X0 = F.one_hot(sampled0.X, num_classes=self.Xdim_output).float()
|
||||
E0 = F.one_hot(sampled0.E, num_classes=self.Edim_output).float()
|
||||
|
||||
assert (X.shape == X0.shape) and (E.shape == E0.shape)
|
||||
sampled_0 = utils.PlaceHolder(X=X0, E=E0, y=y).mask(node_mask)
|
||||
|
||||
# Predictions
|
||||
noisy_data = {'X_t': sampled_0.X, 'E_t': sampled_0.E, 'y_t': sampled_0.y, 'node_mask': node_mask,
|
||||
't': torch.zeros(X0.shape[0], 1).type_as(y)}
|
||||
pred0 = self.forward(noisy_data)
|
||||
|
||||
# Normalize predictions
|
||||
probX0 = F.softmax(pred0.X, dim=-1)
|
||||
probE0 = F.softmax(pred0.E, dim=-1)
|
||||
proby0 = None
|
||||
|
||||
# Set masked rows to arbitrary values that don't contribute to loss
|
||||
probX0[~node_mask] = torch.ones(self.Xdim_output).type_as(probX0)
|
||||
probE0[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))] = torch.ones(self.Edim_output).type_as(probE0)
|
||||
|
||||
diag_mask = torch.eye(probE0.size(1)).type_as(probE0).bool()
|
||||
diag_mask = diag_mask.unsqueeze(0).expand(probE0.size(0), -1, -1)
|
||||
probE0[diag_mask] = torch.ones(self.Edim_output).type_as(probE0)
|
||||
|
||||
return utils.PlaceHolder(X=probX0, E=probE0, y=proby0)
|
||||
|
||||
def apply_noise(self, X, E, y, node_mask):
|
||||
""" Sample noise and apply it to the data. """
|
||||
|
||||
# Sample a timestep t.
|
||||
# When evaluating, the loss for t=0 is computed separately
|
||||
lowest_t = 0 if self.training else 1
|
||||
t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1)
|
||||
s_int = t_int - 1
|
||||
|
||||
t_float = t_int / self.T
|
||||
s_float = s_int / self.T
|
||||
|
||||
# beta_t and alpha_s_bar are used for denoising/loss computation
|
||||
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
|
||||
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
|
||||
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
|
||||
|
||||
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
|
||||
|
||||
bs, n, d = X.shape
|
||||
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
|
||||
prob_all = X_all @ Qtb.X
|
||||
probX = prob_all[:, :, :self.Xdim_output]
|
||||
probE = prob_all[:, :, self.Xdim_output:].reshape(bs, n, n, -1)
|
||||
|
||||
sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)
|
||||
|
||||
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
|
||||
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
|
||||
assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
|
||||
|
||||
y_t = y
|
||||
z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
|
||||
|
||||
noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
|
||||
'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask}
|
||||
return noisy_data
|
||||
|
||||
def compute_val_loss(self, pred, noisy_data, X, E, y, node_mask, test=False):
|
||||
"""Computes an estimator for the variational lower bound.
|
||||
pred: (batch_size, n, total_features)
|
||||
noisy_data: dict
|
||||
X, E, y : (bs, n, dx), (bs, n, n, de), (bs, dy)
|
||||
node_mask : (bs, n)
|
||||
Output: nll (size 1)
|
||||
"""
|
||||
t = noisy_data['t']
|
||||
|
||||
# 1.
|
||||
N = node_mask.sum(1).long()
|
||||
log_pN = self.node_dist.log_prob(N)
|
||||
|
||||
# 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero.
|
||||
kl_prior = self.kl_prior(X, E, node_mask)
|
||||
|
||||
# 3. Diffusion loss
|
||||
loss_all_t = self.compute_Lt(X, E, y, pred, noisy_data, node_mask, test)
|
||||
|
||||
# 4. Reconstruction loss
|
||||
# Compute L0 term : -log p (X, E, y | z_0) = reconstruction loss
|
||||
prob0 = self.reconstruction_logp(t, X, E, y, node_mask)
|
||||
|
||||
eps = 1e-8
|
||||
loss_term_0 = self.val_X_logp(X * (prob0.X+eps).log()) + self.val_E_logp(E * (prob0.E+eps).log())
|
||||
|
||||
# Combine terms
|
||||
nlls = - log_pN + kl_prior + loss_all_t - loss_term_0
|
||||
assert len(nlls.shape) == 1, f'{nlls.shape} has more than only batch dim.'
|
||||
|
||||
# Update NLL metric object and return batch nll
|
||||
nll = (self.test_nll if test else self.val_nll)(nlls) # Average over the batch
|
||||
|
||||
return nll
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None):
|
||||
"""
|
||||
:param batch_id: int
|
||||
:param batch_size: int
|
||||
:param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes
|
||||
:param save_final: int: number of predictions to save to file
|
||||
:param keep_chain: int: number of chains to save to file (disabled)
|
||||
:param keep_chain_steps: number of timesteps to save for each chain (disabled)
|
||||
:return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions)
|
||||
"""
|
||||
if num_nodes is None:
|
||||
n_nodes = self.node_dist.sample_n(batch_size, self.device)
|
||||
elif type(num_nodes) == int:
|
||||
n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int)
|
||||
else:
|
||||
assert isinstance(num_nodes, torch.Tensor)
|
||||
n_nodes = num_nodes
|
||||
n_max = self.max_n_nodes
|
||||
arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||
node_mask = arange < n_nodes.unsqueeze(1)
|
||||
|
||||
z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask)
|
||||
X, E = z_T.X, z_T.E
|
||||
|
||||
assert (E == torch.transpose(E, 1, 2)).all()
|
||||
|
||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||
for s_int in reversed(range(0, self.T)):
|
||||
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
|
||||
t_array = s_array + 1
|
||||
s_norm = s_array / self.T
|
||||
t_norm = t_array / self.T
|
||||
|
||||
# Sample z_s
|
||||
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||
|
||||
# Sample
|
||||
sampled_s = sampled_s.mask(node_mask, collapse=True)
|
||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||
|
||||
molecule_list = []
|
||||
for i in range(batch_size):
|
||||
n = n_nodes[i]
|
||||
atom_types = X[i, :n].cpu()
|
||||
edge_types = E[i, :n, :n].cpu()
|
||||
molecule_list.append([atom_types, edge_types])
|
||||
|
||||
return molecule_list
|
||||
|
||||
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
|
||||
"""Samples from zs ~ p(zs | zt). Only used during sampling.
|
||||
if last_step, return the graph prediction as well"""
|
||||
bs, n, dxs = X_t.shape
|
||||
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
|
||||
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
|
||||
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
|
||||
|
||||
# Neural net predictions
|
||||
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask}
|
||||
|
||||
def get_prob(noisy_data, unconditioned=False):
|
||||
pred = self.forward(noisy_data, unconditioned=unconditioned)
|
||||
|
||||
# Normalize predictions
|
||||
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
|
||||
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
|
||||
|
||||
# Retrieve transitions matrix
|
||||
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
|
||||
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device)
|
||||
Qt = self.transition_model.get_Qt(beta_t, self.device)
|
||||
|
||||
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
|
||||
p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all,
|
||||
Qt=Qt.X,
|
||||
Qsb=Qsb.X,
|
||||
Qtb=Qtb.X)
|
||||
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
|
||||
weightedX_all = predX_all.unsqueeze(-1) * p_s_and_t_given_0
|
||||
unnormalized_probX_all = weightedX_all.sum(dim=2) # bs, n, d_t-1
|
||||
|
||||
unnormalized_prob_X = unnormalized_probX_all[:, :, :self.Xdim_output]
|
||||
unnormalized_prob_E = unnormalized_probX_all[:, :, self.Xdim_output:].reshape(bs, n*n, -1)
|
||||
|
||||
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
|
||||
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
|
||||
|
||||
prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True) # bs, n, d_t-1
|
||||
prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True) # bs, n, d_t-1
|
||||
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
|
||||
|
||||
return prob_X, prob_E
|
||||
|
||||
prob_X, prob_E = get_prob(noisy_data)
|
||||
|
||||
### Guidance
|
||||
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
||||
uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True)
|
||||
prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale
|
||||
prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale
|
||||
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10)
|
||||
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10)
|
||||
|
||||
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
|
||||
|
||||
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
|
||||
|
||||
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
|
||||
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
|
||||
|
||||
assert (E_s == torch.transpose(E_s, 1, 2)).all()
|
||||
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
|
||||
|
||||
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
||||
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
|
||||
|
||||
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)
|
138
mcd/main.py
Normal file
138
mcd/main.py
Normal file
@ -0,0 +1,138 @@
|
||||
# These imports are tricky because they use c++, do not move them
|
||||
import os, shutil
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
import utils
|
||||
from datasets import dataset
|
||||
from diffusion_model import MCD
|
||||
from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete
|
||||
from metrics.molecular_metrics_sampling import SamplingMolecularMetrics
|
||||
|
||||
from analysis.visualization import MolecularVisualization
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
|
||||
def remove_folder(folder):
|
||||
for filename in os.listdir(folder):
|
||||
file_path = os.path.join(folder, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
print("Failed to delete %s. Reason: %s" % (file_path, e))
|
||||
|
||||
|
||||
def get_resume(cfg, model_kwargs):
|
||||
"""Resumes a run. It loads previous config without allowing to update keys (used for testing)."""
|
||||
saved_cfg = cfg.copy()
|
||||
name = cfg.general.name + "_resume"
|
||||
resume = cfg.general.test_only
|
||||
batch_size = cfg.train.batch_size
|
||||
model = MCD.load_from_checkpoint(resume, **model_kwargs)
|
||||
cfg = model.cfg
|
||||
cfg.general.test_only = resume
|
||||
cfg.general.name = name
|
||||
cfg.train.batch_size = batch_size
|
||||
cfg = utils.update_config_with_new_keys(cfg, saved_cfg)
|
||||
return cfg, model
|
||||
|
||||
def get_resume_adaptive(cfg, model_kwargs):
|
||||
"""Resumes a run. It loads previous config but allows to make some changes (used for resuming training)."""
|
||||
saved_cfg = cfg.copy()
|
||||
# Fetch path to this file to get base path
|
||||
current_path = os.path.dirname(os.path.realpath(__file__))
|
||||
root_dir = current_path.split("outputs")[0]
|
||||
|
||||
resume_path = os.path.join(root_dir, cfg.general.resume)
|
||||
|
||||
if cfg.model.type == "discrete":
|
||||
model = MCD.load_from_checkpoint(
|
||||
resume_path, **model_kwargs
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown model")
|
||||
|
||||
new_cfg = model.cfg
|
||||
for category in cfg:
|
||||
for arg in cfg[category]:
|
||||
new_cfg[category][arg] = cfg[category][arg]
|
||||
|
||||
new_cfg.general.resume = resume_path
|
||||
new_cfg.general.name = new_cfg.general.name + "_resume"
|
||||
|
||||
new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg)
|
||||
return new_cfg, model
|
||||
|
||||
|
||||
@hydra.main(
|
||||
version_base="1.1", config_path="../configs", config_name="config_dev"
|
||||
)
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
datamodule = dataset.DataModule(cfg)
|
||||
datamodule.prepare_data()
|
||||
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
|
||||
train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||
|
||||
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||
|
||||
sampling_metrics = SamplingMolecularMetrics(
|
||||
dataset_infos, train_smiles, reference_smiles
|
||||
)
|
||||
visualization_tools = MolecularVisualization(dataset_infos)
|
||||
|
||||
model_kwargs = {
|
||||
"dataset_infos": dataset_infos,
|
||||
"train_metrics": train_metrics,
|
||||
"sampling_metrics": sampling_metrics,
|
||||
"visualization_tools": visualization_tools,
|
||||
}
|
||||
|
||||
if cfg.general.test_only:
|
||||
# When testing, previous configuration is fully loaded
|
||||
cfg, _ = get_resume(cfg, model_kwargs)
|
||||
os.chdir(cfg.general.test_only.split("checkpoints")[0])
|
||||
elif cfg.general.resume is not None:
|
||||
# When resuming, we can override some parts of previous configuration
|
||||
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
|
||||
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
||||
|
||||
model = MCD(cfg=cfg, **model_kwargs)
|
||||
trainer = Trainer(
|
||||
gradient_clip_val=cfg.train.clip_grad,
|
||||
accelerator="gpu"
|
||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
else "cpu",
|
||||
devices=cfg.general.gpus
|
||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
else None,
|
||||
max_epochs=cfg.train.n_epochs,
|
||||
enable_checkpointing=False,
|
||||
check_val_every_n_epoch=cfg.train.check_val_every_n_epoch,
|
||||
val_check_interval=cfg.train.val_check_interval,
|
||||
strategy="ddp" if cfg.general.gpus > 1 else "auto",
|
||||
enable_progress_bar=cfg.general.enable_progress_bar,
|
||||
callbacks=[],
|
||||
reload_dataloaders_every_n_epochs=0,
|
||||
logger=[],
|
||||
)
|
||||
|
||||
if not cfg.general.test_only:
|
||||
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
|
||||
if cfg.general.save_model:
|
||||
trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt")
|
||||
trainer.test(model, datamodule=datamodule)
|
||||
else:
|
||||
trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
mcd/metrics/__init__.py
Normal file
0
mcd/metrics/__init__.py
Normal file
138
mcd/metrics/abstract_metrics.py
Normal file
138
mcd/metrics/abstract_metrics.py
Normal file
@ -0,0 +1,138 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics import Metric, MeanSquaredError
|
||||
|
||||
|
||||
class TrainAbstractMetricsDiscrete(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def log_epoch_metrics(self, current_epoch):
|
||||
pass
|
||||
|
||||
|
||||
class TrainAbstractMetrics(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def log_epoch_metrics(self, current_epoch):
|
||||
pass
|
||||
|
||||
|
||||
class SumExceptBatchMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, values) -> None:
|
||||
self.total_value += torch.sum(values)
|
||||
self.total_samples += values.shape[0]
|
||||
|
||||
def compute(self):
|
||||
return self.total_value / self.total_samples
|
||||
|
||||
|
||||
class SumExceptBatchMSE(MeanSquaredError):
|
||||
def update(self, preds: Tensor, target: Tensor) -> None:
|
||||
"""Update state with predictions and targets.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
assert preds.shape == target.shape
|
||||
sum_squared_error, n_obs = self._mean_squared_error_update(preds, target)
|
||||
|
||||
self.sum_squared_error += sum_squared_error
|
||||
self.total += n_obs
|
||||
|
||||
def _mean_squared_error_update(self, preds: Tensor, target: Tensor):
|
||||
""" Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
|
||||
tensors.
|
||||
preds: Predicted tensor
|
||||
target: Ground truth tensor
|
||||
"""
|
||||
diff = preds - target
|
||||
sum_squared_error = torch.sum(diff * diff)
|
||||
n_obs = preds.shape[0]
|
||||
return sum_squared_error, n_obs
|
||||
|
||||
|
||||
class SumExceptBatchKL(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, p, q) -> None:
|
||||
self.total_value += F.kl_div(q, p, reduction='sum')
|
||||
self.total_samples += p.size(0)
|
||||
|
||||
def compute(self):
|
||||
return self.total_value / self.total_samples
|
||||
|
||||
|
||||
class CrossEntropyMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor, weight=None) -> None:
|
||||
""" Update state with predictions and targets.
|
||||
preds: Predictions from model (bs * n, d) or (bs * n * n, d)
|
||||
target: Ground truth values (bs * n, d) or (bs * n * n, d). """
|
||||
target = torch.argmax(target, dim=-1)
|
||||
if weight is not None:
|
||||
weight = weight.type_as(preds)
|
||||
output = F.cross_entropy(preds, target, weight = weight, reduction='sum')
|
||||
else:
|
||||
output = F.cross_entropy(preds, target, reduction='sum')
|
||||
self.total_ce += output
|
||||
self.total_samples += preds.size(0)
|
||||
|
||||
def compute(self):
|
||||
return self.total_ce / self.total_samples
|
||||
|
||||
|
||||
class ProbabilityMetric(Metric):
|
||||
def __init__(self):
|
||||
""" This metric is used to track the marginal predicted probability of a class during training. """
|
||||
super().__init__()
|
||||
self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: Tensor) -> None:
|
||||
self.prob += preds.sum()
|
||||
self.total += preds.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.prob / self.total
|
||||
|
||||
|
||||
class NLL(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, batch_nll) -> None:
|
||||
self.total_nll += torch.sum(batch_nll)
|
||||
self.total_samples += batch_nll.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.total_nll / self.total_samples
|
BIN
mcd/metrics/fpscores.pkl.gz
Normal file
BIN
mcd/metrics/fpscores.pkl.gz
Normal file
Binary file not shown.
138
mcd/metrics/molecular_metrics_sampling.py
Normal file
138
mcd/metrics/molecular_metrics_sampling.py
Normal file
@ -0,0 +1,138 @@
|
||||
### packages for visualization
|
||||
from analysis.rdkit_functions import compute_molecular_metrics
|
||||
from mini_moses.metrics.metrics import compute_intermediate_statistics
|
||||
from metrics.property_metric import TaskModel
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import os
|
||||
import csv
|
||||
import time
|
||||
|
||||
def result_to_csv(path, dict_data):
|
||||
file_exists = os.path.exists(path)
|
||||
log_name = dict_data.pop("log_name", None)
|
||||
if log_name is None:
|
||||
raise ValueError("The provided dictionary must contain a 'log_name' key.")
|
||||
field_names = ["log_name"] + list(dict_data.keys())
|
||||
dict_data["log_name"] = log_name
|
||||
with open(path, "a", newline="") as file:
|
||||
writer = csv.DictWriter(file, fieldnames=field_names)
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
writer.writerow(dict_data)
|
||||
|
||||
|
||||
class SamplingMolecularMetrics(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_infos,
|
||||
train_smiles,
|
||||
reference_smiles,
|
||||
n_jobs=1,
|
||||
device="cpu",
|
||||
batch_size=512,
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = dataset_infos.task
|
||||
self.dataset_infos = dataset_infos
|
||||
self.active_atoms = dataset_infos.active_atoms
|
||||
self.train_smiles = train_smiles
|
||||
|
||||
if reference_smiles is not None:
|
||||
print(
|
||||
f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
|
||||
)
|
||||
start_time = time.time()
|
||||
self.stat_ref = compute_intermediate_statistics(
|
||||
reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(
|
||||
f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
|
||||
)
|
||||
else:
|
||||
self.stat_ref = None
|
||||
|
||||
self.comput_config = {
|
||||
"n_jobs": n_jobs,
|
||||
"device": device,
|
||||
"batch_size": batch_size,
|
||||
}
|
||||
|
||||
self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None}
|
||||
for cur_task in dataset_infos.task.split("-")[:]:
|
||||
# print('loading evaluator for task', cur_task)
|
||||
model_path = os.path.join(
|
||||
dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
|
||||
)
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
evaluator = TaskModel(model_path, cur_task)
|
||||
self.task_evaluator[cur_task] = evaluator
|
||||
|
||||
def forward(self, molecules, targets, name, current_epoch, val_counter, test=False):
|
||||
if isinstance(targets, list):
|
||||
targets_cat = torch.cat(targets, dim=0)
|
||||
targets_np = targets_cat.detach().cpu().numpy()
|
||||
else:
|
||||
targets_np = targets.detach().cpu().numpy()
|
||||
|
||||
unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics(
|
||||
molecules,
|
||||
targets_np,
|
||||
self.train_smiles,
|
||||
self.stat_ref,
|
||||
self.dataset_infos,
|
||||
self.task_evaluator,
|
||||
self.comput_config,
|
||||
)
|
||||
|
||||
if test:
|
||||
file_name = "final_smiles.txt"
|
||||
with open(file_name, "w") as fp:
|
||||
all_tasks_name = list(self.task_evaluator.keys())
|
||||
all_tasks_name = all_tasks_name.copy()
|
||||
if 'meta_taskname' in all_tasks_name:
|
||||
all_tasks_name.remove('meta_taskname')
|
||||
if 'scs' in all_tasks_name:
|
||||
all_tasks_name.remove('scs')
|
||||
|
||||
all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
|
||||
fp.write(all_tasks_str + "\n")
|
||||
for i, smiles in enumerate(all_smiles):
|
||||
if targets_log is not None:
|
||||
all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
|
||||
fp.write(all_result_str + "\n")
|
||||
else:
|
||||
fp.write("%s\n" % smiles)
|
||||
print("All smiles saved")
|
||||
else:
|
||||
result_path = os.path.join(os.getcwd(), f"graphs/{name}")
|
||||
os.makedirs(result_path, exist_ok=True)
|
||||
text_path = os.path.join(
|
||||
result_path,
|
||||
f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt",
|
||||
)
|
||||
textfile = open(text_path, "w")
|
||||
for smiles in unique_smiles:
|
||||
textfile.write(smiles + "\n")
|
||||
textfile.close()
|
||||
|
||||
all_logs = all_metrics
|
||||
if test:
|
||||
all_logs["log_name"] = "test"
|
||||
else:
|
||||
all_logs["log_name"] = (
|
||||
"epoch" + str(current_epoch) + "_batch" + str(val_counter)
|
||||
)
|
||||
|
||||
result_to_csv("output.csv", all_logs)
|
||||
return all_smiles
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
126
mcd/metrics/molecular_metrics_train.py
Normal file
126
mcd/metrics/molecular_metrics_train.py
Normal file
@ -0,0 +1,126 @@
|
||||
import torch
|
||||
from torchmetrics import Metric, MetricCollection
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
class CEPerClass(Metric):
|
||||
full_state_update = False
|
||||
def __init__(self, class_id):
|
||||
super().__init__()
|
||||
self.class_id = class_id
|
||||
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum')
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor) -> None:
|
||||
"""Update state with predictions and targets.
|
||||
Args:
|
||||
preds: Predictions from model (bs, n, d) or (bs, n, n, d)
|
||||
target: Ground truth values (bs, n, d) or (bs, n, n, d)
|
||||
"""
|
||||
target = target.reshape(-1, target.shape[-1])
|
||||
mask = (target != 0.).any(dim=-1)
|
||||
|
||||
prob = self.softmax(preds)[..., self.class_id]
|
||||
prob = prob.flatten()[mask]
|
||||
|
||||
target = target[:, self.class_id]
|
||||
target = target[mask]
|
||||
|
||||
output = self.binary_cross_entropy(prob, target)
|
||||
|
||||
self.total_ce += output
|
||||
self.total_samples += prob.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.total_ce / self.total_samples
|
||||
|
||||
|
||||
class AtomCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
class NoBondCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class SingleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class DoubleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class TripleCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class AromaticCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
|
||||
class AtomMetricsCE(MetricCollection):
|
||||
def __init__(self, active_atoms):
|
||||
metrics_list = []
|
||||
|
||||
for i, atom_type in enumerate(active_atoms):
|
||||
metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i))
|
||||
|
||||
super().__init__(metrics_list)
|
||||
|
||||
|
||||
class BondMetricsCE(MetricCollection):
|
||||
def __init__(self):
|
||||
ce_no_bond = NoBondCE(0)
|
||||
ce_SI = SingleCE(1)
|
||||
ce_DO = DoubleCE(2)
|
||||
ce_TR = TripleCE(3)
|
||||
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
|
||||
|
||||
|
||||
class TrainMolecularMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
super().__init__()
|
||||
active_atoms = dataset_infos.active_atoms
|
||||
self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms)
|
||||
self.train_bond_metrics = BondMetricsCE()
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
|
||||
self.train_atom_metrics(masked_pred_X, true_X)
|
||||
self.train_bond_metrics(masked_pred_E, true_E)
|
||||
if log:
|
||||
to_log = {}
|
||||
for key, val in self.train_atom_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
for key, val in self.train_bond_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
|
||||
def reset(self):
|
||||
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
|
||||
metric.reset()
|
||||
|
||||
def log_epoch_metrics(self, current_epoch, log=True):
|
||||
epoch_atom_metrics = self.train_atom_metrics.compute()
|
||||
epoch_bond_metrics = self.train_bond_metrics.compute()
|
||||
|
||||
to_log = {}
|
||||
for key, val in epoch_atom_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
for key, val in epoch_bond_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
|
||||
for key, val in epoch_atom_metrics.items():
|
||||
epoch_atom_metrics[key] = round(val.item(),4)
|
||||
for key, val in epoch_bond_metrics.items():
|
||||
epoch_bond_metrics[key] = round(val.item(),4)
|
||||
|
||||
if log:
|
||||
print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}")
|
||||
|
201
mcd/metrics/property_metric.py
Normal file
201
mcd/metrics/property_metric.py
Normal file
@ -0,0 +1,201 @@
|
||||
import math, os
|
||||
import pickle
|
||||
import os.path as op
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from joblib import dump, load
|
||||
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
||||
from sklearn.metrics import mean_absolute_error, roc_auc_score
|
||||
|
||||
|
||||
from rdkit import Chem
|
||||
from rdkit import rdBase
|
||||
from rdkit.Chem import AllChem
|
||||
from rdkit import DataStructs
|
||||
from rdkit.Chem import rdMolDescriptors
|
||||
rdBase.DisableLog('rdApp.error')
|
||||
|
||||
task_to_colname = {
|
||||
'hiv_b': 'HIV_active',
|
||||
'bace_b': 'Class',
|
||||
'bbbp_b': 'p_np',
|
||||
'O2': 'O2',
|
||||
'N2': 'N2',
|
||||
'CO2': 'CO2',
|
||||
}
|
||||
|
||||
tasktype_name = {
|
||||
'hiv_b': 'classification',
|
||||
'bace_b': 'classification',
|
||||
'bbbp_b': 'classification',
|
||||
'O2': 'regression',
|
||||
'N2': 'regression',
|
||||
'CO2': 'regression',
|
||||
}
|
||||
|
||||
class TaskModel():
|
||||
"""Scores based on an ECFP classifier."""
|
||||
def __init__(self, model_path, task_name):
|
||||
task_type = tasktype_name[task_name]
|
||||
self.task_name = task_name
|
||||
self.task_type = task_type
|
||||
self.model_path = model_path
|
||||
self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error
|
||||
|
||||
try:
|
||||
self.model = load(model_path)
|
||||
print(self.task_name, ' evaluator loaded')
|
||||
except:
|
||||
print(self.task_name, ' evaluator not found, training new one...')
|
||||
if 'classification' in task_type:
|
||||
self.model = RandomForestClassifier(random_state=0)
|
||||
elif 'regression' in task_type:
|
||||
self.model = RandomForestRegressor(random_state=0)
|
||||
perfermance = self.train()
|
||||
dump(self.model, model_path)
|
||||
print('Oracle peformance: ', perfermance)
|
||||
|
||||
def train(self):
|
||||
data_path = os.path.dirname(self.model_path)
|
||||
data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz')
|
||||
df = pd.read_csv(data_path)
|
||||
col_name = task_to_colname[self.task_name]
|
||||
y = df[col_name].to_numpy()
|
||||
x_smiles = df['smiles'].to_numpy()
|
||||
mask = ~np.isnan(y)
|
||||
y = y[mask]
|
||||
|
||||
if 'classification' in self.task_type:
|
||||
y = y.astype(int)
|
||||
|
||||
x_smiles = x_smiles[mask]
|
||||
x_fps = []
|
||||
mask = []
|
||||
for i,smiles in enumerate(x_smiles):
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
mask.append( int(mol is not None) )
|
||||
fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
|
||||
x_fps.append(fp)
|
||||
x_fps = np.concatenate(x_fps, axis=0)
|
||||
self.model.fit(x_fps, y)
|
||||
y_pred = self.model.predict(x_fps)
|
||||
perf = self.metric_func(y, y_pred)
|
||||
print(f'{self.task_name} performance: {perf}')
|
||||
return perf
|
||||
|
||||
def __call__(self, smiles_list):
|
||||
fps = []
|
||||
mask = []
|
||||
for i,smiles in enumerate(smiles_list):
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
mask.append( int(mol is not None) )
|
||||
fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
|
||||
fps.append(fp)
|
||||
|
||||
fps = np.concatenate(fps, axis=0)
|
||||
if 'classification' in self.task_type:
|
||||
scores = self.model.predict_proba(fps)[:, 1]
|
||||
else:
|
||||
scores = self.model.predict(fps)
|
||||
scores = scores * np.array(mask)
|
||||
return np.float32(scores)
|
||||
|
||||
@classmethod
|
||||
def fingerprints_from_mol(cls, mol): # use ECFP4
|
||||
features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
|
||||
features = np.zeros((1,))
|
||||
DataStructs.ConvertToNumpyArray(features_vec, features)
|
||||
return features.reshape(1, -1)
|
||||
|
||||
###### SAS Score ######
|
||||
_fscores = None
|
||||
|
||||
def readFragmentScores(name='fpscores'):
|
||||
import gzip
|
||||
global _fscores
|
||||
# generate the full path filename:
|
||||
if name == "fpscores":
|
||||
name = op.join(op.dirname(__file__), name)
|
||||
data = pickle.load(gzip.open('%s.pkl.gz' % name))
|
||||
outDict = {}
|
||||
for i in data:
|
||||
for j in range(1, len(i)):
|
||||
outDict[i[j]] = float(i[0])
|
||||
_fscores = outDict
|
||||
|
||||
def numBridgeheadsAndSpiro(mol, ri=None):
|
||||
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
||||
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
||||
return nBridgehead, nSpiro
|
||||
|
||||
def calculateSAS(smiles_list):
|
||||
scores = []
|
||||
for i, smiles in enumerate(smiles_list):
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
score = calculateScore(mol)
|
||||
scores.append(score)
|
||||
return np.float32(scores)
|
||||
|
||||
def calculateScore(m):
|
||||
if _fscores is None:
|
||||
readFragmentScores()
|
||||
|
||||
# fragment score
|
||||
fp = rdMolDescriptors.GetMorganFingerprint(m,
|
||||
2) # <- 2 is the *radius* of the circular fingerprint
|
||||
fps = fp.GetNonzeroElements()
|
||||
score1 = 0.
|
||||
nf = 0
|
||||
for bitId, v in fps.items():
|
||||
nf += v
|
||||
sfp = bitId
|
||||
score1 += _fscores.get(sfp, -4) * v
|
||||
score1 /= nf
|
||||
|
||||
# features score
|
||||
nAtoms = m.GetNumAtoms()
|
||||
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
|
||||
ri = m.GetRingInfo()
|
||||
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
|
||||
nMacrocycles = 0
|
||||
for x in ri.AtomRings():
|
||||
if len(x) > 8:
|
||||
nMacrocycles += 1
|
||||
|
||||
sizePenalty = nAtoms**1.005 - nAtoms
|
||||
stereoPenalty = math.log10(nChiralCenters + 1)
|
||||
spiroPenalty = math.log10(nSpiro + 1)
|
||||
bridgePenalty = math.log10(nBridgeheads + 1)
|
||||
macrocyclePenalty = 0.
|
||||
# ---------------------------------------
|
||||
# This differs from the paper, which defines:
|
||||
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
||||
# This form generates better results when 2 or more macrocycles are present
|
||||
if nMacrocycles > 0:
|
||||
macrocyclePenalty = math.log10(2)
|
||||
|
||||
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
|
||||
|
||||
# correction for the fingerprint density
|
||||
# not in the original publication, added in version 1.1
|
||||
# to make highly symmetrical molecules easier to synthetise
|
||||
score3 = 0.
|
||||
if nAtoms > len(fps):
|
||||
score3 = math.log(float(nAtoms) / len(fps)) * .5
|
||||
|
||||
sascore = score1 + score2 + score3
|
||||
|
||||
# need to transform "raw" value into scale between 1 and 10
|
||||
min = -4.0
|
||||
max = 2.5
|
||||
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
||||
# smooth the 10-end
|
||||
if sascore > 8.:
|
||||
sascore = 8. + math.log(sascore + 1. - 9.)
|
||||
if sascore > 10.:
|
||||
sascore = 10.0
|
||||
elif sascore < 1.:
|
||||
sascore = 1.0
|
||||
|
||||
return sascore
|
94
mcd/metrics/train_loss.py
Normal file
94
mcd/metrics/train_loss.py
Normal file
@ -0,0 +1,94 @@
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from metrics.abstract_metrics import CrossEntropyMetric
|
||||
from torchmetrics import Metric, MeanSquaredError
|
||||
|
||||
# from 2:He to 119:*
|
||||
valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||
valencies_check = torch.tensor(valencies_check)
|
||||
|
||||
weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0]
|
||||
weight_check = torch.tensor(weight_check)
|
||||
|
||||
class AtomWeightMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
|
||||
global weight_check
|
||||
self.weight_check = weight_check
|
||||
|
||||
def update(self, X, Y):
|
||||
atom_pred_num = X.argmax(dim=-1)
|
||||
atom_real_num = Y.argmax(dim=-1)
|
||||
self.weight_check = self.weight_check.type_as(X)
|
||||
|
||||
pred_weight = self.weight_check[atom_pred_num]
|
||||
real_weight = self.weight_check[atom_real_num]
|
||||
|
||||
lss = 0
|
||||
lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum()
|
||||
self.total_loss += lss
|
||||
self.total_samples += X.size(0)
|
||||
|
||||
def compute(self):
|
||||
return self.total_loss / self.total_samples
|
||||
|
||||
|
||||
class TrainLossDiscrete(nn.Module):
|
||||
""" Train with Cross entropy"""
|
||||
def __init__(self, lambda_train, weight_node=None, weight_edge=None):
|
||||
super().__init__()
|
||||
self.node_loss = CrossEntropyMetric()
|
||||
self.edge_loss = CrossEntropyMetric()
|
||||
self.weight_loss = AtomWeightMetric()
|
||||
|
||||
self.y_loss = MeanSquaredError()
|
||||
self.lambda_train = lambda_train
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool):
|
||||
""" Compute train metrics
|
||||
masked_pred_X : tensor -- (bs, n, dx)
|
||||
masked_pred_E : tensor -- (bs, n, n, de)
|
||||
pred_y : tensor -- (bs, )
|
||||
true_X : tensor -- (bs, n, dx)
|
||||
true_E : tensor -- (bs, n, n, de)
|
||||
true_y : tensor -- (bs, )
|
||||
log : boolean. """
|
||||
|
||||
loss_weight = self.weight_loss(masked_pred_X, true_X)
|
||||
|
||||
true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx)
|
||||
true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de)
|
||||
masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1))) # (bs * n, dx)
|
||||
masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1))) # (bs * n * n, de)
|
||||
|
||||
# Remove masked rows
|
||||
mask_X = (true_X != 0.).any(dim=-1)
|
||||
mask_E = (true_E != 0.).any(dim=-1)
|
||||
|
||||
flat_true_X = true_X[mask_X, :]
|
||||
flat_pred_X = masked_pred_X[mask_X, :]
|
||||
|
||||
flat_true_E = true_E[mask_E, :]
|
||||
flat_pred_E = masked_pred_E[mask_E, :]
|
||||
|
||||
loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0
|
||||
loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0
|
||||
|
||||
return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight
|
||||
|
||||
def reset(self):
|
||||
for metric in [self.node_loss, self.edge_loss, self.y_loss]:
|
||||
metric.reset()
|
||||
|
||||
def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True):
|
||||
epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
|
||||
epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
|
||||
epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1
|
||||
|
||||
if log:
|
||||
print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} "
|
||||
f"Weight: {epoch_weight_loss :.4f} "
|
||||
f"-- Time taken {time.time() - start_epoch_time:.1f}s ")
|
0
mcd/models/__init__.py
Normal file
0
mcd/models/__init__.py
Normal file
119
mcd/models/conditions.py
Normal file
119
mcd/models/conditions.py
Normal file
@ -0,0 +1,119 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t = t.view(-1)
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class CategoricalEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds categorical conditions such as data sources into vector representations.
|
||||
Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None, t=None):
|
||||
labels = labels.long().view(-1)
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
if True and train:
|
||||
noise = torch.randn_like(embeddings)
|
||||
embeddings = embeddings + noise
|
||||
return embeddings
|
||||
|
||||
class ClusterContinuousEmbedder(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
|
||||
if use_cfg_embedding:
|
||||
self.embedding_drop = nn.Embedding(1, hidden_size)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size, bias=True),
|
||||
nn.Softmax(dim=1),
|
||||
nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None, timestep=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if force_drop_ids is not None:
|
||||
drop_ids = force_drop_ids == 1
|
||||
else:
|
||||
drop_ids = None
|
||||
|
||||
if (train and use_dropout):
|
||||
drop_ids_rand = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
||||
if force_drop_ids is not None:
|
||||
drop_ids = torch.logical_or(drop_ids, drop_ids_rand)
|
||||
else:
|
||||
drop_ids = drop_ids_rand
|
||||
|
||||
if drop_ids is not None:
|
||||
embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device)
|
||||
embeddings[~drop_ids] = self.mlp(labels[~drop_ids])
|
||||
embeddings[drop_ids] += self.embedding_drop.weight[0]
|
||||
else:
|
||||
embeddings = self.mlp(labels)
|
||||
|
||||
if train:
|
||||
noise = torch.randn_like(embeddings)
|
||||
embeddings = embeddings + noise
|
||||
return embeddings
|
114
mcd/models/layers.py
Normal file
114
mcd/models/layers.py
Normal file
@ -0,0 +1,114 @@
|
||||
from torch.jit import Final
|
||||
import torch.nn.functional as F
|
||||
from itertools import repeat
|
||||
import collections.abc
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Attention(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_norm=False,
|
||||
attn_drop=0,
|
||||
proj_drop=0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
assert self.fast_attn, "scaled_dot_product_attention Not implemented"
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def dot_product_attention(self, q, k, v):
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn_sfmx = attn.softmax(dim=-1)
|
||||
attn_sfmx = self.attn_drop(attn_sfmx)
|
||||
x = attn_sfmx @ v
|
||||
return x, attn
|
||||
|
||||
def forward(self, x, node_mask):
|
||||
B, N, D = x.shape
|
||||
|
||||
# B, head, N, head_dim
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # B, head, N, head_dim
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
attn_mask = (node_mask[:, None, :, None] & node_mask[:, None, None, :]).expand(-1, self.num_heads, N, N)
|
||||
attn_mask[attn_mask.sum(-1) == 0] = True
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = to_2tuple(bias)
|
||||
drop_probs = to_2tuple(drop)
|
||||
linear_layer = nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
184
mcd/models/transformer.py
Normal file
184
mcd/models/transformer.py
Normal file
@ -0,0 +1,184 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import utils
|
||||
from models.layers import Attention, Mlp
|
||||
from models.conditions import TimestepEmbedder, CategoricalEmbedder, ClusterContinuousEmbedder
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
max_n_nodes,
|
||||
hidden_size=384,
|
||||
depth=12,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
drop_condition=0.1,
|
||||
Xdim=118,
|
||||
Edim=5,
|
||||
ydim=3,
|
||||
task_type='regression',
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.ydim = ydim
|
||||
self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedding_list = torch.nn.ModuleList()
|
||||
|
||||
self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition))
|
||||
for i in range(ydim - 2):
|
||||
if task_type == 'regression':
|
||||
self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition))
|
||||
else:
|
||||
self.y_embedding_list.append(CategoricalEmbedder(2, hidden_size, drop_condition))
|
||||
|
||||
self.encoders = nn.ModuleList(
|
||||
[
|
||||
SELayer(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
max_n_nodes=max_n_nodes,
|
||||
hidden_size=hidden_size,
|
||||
atom_type=Xdim,
|
||||
bond_type=Edim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def _constant_init(module, i):
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.constant_(module.weight, i)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, i)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
for block in self.encoders :
|
||||
_constant_init(block.adaLN_modulation[0], 0)
|
||||
_constant_init(self.decoder.adaLN_modulation[0], 0)
|
||||
|
||||
def forward(self, x, e, node_mask, y, t, unconditioned):
|
||||
|
||||
force_drop_id = torch.zeros_like(y.sum(-1))
|
||||
force_drop_id[torch.isnan(y.sum(-1))] = 1
|
||||
if unconditioned:
|
||||
force_drop_id = torch.ones_like(y[:, 0])
|
||||
|
||||
x_in, e_in, y_in = x, e, y
|
||||
bs, n, _ = x.size()
|
||||
x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
c1 = self.t_embedder(t)
|
||||
for i in range(1, self.ydim):
|
||||
if i == 1:
|
||||
c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t)
|
||||
else:
|
||||
c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t)
|
||||
c = c1 + c2
|
||||
|
||||
for i, block in enumerate(self.encoders):
|
||||
x = block(x, c, node_mask)
|
||||
|
||||
# X: B * N * dx, E: B * N * N * de
|
||||
X, E, y = self.decoder(x, x_in, e_in, c, t, node_mask)
|
||||
return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.dropout = 0.
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
|
||||
|
||||
self.attn = Attention(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True, **block_kwargs
|
||||
)
|
||||
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=int(hidden_size * mlp_ratio),
|
||||
drop=self.dropout,
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, node_mask):
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
x = x + gate_msa.unsqueeze(1) * modulate(self.norm1(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa)
|
||||
x = x + gate_mlp.unsqueeze(1) * modulate(self.norm2(self.mlp(x)), shift_mlp, scale_mlp)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
# Structure Decoder
|
||||
def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
|
||||
super().__init__()
|
||||
self.atom_type = atom_type
|
||||
self.bond_type = bond_type
|
||||
final_size = atom_type + max_n_nodes * bond_type
|
||||
self.xedecoder = Mlp(in_features=hidden_size,
|
||||
out_features=final_size, drop=0)
|
||||
|
||||
self.norm_final = nn.LayerNorm(final_size, elementwise_affine=False)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * final_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, x_in, e_in, c, t, node_mask):
|
||||
x_all = self.xedecoder(x)
|
||||
B, N, D = x_all.size()
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x_all = modulate(self.norm_final(x_all), shift, scale)
|
||||
|
||||
atom_out = x_all[:, :, :self.atom_type]
|
||||
atom_out = x_in + atom_out
|
||||
|
||||
bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type)
|
||||
bond_out = e_in + bond_out
|
||||
|
||||
##### standardize adj_out
|
||||
edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :]
|
||||
diag_mask = (
|
||||
torch.eye(N, dtype=torch.bool)
|
||||
.unsqueeze(0)
|
||||
.expand(B, -1, -1)
|
||||
.type_as(edge_mask)
|
||||
)
|
||||
bond_out.masked_fill_(edge_mask[:, :, :, None], 0)
|
||||
bond_out.masked_fill_(diag_mask[:, :, :, None], 0)
|
||||
bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2))
|
||||
|
||||
return atom_out, bond_out, None
|
135
mcd/utils.py
Normal file
135
mcd/utils.py
Normal file
@ -0,0 +1,135 @@
|
||||
import os
|
||||
from omegaconf import OmegaConf, open_dict
|
||||
|
||||
import torch
|
||||
import torch_geometric.utils
|
||||
from torch_geometric.utils import to_dense_adj, to_dense_batch
|
||||
|
||||
def create_folders(args):
|
||||
try:
|
||||
os.makedirs('graphs')
|
||||
os.makedirs('chains')
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
try:
|
||||
os.makedirs('graphs/' + args.general.name)
|
||||
os.makedirs('chains/' + args.general.name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def normalize(X, E, y, norm_values, norm_biases, node_mask):
|
||||
X = (X - norm_biases[0]) / norm_values[0]
|
||||
E = (E - norm_biases[1]) / norm_values[1]
|
||||
y = (y - norm_biases[2]) / norm_values[2]
|
||||
|
||||
diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
|
||||
E[diag] = 0
|
||||
|
||||
return PlaceHolder(X=X, E=E, y=y).mask(node_mask)
|
||||
|
||||
|
||||
def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False):
|
||||
"""
|
||||
X : node features
|
||||
E : edge features
|
||||
y : global features`
|
||||
norm_values : [norm value X, norm value E, norm value y]
|
||||
norm_biases : same order
|
||||
node_mask
|
||||
"""
|
||||
X = (X * norm_values[0] + norm_biases[0])
|
||||
E = (E * norm_values[1] + norm_biases[1])
|
||||
y = y * norm_values[2] + norm_biases[2]
|
||||
|
||||
return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse)
|
||||
|
||||
|
||||
def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None):
|
||||
X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes)
|
||||
# node_mask = node_mask.float()
|
||||
edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
|
||||
if max_num_nodes is None:
|
||||
max_num_nodes = X.size(1)
|
||||
E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
|
||||
E = encode_no_edge(E)
|
||||
return PlaceHolder(X=X, E=E, y=None), node_mask
|
||||
|
||||
|
||||
def encode_no_edge(E):
|
||||
assert len(E.shape) == 4
|
||||
if E.shape[-1] == 0:
|
||||
return E
|
||||
no_edge = torch.sum(E, dim=3) == 0
|
||||
first_elt = E[:, :, :, 0]
|
||||
first_elt[no_edge] = 1
|
||||
E[:, :, :, 0] = first_elt
|
||||
diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
|
||||
E[diag] = 0
|
||||
return E
|
||||
|
||||
|
||||
def update_config_with_new_keys(cfg, saved_cfg):
|
||||
saved_general = saved_cfg.general
|
||||
saved_train = saved_cfg.train
|
||||
saved_model = saved_cfg.model
|
||||
saved_dataset = saved_cfg.dataset
|
||||
|
||||
for key, val in saved_dataset.items():
|
||||
OmegaConf.set_struct(cfg.dataset, True)
|
||||
with open_dict(cfg.dataset):
|
||||
if key not in cfg.dataset.keys():
|
||||
setattr(cfg.dataset, key, val)
|
||||
|
||||
for key, val in saved_general.items():
|
||||
OmegaConf.set_struct(cfg.general, True)
|
||||
with open_dict(cfg.general):
|
||||
if key not in cfg.general.keys():
|
||||
setattr(cfg.general, key, val)
|
||||
|
||||
OmegaConf.set_struct(cfg.train, True)
|
||||
with open_dict(cfg.train):
|
||||
for key, val in saved_train.items():
|
||||
if key not in cfg.train.keys():
|
||||
setattr(cfg.train, key, val)
|
||||
|
||||
OmegaConf.set_struct(cfg.model, True)
|
||||
with open_dict(cfg.model):
|
||||
for key, val in saved_model.items():
|
||||
if key not in cfg.model.keys():
|
||||
setattr(cfg.model, key, val)
|
||||
return cfg
|
||||
|
||||
|
||||
class PlaceHolder:
|
||||
def __init__(self, X, E, y):
|
||||
self.X = X
|
||||
self.E = E
|
||||
self.y = y
|
||||
|
||||
def type_as(self, x: torch.Tensor, categorical: bool = False):
|
||||
""" Changes the device and dtype of X, E, y. """
|
||||
self.X = self.X.type_as(x)
|
||||
self.E = self.E.type_as(x)
|
||||
if categorical:
|
||||
self.y = self.y.type_as(x)
|
||||
return self
|
||||
|
||||
def mask(self, node_mask, collapse=False):
|
||||
x_mask = node_mask.unsqueeze(-1) # bs, n, 1
|
||||
e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
|
||||
e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
|
||||
|
||||
if collapse:
|
||||
self.X = torch.argmax(self.X, dim=-1)
|
||||
self.E = torch.argmax(self.E, dim=-1)
|
||||
|
||||
self.X[node_mask == 0] = - 1
|
||||
self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
|
||||
else:
|
||||
self.X = self.X * x_mask
|
||||
self.E = self.E * e_mask1 * e_mask2
|
||||
assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
|
||||
return self
|
||||
|
||||
|
17
requirements.txt
Normal file
17
requirements.txt
Normal file
@ -0,0 +1,17 @@
|
||||
fcd_torch==1.0.7
|
||||
hydra-core==1.3.2
|
||||
imageio==2.26.0
|
||||
joblib==1.2.0
|
||||
matplotlib==3.7.0
|
||||
mini_moses==1.0
|
||||
networkx==3.0
|
||||
numpy==1.24.2
|
||||
omegaconf==2.3.0
|
||||
pandas==1.5.3
|
||||
pytorch_lightning==2.0.1
|
||||
rdkit==2023.9.4
|
||||
scikit_learn==1.2.1
|
||||
torch==2.0.0
|
||||
torch_geometric==2.3.0
|
||||
torchmetrics==0.11.4
|
||||
tqdm==4.64.1
|
Loading…
Reference in New Issue
Block a user