init_code

This commit is contained in:
gang liu 2024-01-29 19:49:14 -05:00
parent 353d892291
commit 91727d2500
29 changed files with 3623 additions and 1 deletions

View File

@ -14,7 +14,7 @@ This is the code for MCD: a Multi-Conditional Diffusion Model for inverse small
## Requirements
All dependencies are specified in the `requirements.txt` file.
This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, pytorch-lightning 2.0.1.
This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1.
For molecular generation evaluation, we should first install rdkit:

0
configs/__init__.py Normal file
View File

46
configs/config.yaml Normal file
View File

@ -0,0 +1,46 @@
general:
name: 'MCD'
wandb: 'disabled'
gpus: 1
resume: null
test_only: null
sample_every_val: 2500
samples_to_generate: 512
samples_to_save: 3
chains_to_save: 1
log_every_steps: 50
number_chain_steps: 8
final_model_samples_to_generate: 10000
final_model_samples_to_save: 20
final_model_chains_to_save: 1
enable_progress_bar: False
save_model: False
model:
type: 'discrete'
transition: 'marginal'
model: 'MCD'
diffusion_steps: 500
diffusion_noise_schedule: 'cosine'
guide_scale: 2
hidden_size: 1152
depth: 6
num_heads: 16
mlp_ratio: 4
drop_condition: 0.01
lambda_train: [1, 10] # node and edge training weight
ensure_connected: True
train:
n_epochs: 10000
batch_size: 1200
lr: 0.0002
clip_grad: null
num_workers: 0
weight_decay: 0
seed: 0
val_check_interval: null
check_val_every_n_epoch: 1
dataset:
datadir: 'data/'
task_name: null
guidance_target: null
pin_memory: False

0
mcd/__init__.py Normal file
View File

0
mcd/analysis/__init__.py Normal file
View File

View File

@ -0,0 +1,411 @@
from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')
from fcd_torch import FCD as FCDMetric
from mini_moses.metrics.metrics import FragMetric, internal_diversity
from mini_moses.metrics.utils import get_mol, mapper
import re
import time
import random
random.seed(0)
import numpy as np
from multiprocessing import Pool
import torch
from metrics.property_metric import calculateSAS
bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}
bd_dict_x = {'O2-N2': [5.00E+04, 1.00E-03]}
bd_dict_y = {'O2-N2': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05]}
selectivity = ['O2-N2']
a_dict = {}
b_dict = {}
for prop_name in selectivity:
x1, x2 = np.log10(bd_dict_x[prop_name][0]), np.log10(bd_dict_x[prop_name][1])
y1, y2 = np.log10(bd_dict_y[prop_name][0]), np.log10(bd_dict_y[prop_name][1])
a = (y1-y2)/(x1-x2)
b = y1-a*x1
a_dict[prop_name] = a
b_dict[prop_name] = b
def selectivity_evaluation(gas1, gas2, prop_name):
x = np.log10(np.array(gas1))
y = np.log10(np.array(gas1) / np.array(gas2))
upper = (y - (a_dict[prop_name] * x + b_dict[prop_name])) > 0
return upper
class BasicMolecularMetrics(object):
def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512):
self.dataset_smiles_list = train_smiles
self.atom_decoder = atom_decoder
self.n_jobs = n_jobs
self.device = device
self.batch_size = batch_size
self.stat_ref = stat_ref
self.task_evaluator = task_evaluator
def compute_relaxed_validity(self, generated, ensure_connected):
valid = []
num_components = []
all_smiles = []
valid_mols = []
covered_atoms = set()
direct_valid_count = 0
for graph in generated:
atom_types, edge_types = graph
mol = build_molecule_with_partial_charges(atom_types, edge_types, self.atom_decoder)
direct_valid_flag = True if check_mol(mol, largest_connected_comp=True) is not None else False
if direct_valid_flag:
direct_valid_count += 1
if not ensure_connected:
mol_conn, _ = correct_mol(mol, connection=True)
mol = mol_conn if mol_conn is not None else correct_mol(mol, connection=False)[0]
else: # ensure fully connected
mol, _ = correct_mol(mol, connection=True)
smiles = mol2smiles(mol)
mol = get_mol(smiles)
try:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
num_components.append(len(mol_frags))
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
smiles = mol2smiles(largest_mol)
if smiles is not None and largest_mol is not None and len(smiles) > 1 and Chem.MolFromSmiles(smiles) is not None:
valid_mols.append(largest_mol)
valid.append(smiles)
for atom in largest_mol.GetAtoms():
covered_atoms.add(atom.GetSymbol())
all_smiles.append(smiles)
else:
all_smiles.append(None)
except Exception as e:
# print(f"An error occurred: {e}")
all_smiles.append(None)
return valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_smiles, covered_atoms
def evaluate(self, generated, targets, ensure_connected, active_atoms=None):
""" generated: list of pairs (positions: n x 3, atom_types: n [int])
the positions and atom types should already be masked. """
valid, validity, nc_validity, num_components, all_smiles, covered_atoms = self.compute_relaxed_validity(generated, ensure_connected=ensure_connected)
nc_mu = num_components.mean() if len(num_components) > 0 else 0
nc_min = num_components.min() if len(num_components) > 0 else 0
nc_max = num_components.max() if len(num_components) > 0 else 0
len_active = len(active_atoms) if active_atoms is not None else 1
cover_str = f"Cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}"
print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}")
print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")
if validity > 0:
dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity}
unique = list(set(valid))
close_pool = False
if self.n_jobs != 1:
pool = Pool(self.n_jobs)
close_pool = True
else:
pool = 1
valid_mols = mapper(pool)(get_mol, valid)
dist_metrics['interval_diversity'] = internal_diversity(valid_mols, pool, device=self.device)
start_time = time.time()
if self.stat_ref is not None:
kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size}
kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size}
try:
dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_mols, pref=self.stat_ref['Frag'])
except:
print('error: ', 'pool', pool)
print('valid_mols: ', valid_mols)
dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD'])
if self.task_evaluator is not None:
evaluation_list = list(self.task_evaluator.keys())
evaluation_list = evaluation_list.copy()
assert 'meta_taskname' in evaluation_list
meta_taskname = self.task_evaluator['meta_taskname']
evaluation_list.remove('meta_taskname')
meta_split = meta_taskname.split('-')
valid_index = np.array([True if smiles else False for smiles in all_smiles])
targets_log = {}
for i, name in enumerate(evaluation_list):
targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index))
targets_log[f'input_{name}'] = targets[:, i]
targets = targets[valid_index]
if len(meta_split) == 2:
cached_perm = {meta_split[0]: None, meta_split[1]: None}
for i, name in enumerate(evaluation_list):
if name == 'scs':
continue
elif name == 'sas':
scores = calculateSAS(valid)
else:
scores = self.task_evaluator[name](valid)
targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index))
targets_log[f'output_{name}'][valid_index] = scores
if name in ['O2', 'N2', 'CO2']:
if len(meta_split) == 2:
cached_perm[name] = scores
scores, cur_targets = np.log10(scores), np.log10(targets[:, i])
dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets))
elif name == 'sas':
dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i]))
else:
true_y = targets[:, i]
predicted_labels = (scores >= 0.5).astype(int)
acc = (predicted_labels == true_y).sum() / len(true_y)
dist_metrics[f'{name}/acc'] = acc
if len(meta_split) == 2:
if cached_perm[meta_split[0]] is not None and cached_perm[meta_split[1]] is not None:
task_name = self.task_evaluator['meta_taskname']
upper = selectivity_evaluation(cached_perm[meta_split[0]], cached_perm[meta_split[1]], task_name)
dist_metrics[f'selectivity/{task_name}'] = np.sum(upper)
end_time = time.time()
elapsed_time = end_time - start_time
max_key_length = max(len(key) for key in dist_metrics)
print(f'Details over {len(valid)} ({len(generated)}) valid (total) molecules, calculating metrics using {elapsed_time:.2f} s:')
strs = ''
for i, (key, value) in enumerate(dist_metrics.items()):
if isinstance(value, (int, float, np.floating, np.integer)):
strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t'
if i % 4 == 3:
strs = strs + '\n'
print(strs)
if close_pool:
pool.close()
pool.join()
else:
unique = []
dist_metrics = {}
targets_log = None
return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles, dist_metrics, targets_log
def mol2smiles(mol):
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except ValueError:
return None
return Chem.MolToSmiles(mol)
def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False):
if verbose:
print("\nbuilding new molecule")
mol = Chem.RWMol()
for atom in atom_types:
a = Chem.Atom(atom_decoder[atom.item()])
mol.AddAtom(a)
if verbose:
print("Atom added: ", atom.item(), atom_decoder[atom.item()])
edge_types = torch.triu(edge_types)
all_bonds = torch.nonzero(edge_types)
for i, bond in enumerate(all_bonds):
if bond[0].item() != bond[1].item():
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
if verbose:
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
bond_dict[edge_types[bond[0], bond[1]].item()])
# add formal charge to atom: e.g. [O+], [N+], [S+]
# not support [O-], [N-], [S-], [NH+] etc.
flag, atomid_valence = check_valency(mol)
if verbose:
print("flag, valence", flag, atomid_valence)
if flag:
continue
else:
if len(atomid_valence) == 2:
idx = atomid_valence[0]
v = atomid_valence[1]
an = mol.GetAtomWithIdx(idx).GetAtomicNum()
if verbose:
print("atomic num of atom with a large valence", an)
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
mol.GetAtomWithIdx(idx).SetFormalCharge(1)
# print("Formal charge added")
else:
continue
return mol
def check_valency(mol):
try:
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True, None
except ValueError as e:
e = str(e)
p = e.find('#')
e_sub = e[p:]
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
return False, atomid_valence
def correct_mol(mol, connection=False):
#####
no_correct = False
flag, _ = check_valency(mol)
if flag:
no_correct = True
while True:
if connection:
mol_conn = connect_fragments(mol)
# if mol_conn is not None:
mol = mol_conn
if mol is None:
return None, no_correct
flag, atomid_valence = check_valency(mol)
if flag:
break
else:
try:
assert len(atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
queue = []
check_idx = 0
for b in mol.GetAtomWithIdx(idx).GetBonds():
type = int(b.GetBondType())
queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
if type == 12:
check_idx += 1
queue.sort(key=lambda tup: tup[1], reverse=True)
if queue[-1][1] == 12:
return None, no_correct
elif len(queue) > 0:
start = queue[check_idx][2]
end = queue[check_idx][3]
t = queue[check_idx][1] - 1
mol.RemoveBond(start, end)
if t >= 1:
mol.AddBond(start, end, bond_dict[t])
except Exception as e:
# print(f"An error occurred in correction: {e}")
return None, no_correct
return mol, no_correct
def check_mol(m, largest_connected_comp=True):
if m is None:
return None
sm = Chem.MolToSmiles(m, isomericSmiles=True)
if largest_connected_comp and '.' in sm:
vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
vsm.sort(key=lambda tup: tup[1], reverse=True)
mol = Chem.MolFromSmiles(vsm[0][0])
else:
mol = Chem.MolFromSmiles(sm)
return mol
##### connect fragements
def select_atom_with_available_valency(frag):
atoms = list(frag.GetAtoms())
random.shuffle(atoms)
for atom in atoms:
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0:
return atom
return None
def select_atoms_with_available_valency(frag):
return [atom for atom in frag.GetAtoms() if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0]
def try_to_connect_fragments(combined_mol, frag, atom1, atom2):
# Make copies of the molecules to try the connection
trial_combined_mol = Chem.RWMol(combined_mol)
trial_frag = Chem.RWMol(frag)
# Add the new fragment to the combined molecule with new indices
new_indices = {atom.GetIdx(): trial_combined_mol.AddAtom(atom) for atom in trial_frag.GetAtoms()}
# Add the bond between the suitable atoms from each fragment
trial_combined_mol.AddBond(atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE)
# Adjust the hydrogen count of the connected atoms
for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]:
atom = trial_combined_mol.GetAtomWithIdx(atom_idx)
num_h = atom.GetTotalNumHs()
atom.SetNumExplicitHs(max(0, num_h - 1))
# Add bonds for the new fragment
for bond in trial_frag.GetBonds():
trial_combined_mol.AddBond(new_indices[bond.GetBeginAtomIdx()], new_indices[bond.GetEndAtomIdx()], bond.GetBondType())
# Convert to a Mol object and try to sanitize it
new_mol = Chem.Mol(trial_combined_mol)
try:
Chem.SanitizeMol(new_mol)
return new_mol # Return the new valid molecule
except Chem.MolSanitizeException:
return None # If the molecule is not valid, return None
def connect_fragments(mol):
# Get the separate fragments
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
if len(frags) < 2:
return mol
combined_mol = Chem.RWMol(frags[0])
for frag in frags[1:]:
# Select all atoms with available valency from both molecules
atoms1 = select_atoms_with_available_valency(combined_mol)
atoms2 = select_atoms_with_available_valency(frag)
# Try to connect using all combinations of available valency atoms
for atom1 in atoms1:
for atom2 in atoms2:
new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2)
if new_mol is not None:
# If a valid connection is made, update the combined molecule and break
combined_mol = new_mol
break
else:
# Continue if the inner loop didn't break (no valid connection found for atom1)
continue
# Break if the inner loop did break (valid connection found)
break
else:
# If no valid connections could be made with any of the atoms, return None
return None
return combined_mol
#### connect fragements
def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config):
""" molecule_list: (dict) """
atom_decoder = dataset_info.atom_decoder
active_atoms = dataset_info.active_atoms
ensure_connected = dataset_info.ensure_connected
metrics = BasicMolecularMetrics(atom_decoder, train_smiles, stat_ref, task_evaluator, **comput_config)
evaluated_res = metrics.evaluate(molecule_list, targets, ensure_connected, active_atoms)
all_smiles = evaluated_res[-3]
all_metrics = evaluated_res[-2]
targets_log = evaluated_res[-1]
unique_smiles = evaluated_res[0]
return unique_smiles, all_smiles, all_metrics, targets_log
if __name__ == '__main__':
smiles_mol = 'C1CCC1'
print("Smiles mol %s" % smiles_mol)
chem_mol = Chem.MolFromSmiles(smiles_mol)
print(block_mol)

View File

@ -0,0 +1,222 @@
import os
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Geometry import Point3D
from rdkit import RDLogger
import imageio
import networkx as nx
import numpy as np
import rdkit.Chem
import matplotlib.pyplot as plt
class MolecularVisualization:
def __init__(self, dataset_infos):
self.dataset_infos = dataset_infos
def mol_from_graphs(self, node_list, adjacency_matrix):
"""
Convert graphs to rdkit molecules
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
# dictionary to map integer value to the char of atom
atom_decoder = self.dataset_infos.atom_decoder
# [list(self.dataset_infos.atom_decoder.keys())[0]]
# create empty editable mol object
mol = Chem.RWMol()
# add atoms to mol and keep track of index
node_to_idx = {}
for i in range(len(node_list)):
if node_list[i] == -1:
continue
a = Chem.Atom(atom_decoder[int(node_list[i])])
molIdx = mol.AddAtom(a)
node_to_idx[i] = molIdx
for ix, row in enumerate(adjacency_matrix):
for iy, bond in enumerate(row):
# only traverse half the symmetric matrix
if iy <= ix:
continue
if bond == 1:
bond_type = Chem.rdchem.BondType.SINGLE
elif bond == 2:
bond_type = Chem.rdchem.BondType.DOUBLE
elif bond == 3:
bond_type = Chem.rdchem.BondType.TRIPLE
elif bond == 4:
bond_type = Chem.rdchem.BondType.AROMATIC
else:
continue
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
try:
mol = mol.GetMol()
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
mol = None
return mol
def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'):
# define path to save figures
if not os.path.exists(path):
os.makedirs(path)
# visualize the final molecules
print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}")
if num_molecules_to_visualize > len(molecules):
print(f"Shortening to {len(molecules)}")
num_molecules_to_visualize = len(molecules)
for i in range(num_molecules_to_visualize):
file_path = os.path.join(path, 'molecule_{}.png'.format(i))
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
try:
Draw.MolToFile(mol, file_path)
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None):
RDLogger.DisableLog('rdApp.*')
# convert graphs to the rdkit molecules
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
# find the coordinates of atoms in the final molecule
final_molecule = mols[-1]
AllChem.Compute2DCoords(final_molecule)
coords = []
for i, atom in enumerate(final_molecule.GetAtoms()):
positions = final_molecule.GetConformer().GetAtomPosition(i)
coords.append((positions.x, positions.y, positions.z))
# align all the molecules
for i, mol in enumerate(mols):
AllChem.Compute2DCoords(mol)
conf = mol.GetConformer()
for j, atom in enumerate(mol.GetAtoms()):
x, y, z = coords[j]
conf.SetAtomPosition(j, Point3D(x, y, z))
# draw gif
save_paths = []
num_frams = nodes_list.shape[0]
for frame in range(num_frams):
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}")
save_paths.append(file_name)
imgs = [imageio.imread(fn) for fn in save_paths]
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
# draw grid image
try:
img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200))
img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1])))
except Chem.rdchem.KekulizeException:
print("Can't kekulize molecule")
return mols
def visualize_by_smiles(self, path: str, smiles_list: list, num_to_visualize: int):
os.makedirs(path, exist_ok=True)
print(f"Visualizing corrected {num_to_visualize} of {len(smiles_list)}")
if num_to_visualize > len(smiles_list):
print(f"Shortening to {len(smiles_list)}")
num_to_visualize = len(smiles_list)
for i in range(num_to_visualize):
file_path = os.path.join(path, 'molecule_corrected_{}.png'.format(i))
if smiles_list[i] is None:
continue
mol = Chem.MolFromSmiles(smiles_list[i])
try:
Draw.MolToFile(mol, file_path)
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
class NonMolecularVisualization:
def to_networkx(self, node_list, adjacency_matrix):
"""
Convert graphs to networkx graphs
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
graph = nx.Graph()
for i in range(len(node_list)):
if node_list[i] == -1:
continue
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
rows, cols = np.where(adjacency_matrix >= 1)
edges = zip(rows.tolist(), cols.tolist())
for edge in edges:
edge_type = adjacency_matrix[edge[0]][edge[1]]
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
return graph
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False):
if largest_component:
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
graph = CGs[0]
# Plot the graph structure with colors
if pos is None:
pos = nx.spring_layout(graph, iterations=iterations)
# Set node colors based on the eigenvectors
w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray())
vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1])
m = max(np.abs(vmin), vmax)
vmin, vmax = -m, m
plt.figure()
nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
plt.tight_layout()
plt.savefig(path)
plt.close("all")
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
# define path to save figures
if not os.path.exists(path):
os.makedirs(path)
# visualize the final molecules
for i in range(num_graphs_to_visualize):
file_path = os.path.join(path, 'graph_{}.png'.format(i))
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
im = plt.imread(file_path)
def visualize_chain(self, path, nodes_list, adjacency_matrix):
# convert graphs to networkx
graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
# find the coordinates of atoms in the final molecule
final_graph = graphs[-1]
final_pos = nx.spring_layout(final_graph, seed=0)
# draw gif
save_paths = []
num_frams = nodes_list.shape[0]
for frame in range(num_frams):
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name)
save_paths.append(file_name)
imgs = [imageio.imread(fn) for fn in save_paths]
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)

0
mcd/datasets/__init__.py Normal file
View File

View File

@ -0,0 +1,126 @@
from diffusion.distributions import DistributionNodes
import utils as utils
import torch
import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
class AbstractDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.dataloaders = None
self.input_dims = None
self.output_dims = None
def prepare_data(self, datasets) -> None:
batch_size = self.cfg.train.batch_size
num_workers = self.cfg.train.num_workers
self.dataloaders = {split: DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
shuffle='debug' not in self.cfg.general.name)
for split, dataset in datasets.items()}
def train_dataloader(self):
return self.dataloaders["train"]
def val_dataloader(self):
return self.dataloaders["val"]
def test_dataloader(self):
return self.dataloaders["test"]
def __getitem__(self, idx):
return self.dataloaders['train'][idx]
def node_counts(self, max_nodes_possible=300):
all_counts = torch.zeros(max_nodes_possible)
for split in ['train', 'val', 'test']:
for i, data in enumerate(self.dataloaders[split]):
unique, counts = torch.unique(data.batch, return_counts=True)
for count in counts:
all_counts[count] += 1
max_index = max(all_counts.nonzero())
all_counts = all_counts[:max_index + 1]
all_counts = all_counts / all_counts.sum()
return all_counts
def node_types(self):
num_classes = None
for data in self.dataloaders['train']:
num_classes = data.x.shape[1]
break
counts = torch.zeros(num_classes)
for split in ['train', 'val', 'test']:
for i, data in enumerate(self.dataloaders[split]):
counts += data.x.sum(dim=0)
counts = counts / counts.sum()
return counts
def edge_counts(self):
num_classes = None
for data in self.dataloaders['train']:
num_classes = 5
break
d = torch.Tensor(num_classes)
for split in ['train', 'val', 'test']:
for i, data in enumerate(self.dataloaders[split]):
unique, counts = torch.unique(data.batch, return_counts=True)
all_pairs = 0
for count in counts:
all_pairs += count * (count - 1)
num_edges = data.edge_index.shape[1]
num_non_edges = all_pairs - num_edges
data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float()
edge_types = data_edge_attr.sum(dim=0)
assert num_non_edges >= 0
d[0] += num_non_edges
d[1:] += edge_types[1:]
d = d / d.sum()
return d
class MolecularDataModule(AbstractDataModule):
def valency_count(self, max_n_nodes):
valencies = torch.zeros(3 * max_n_nodes - 2) # Max valency possible if everything is connected
multiplier = torch.Tensor([0, 1, 2, 3, 1.5])
for split in ['train', 'val', 'test']:
for i, data in enumerate(self.dataloaders[split]):
n = data.x.shape[0]
for atom in range(n):
data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float()
edges = data_edge_attr[data.edge_index[0] == atom]
edges_total = edges.sum(dim=0)
valency = (edges_total * multiplier).sum()
valencies[valency.long().item()] += 1
valencies = valencies / valencies.sum()
return valencies
class AbstractDatasetInfos:
def complete_infos(self, n_nodes, node_types):
self.input_dims = None
self.output_dims = None
self.num_classes = len(node_types)
self.max_n_nodes = len(n_nodes) - 1
self.nodes_dist = DistributionNodes(n_nodes)
def compute_input_output_dims(self, datamodule):
example_batch = datamodule.example_batch()
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float()
self.input_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1),
'y': example_batch['y'].size(1)}
self.output_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1),
'y': example_batch['y'].size(1)}

381
mcd/datasets/dataset.py Normal file
View File

@ -0,0 +1,381 @@
import sys
sys.path.append('../')
import os
import os.path as osp
import pathlib
import json
import torch
import torch.nn.functional as F
from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
class DataModule(AbstractDataModule):
def __init__(self, cfg):
self.datadir = cfg.dataset.datadir
self.task = cfg.dataset.task_name
super().__init__(cfg)
def prepare_data(self) -> None:
target = getattr(self.cfg.dataset, 'guidance_target', None)
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
root_path = os.path.join(base_path, self.datadir)
self.root_path = root_path
batch_size = self.cfg.train.batch_size
num_workers = self.cfg.train.num_workers
pin_memory = self.cfg.dataset.pin_memory
dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
if len(self.task.split('-')) == 2:
train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
else:
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
if len(unlabeled_index) > 0:
train_index = torch.cat([train_index, unlabeled_index], dim=0)
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
self.train_dataset = train_dataset
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
training_iterations = len(train_dataset) // batch_size
self.training_iterations = training_iterations
def random_data_split(self, dataset):
nan_count = torch.isnan(dataset.y[:, 0]).sum().item()
labeled_len = len(dataset) - nan_count
full_idx = list(range(labeled_len))
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
unlabeled_index = list(range(labeled_len, len(dataset)))
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
return train_index, val_index, test_index, unlabeled_index
def fixed_split(self, dataset):
if self.task == 'O2-N2':
test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
else:
raise ValueError('Invalid task name: {}'.format(self.task))
full_idx = list(range(len(dataset)))
full_idx = list(set(full_idx) - set(test_index))
train_ratio = 0.8
train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
return train_index, val_index, test_index, []
def get_train_smiles(self):
filename = f'{self.task}.csv.gz'
df = pd.read_csv(f'{self.root_path}/raw/{filename}')
df_test = df.iloc[self.test_index]
df = df.iloc[self.train_index]
smiles_list = df['smiles'].tolist()
smiles_list_test = df_test['smiles'].tolist()
smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
return smiles_list, smiles_list_test
def get_data_split(self):
filename = f'{self.task}.csv.gz'
df = pd.read_csv(f'{self.root_path}/raw/{filename}')
df_val = df.iloc[self.val_index]
df_test = df.iloc[self.test_index]
df_train = df.iloc[self.train_index]
return df_train, df_val, df_test
def example_batch(self):
return next(iter(self.val_loader))
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.val_loader
def test_dataloader(self):
return self.test_loader
class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None,
transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop
self.source = source
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return [f'{self.source}.csv.gz']
@property
def processed_file_names(self):
return [f'{self.source}.pt']
def process(self):
RDLogger.DisableLog('rdApp.*')
data_path = osp.join(self.raw_dir, self.raw_file_names[0])
data_df = pd.read_csv(data_path)
def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None):
type_idx = []
heavy_atom_indices, active_atoms = [], []
for atom in mol.GetAtoms():
if atom.GetAtomicNum() != 1:
type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2)
heavy_atom_indices.append(atom.GetIdx())
active_atoms.append(atom.GetSymbol())
if valid_atoms is not None:
if not atom.GetSymbol() in valid_atoms:
return None, None
x = torch.LongTensor(type_idx)
edges_list = []
edge_type = []
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
if start in heavy_atom_indices and end in heavy_atom_indices:
start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end)
edges_list.append((start_new, end_new))
edge_type.append(bonds[bond.GetBondType()])
edges_list.append((end_new, start_new))
edge_type.append(bonds[bond.GetBondType()])
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
edge_type = torch.tensor(edge_type, dtype=torch.long)
edge_attr = edge_type
if target3 is not None:
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1)
elif target2 is not None:
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1)
else:
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
if self.pre_transform is not None:
data = self.pre_transform(data)
return data, active_atoms
# Loop through every row in the DataFrame and apply the function
data_list = []
len_data = len(data_df)
with tqdm(total=len_data) as pbar:
# --- data processing start ---
active_atoms = set()
for i, (sms, df_row) in enumerate(data_df.iterrows()):
if i == sms:
sms = df_row['smiles']
mol = Chem.MolFromSmiles(sms, sanitize=False)
if len(self.target_prop.split('-')) == 2:
target1, target2 = self.target_prop.split('-')
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2])
elif len(self.target_prop.split('-')) == 3:
target1, target2, target3 = self.target_prop.split('-')
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3])
else:
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop])
active_atoms.update(cur_active_atoms)
data_list.append(data)
pbar.update(1)
torch.save(self.collate(data_list), self.processed_paths[0])
class DataInfos(AbstractDatasetInfos):
def __init__(self, datamodule, cfg):
tasktype_dict = {
'hiv_b': 'classification',
'bace_b': 'classification',
'bbbp_b': 'classification',
'O2': 'regression',
'N2': 'regression',
'CO2': 'regression',
}
task_name = cfg.dataset.task_name
self.task = task_name
self.task_type = tasktype_dict.get(task_name, "regression")
self.ensure_connected = cfg.model.ensure_connected
datadir = cfg.dataset.datadir
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json')
data_root = os.path.join(base_path, datadir, 'raw')
if os.path.exists(meta_filename):
with open(meta_filename, 'r') as f:
meta_dict = json.load(f)
else:
meta_dict = compute_meta(data_root, task_name, datamodule.train_index, datamodule.test_index)
self.base_path = base_path
self.active_atoms = meta_dict['active_atoms']
self.max_n_nodes = meta_dict['max_node']
self.original_max_n_nodes = meta_dict['max_node']
self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist'])
self.edge_types = torch.Tensor(meta_dict['bond_type_dist'])
self.transition_E = torch.Tensor(meta_dict['transition_E'])
self.atom_decoder = meta_dict['active_atoms']
node_types = torch.Tensor(meta_dict['atom_type_dist'])
active_index = (node_types > 0).nonzero().squeeze()
self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index]
self.nodes_dist = DistributionNodes(self.n_nodes)
self.active_index = active_index
val_len = 3 * self.original_max_n_nodes - 2
meta_val = torch.Tensor(meta_dict['valencies'])
self.valency_distribution = torch.zeros(val_len)
val_len = min(val_len, len(meta_val))
self.valency_distribution[:val_len] = meta_val[:val_len]
self.y_prior = None
self.train_ymin = []
self.train_ymax = []
def compute_meta(root, source_name, train_index, test_index):
pt = Chem.GetPeriodicTable()
atom_name_list = []
atom_count_list = []
for i in range(2, 119):
atom_name_list.append(pt.GetElementSymbol(i))
atom_count_list.append(0)
atom_name_list.append('*')
atom_count_list.append(0)
n_atoms_per_mol = [0] * 500
bond_count_list = [0, 0, 0, 0, 0]
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
valencies = [0] * 500
tansition_E = np.zeros((118, 118, 5))
filename = f'{source_name}.csv.gz'
df = pd.read_csv(f'{root}/{filename}')
all_index = list(range(len(df)))
non_test_index = list(set(all_index) - set(test_index))
df = df.iloc[non_test_index]
tot_smiles = df['smiles'].tolist()
n_atom_list = []
n_bond_list = []
for i, sms in enumerate(tot_smiles):
try:
mol = Chem.MolFromSmiles(sms)
except:
continue
n_atom = mol.GetNumHeavyAtoms()
n_bond = mol.GetNumBonds()
n_atom_list.append(n_atom)
n_bond_list.append(n_bond)
n_atoms_per_mol[n_atom] += 1
cur_atom_count_arr = np.zeros(118)
for atom in mol.GetAtoms():
symbol = atom.GetSymbol()
if symbol == 'H':
continue
elif symbol == '*':
atom_count_list[-1] += 1
cur_atom_count_arr[-1] += 1
else:
atom_count_list[atom.GetAtomicNum()-2] += 1
cur_atom_count_arr[atom.GetAtomicNum()-2] += 1
try:
valencies[int(atom.GetExplicitValence())] += 1
except:
print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence()))
tansition_E_temp = np.zeros((118, 118, 5))
for bond in mol.GetBonds():
start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H':
continue
if start_atom.GetSymbol() == '*':
start_index = 117
else:
start_index = start_atom.GetAtomicNum() - 2
if end_atom.GetSymbol() == '*':
end_index = 117
else:
end_index = end_atom.GetAtomicNum() - 2
bond_type = bond.GetBondType()
bond_index = bond_type_to_index[bond_type]
bond_count_list[bond_index] += 2
tansition_E[start_index, end_index, bond_index] += 2
tansition_E[end_index, start_index, bond_index] += 2
tansition_E_temp[start_index, end_index, bond_index] += 2
tansition_E_temp[end_index, start_index, bond_index] += 2
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118
tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1)
assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
print('processed meta info: ------', filename, '------')
print('len atom_count_list', len(atom_count_list))
print('len atom_name_list', len(atom_name_list))
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
active_atoms = active_atoms.tolist()
atom_count_list = atom_count_list.tolist()
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
bond_count_list = bond_count_list.tolist()
valencies = np.array(valencies) / np.sum(valencies)
valencies = valencies.tolist()
no_edge = np.sum(tansition_E, axis=-1) == 0
first_elt = tansition_E[:, :, 0]
first_elt[no_edge] = 1
tansition_E[:, :, 0] = first_elt
tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True)
meta_dict = {
'source': source_name,
'num_graph': len(n_atom_list),
'n_atoms_per_mol_dist': n_atoms_per_mol,
'max_node': max(n_atom_list),
'max_bond': max(n_bond_list),
'atom_type_dist': atom_count_list,
'bond_type_dist': bond_count_list,
'valencies': valencies,
'active_atoms': active_atoms,
'num_atom_type': len(active_atoms),
'transition_E': tansition_E.tolist(),
}
with open(f'{root}/{source_name}.meta.json', "w") as f:
json.dump(meta_dict, f)
return meta_dict
if __name__ == "__main__":
pass

View File

View File

@ -0,0 +1,224 @@
import torch
from torch.nn import functional as F
import numpy as np
from utils import PlaceHolder
def sum_except_batch(x):
return x.reshape(x.size(0), -1).sum(dim=-1)
def assert_correctly_masked(variable, node_mask):
assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \
'Variables not masked properly.'
def cosine_beta_schedule_discrete(timesteps, s=0.008):
""" Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """
steps = timesteps + 2
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas = 1 - alphas
return betas.squeeze()
def custom_beta_schedule_discrete(timesteps, average_num_nodes=30, s=0.008):
""" Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """
steps = timesteps + 2
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas = 1 - alphas
assert timesteps >= 100
p = 4 / 5 # 1 - 1 / num_edge_classes
num_edges = average_num_nodes * (average_num_nodes - 1) / 2
# First 100 steps: only a few updates per graph
updates_per_graph = 1.2
beta_first = updates_per_graph / (p * num_edges)
betas[betas < beta_first] = beta_first
return np.array(betas)
def check_mask_correct(variables, node_mask):
for i, variable in enumerate(variables):
if len(variable) > 0:
assert_correctly_masked(variable, node_mask)
def check_tensor_same_size(*args):
for i, arg in enumerate(args):
if i == 0:
continue
assert args[0].size() == arg.size()
def reverse_tensor(x):
return x[torch.arange(x.size(0) - 1, -1, -1)]
def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True):
''' Sample features from multinomial distribution with given probabilities (probX, probE, proby)
:param probX: bs, n, dx_out node features
:param probE: bs, n, n, de_out edge features
:param proby: bs, dy_out global features.
'''
bs, n, _ = probX.shape
# Noise X
# The masked rows should define probability distributions as well
probX[~node_mask] = 1 / probX.shape[-1]
# Flatten the probability tensor to sample with multinomial
probX = probX.reshape(bs * n, -1) # (bs * n, dx_out)
# Sample X
probX = probX + 1e-12
probX = probX / probX.sum(dim=-1, keepdim=True)
X_t = probX.multinomial(1) # (bs * n, 1)
X_t = X_t.reshape(bs, n) # (bs, n)
# Noise E
# The masked rows should define probability distributions as well
inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))
diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1)
probE[inverse_edge_mask] = 1 / probE.shape[-1]
probE[diag_mask.bool()] = 1 / probE.shape[-1]
probE = probE.reshape(bs * n * n, -1) # (bs * n * n, de_out)
probE = probE + 1e-12
probE = probE / probE.sum(dim=-1, keepdim=True)
# Sample E
E_t = probE.multinomial(1).reshape(bs, n, n) # (bs, n, n)
E_t = torch.triu(E_t, diagonal=1)
E_t = (E_t + torch.transpose(E_t, 1, 2))
return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t))
def compute_batched_over0_posterior_distribution(X_t, Qt, Qsb, Qtb):
""" M: X or E
Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0
X_t: bs, n, dt or bs, n, n, dt
Qt: bs, d_t-1, dt
Qsb: bs, d0, d_t-1
Qtb: bs, d0, dt.
"""
X_t = X_t.float()
Qt_T = Qt.transpose(-1, -2).float() # bs, N, dt
assert Qt.dim() == 3
left_term = X_t @ Qt_T
left_term = left_term.unsqueeze(dim=2) # bs, N, 1, d_t-1
right_term = Qsb.unsqueeze(1)
numerator = left_term * right_term # bs, N, d0, d_t-1
denominator = Qtb @ X_t.transpose(-1, -2) # bs, d0, N
denominator = denominator.transpose(-1, -2) # bs, N, d0
denominator = denominator.unsqueeze(-1) # bs, N, d0, 1
denominator[denominator == 0] = 1.
return numerator / denominator
def mask_distributions(true_X, true_E, pred_X, pred_E, node_mask):
# Add a small value everywhere to avoid nans
pred_X = pred_X.clamp_min(1e-18)
pred_X = pred_X / torch.sum(pred_X, dim=-1, keepdim=True)
pred_E = pred_E.clamp_min(1e-18)
pred_E = pred_E / torch.sum(pred_E, dim=-1, keepdim=True)
# Set masked rows to arbitrary distributions, so it doesn't contribute to loss
row_X = torch.ones(true_X.size(-1), dtype=true_X.dtype, device=true_X.device)
row_E = torch.zeros(true_E.size(-1), dtype=true_E.dtype, device=true_E.device).clamp_min(1e-18)
row_E[0] = 1.
diag_mask = ~torch.eye(node_mask.size(1), device=node_mask.device, dtype=torch.bool).unsqueeze(0)
true_X[~node_mask] = row_X
true_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E
pred_X[~node_mask] = row_X
pred_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E
return true_X, true_E, pred_X, pred_E
def posterior_distributions(X, X_t, Qt, Qsb, Qtb, X_dim):
bs, n, d = X.shape
X = X.float()
Qt_X_T = torch.transpose(Qt.X, -2, -1).float() # (bs, d, d)
left_term = X_t @ Qt_X_T # (bs, N, d)
right_term = X @ Qsb.X # (bs, N, d)
numerator = left_term * right_term # (bs, N, d)
denominator = X @ Qtb.X # (bs, N, d) @ (bs, d, d) = (bs, N, d)
denominator = denominator * X_t
num_X = numerator[:, :, :X_dim]
num_E = numerator[:, :, X_dim:].reshape(bs, n*n, -1)
deno_X = denominator[:, :, :X_dim]
deno_E = denominator[:, :, X_dim:].reshape(bs, n*n, -1)
# denominator = (denominator * X_t).sum(dim=-1) # (bs, N, d) * (bs, N, d) + sum = (bs, N)
denominator = denominator.unsqueeze(-1) # (bs, N, 1)
deno_X = deno_X.sum(dim=-1).unsqueeze(-1)
deno_E = deno_E.sum(dim=-1).unsqueeze(-1)
deno_X[deno_X == 0.] = 1
deno_E[deno_E == 0.] = 1
prob_X = num_X / deno_X
prob_E = num_E / deno_E
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True)
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True)
return PlaceHolder(X=prob_X, E=prob_E, y=None)
def sample_discrete_feature_noise(limit_dist, node_mask):
""" Sample from the limit distribution of the diffusion process"""
bs, n_max = node_mask.shape
x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1)
U_X = x_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max)
U_X = F.one_hot(U_X.long(), num_classes=x_limit.shape[-1]).float()
e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1)
U_E = e_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max, n_max)
U_E = F.one_hot(U_E.long(), num_classes=e_limit.shape[-1]).float()
U_X = U_X.to(node_mask.device)
U_E = U_E.to(node_mask.device)
# Get upper triangular part of edge noise, without main diagonal
upper_triangular_mask = torch.zeros_like(U_E)
indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1)
upper_triangular_mask[:, indices[0], indices[1], :] = 1
U_E = U_E * upper_triangular_mask
U_E = (U_E + torch.transpose(U_E, 1, 2))
assert (U_E == torch.transpose(U_E, 1, 2)).all()
return PlaceHolder(X=U_X, E=U_E, y=None).mask(node_mask)
def index_QE(X, q_e, n_bond=5):
bs, n, n_atom = X.shape
node_indices = X.argmax(-1) # (bs, n)
exp_ind1 = node_indices[ :, :, None, None, None].expand(bs, n, n_atom, n_bond, n_bond)
exp_ind2 = node_indices[ :, :, None, None, None].expand(bs, n, n, n_bond, n_bond)
q_e = torch.gather(q_e, 1, exp_ind1)
q_e = torch.gather(q_e, 2, exp_ind2) # (bs, n, n, n_bond, n_bond)
node_mask = X.sum(-1) != 0
no_edge = (~node_mask)[:, :, None] & (~node_mask)[:, None, :]
q_e[no_edge] = torch.tensor([1, 0, 0, 0, 0]).type_as(q_e)
return q_e

View File

@ -0,0 +1,30 @@
import torch
class DistributionNodes:
def __init__(self, histogram):
""" Compute the distribution of the number of nodes in the dataset, and sample from this distribution.
historgram: dict. The keys are num_nodes, the values are counts
"""
if type(histogram) == dict:
max_n_nodes = max(histogram.keys())
prob = torch.zeros(max_n_nodes + 1)
for num_nodes, count in histogram.items():
prob[num_nodes] = count
else:
prob = histogram
self.prob = prob / prob.sum()
self.m = torch.distributions.Categorical(prob)
def sample_n(self, n_samples, device):
idx = self.m.sample((n_samples,))
return idx.to(device)
def log_prob(self, batch_n_nodes):
assert len(batch_n_nodes.size()) == 1
p = self.prob.to(batch_n_nodes.device)
probas = p[batch_n_nodes]
log_p = torch.log(probas + 1e-30)
return log_p

View File

@ -0,0 +1,159 @@
import torch
import utils
from diffusion import diffusion_utils
class PredefinedNoiseScheduleDiscrete(torch.nn.Module):
def __init__(self, noise_schedule, timesteps):
super(PredefinedNoiseScheduleDiscrete, self).__init__()
self.timesteps = timesteps
if noise_schedule == 'cosine':
betas = diffusion_utils.cosine_beta_schedule_discrete(timesteps)
elif noise_schedule == 'custom':
betas = diffusion_utils.custom_beta_schedule_discrete(timesteps)
else:
raise NotImplementedError(noise_schedule)
self.register_buffer('betas', torch.from_numpy(betas).float())
# 0.9999
self.alphas = 1 - torch.clamp(self.betas, min=0, max=1)
log_alpha = torch.log(self.alphas)
log_alpha_bar = torch.cumsum(log_alpha, dim=0)
self.alphas_bar = torch.exp(log_alpha_bar)
def forward(self, t_normalized=None, t_int=None):
assert int(t_normalized is None) + int(t_int is None) == 1
if t_int is None:
t_int = torch.round(t_normalized * self.timesteps)
return self.betas[t_int.long()]
def get_alpha_bar(self, t_normalized=None, t_int=None):
assert int(t_normalized is None) + int(t_int is None) == 1
if t_int is None:
t_int = torch.round(t_normalized * self.timesteps)
### new
self.alphas_bar = self.alphas_bar.to(t_int.device)
return self.alphas_bar[t_int.long()]
class DiscreteUniformTransition:
def __init__(self, x_classes: int, e_classes: int, y_classes: int):
self.X_classes = x_classes
self.E_classes = e_classes
self.y_classes = y_classes
self.u_x = torch.ones(1, self.X_classes, self.X_classes)
if self.X_classes > 0:
self.u_x = self.u_x / self.X_classes
self.u_e = torch.ones(1, self.E_classes, self.E_classes)
if self.E_classes > 0:
self.u_e = self.u_e / self.E_classes
self.u_y = torch.ones(1, self.y_classes, self.y_classes)
if self.y_classes > 0:
self.u_y = self.u_y / self.y_classes
def get_Qt(self, beta_t, device, X=None, flatten_e=None):
""" Returns one-step transition matrices for X and E, from step t - 1 to step t.
Qt = (1 - beta_t) * I + beta_t / K
beta_t: (bs) noise level between 0 and 1
returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy).
"""
beta_t = beta_t.unsqueeze(1)
beta_t = beta_t.to(device)
self.u_x = self.u_x.to(device)
self.u_e = self.u_e.to(device)
self.u_y = self.u_y.to(device)
q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes, device=device).unsqueeze(0)
q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye(self.E_classes, device=device).unsqueeze(0)
q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye(self.y_classes, device=device).unsqueeze(0)
return utils.PlaceHolder(X=q_x, E=q_e, y=q_y)
def get_Qt_bar(self, alpha_bar_t, device, X=None, flatten_e=None):
""" Returns t-step transition matrices for X and E, from step 0 to step t.
Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) / K
alpha_bar_t: (bs) Product of the (1 - beta_t) for each time step from 0 to t.
returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy).
"""
alpha_bar_t = alpha_bar_t.unsqueeze(1)
alpha_bar_t = alpha_bar_t.to(device)
self.u_x = self.u_x.to(device)
self.u_e = self.u_e.to(device)
self.u_y = self.u_y.to(device)
q_x = alpha_bar_t * torch.eye(self.X_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_x
q_e = alpha_bar_t * torch.eye(self.E_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_e
q_y = alpha_bar_t * torch.eye(self.y_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_y
return utils.PlaceHolder(X=q_x, E=q_e, y=q_y)
class MarginalTransition:
def __init__(self, x_marginals, e_marginals, xe_conditions, ex_conditions, y_classes, n_nodes):
self.X_classes = len(x_marginals)
self.E_classes = len(e_marginals)
self.y_classes = y_classes
self.x_marginals = x_marginals # Dx
self.e_marginals = e_marginals # Dx, De
self.xe_conditions = xe_conditions
self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) # 1, Dx, Dx
self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) # 1, De, De
self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De
self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx
self.u = self.get_union_transition(self.u_x, self.u_e, self.u_xe, self.u_ex, n_nodes) # 1, Dx + n*De, Dx + n*De
def get_union_transition(self, u_x, u_e, u_xe, u_ex, n_nodes):
u_e = u_e.repeat(1, n_nodes, n_nodes) # (1, n*de, n*de)
u_xe = u_xe.repeat(1, 1, n_nodes) # (1, dx, n*de)
u_ex = u_ex.repeat(1, n_nodes, 1) # (1, n*de, dx)
u0 = torch.cat([u_x, u_xe], dim=2) # (1, dx, dx + n*de)
u1 = torch.cat([u_ex, u_e], dim=2) # (1, n*de, dx + n*de)
u = torch.cat([u0, u1], dim=1) # (1, dx + n*de, dx + n*de)
return u
def index_edge_margin(self, X, q_e, n_bond=5):
# q_e: (bs, dx, de) --> (bs, n, de)
bs, n, n_atom = X.shape
node_indices = X.argmax(-1) # (bs, n)
ind = node_indices[ :, :, None].expand(bs, n, n_bond)
q_e = torch.gather(q_e, 1, ind)
return q_e
def get_Qt(self, beta_t, device):
""" Returns one-step transition matrices for X and E, from step t - 1 to step t.
Qt = (1 - beta_t) * I + beta_t / K
beta_t: (bs)
returns: q (bs, d0, d0)
"""
bs = beta_t.size(0)
d0 = self.u.size(-1)
self.u = self.u.to(device)
u = self.u.expand(bs, d0, d0)
beta_t = beta_t.to(device)
beta_t = beta_t.view(bs, 1, 1)
q = beta_t * u + (1 - beta_t) * torch.eye(d0, device=device).unsqueeze(0)
return utils.PlaceHolder(X=q, E=None, y=None)
def get_Qt_bar(self, alpha_bar_t, device):
""" Returns t-step transition matrices for X and E, from step 0 to step t.
Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) * K
alpha_bar_t: (bs, 1) roduct of the (1 - beta_t) for each time step from 0 to t.
returns: q (bs, d0, d0)
"""
bs = alpha_bar_t.size(0)
d0 = self.u.size(-1)
alpha_bar_t = alpha_bar_t.to(device)
alpha_bar_t = alpha_bar_t.view(bs, 1, 1)
self.u = self.u.to(device)
q = alpha_bar_t * torch.eye(d0, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u
return utils.PlaceHolder(X=q, E=None, y=None)

619
mcd/diffusion_model.py Normal file
View File

@ -0,0 +1,619 @@
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import time
import os
from models.transformer import Denoiser
from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition
from diffusion import diffusion_utils
from metrics.train_loss import TrainLossDiscrete
from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
import utils
class MCD(pl.LightningModule):
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
super().__init__()
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
self.test_only = cfg.general.test_only
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
input_dims = dataset_infos.input_dims
output_dims = dataset_infos.output_dims
nodes_dist = dataset_infos.nodes_dist
active_index = dataset_infos.active_index
self.cfg = cfg
self.name = cfg.general.name
self.T = cfg.model.diffusion_steps
self.guide_scale = cfg.model.guide_scale
self.Xdim = input_dims['X']
self.Edim = input_dims['E']
self.ydim = input_dims['y']
self.Xdim_output = output_dims['X']
self.Edim_output = output_dims['E']
self.ydim_output = output_dims['y']
self.node_dist = nodes_dist
self.active_index = active_index
self.dataset_info = dataset_infos
self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train)
self.val_nll = NLL()
self.val_X_kl = SumExceptBatchKL()
self.val_E_kl = SumExceptBatchKL()
self.val_X_logp = SumExceptBatchMetric()
self.val_E_logp = SumExceptBatchMetric()
self.val_y_collection = []
self.test_nll = NLL()
self.test_X_kl = SumExceptBatchKL()
self.test_E_kl = SumExceptBatchKL()
self.test_X_logp = SumExceptBatchMetric()
self.test_E_logp = SumExceptBatchMetric()
self.test_y_collection = []
self.train_metrics = train_metrics
self.sampling_metrics = sampling_metrics
self.visualization_tools = visualization_tools
self.max_n_nodes = dataset_infos.max_n_nodes
self.model = Denoiser(max_n_nodes=self.max_n_nodes,
hidden_size=cfg.model.hidden_size,
depth=cfg.model.depth,
num_heads=cfg.model.num_heads,
mlp_ratio=cfg.model.mlp_ratio,
drop_condition=cfg.model.drop_condition,
Xdim=self.Xdim,
Edim=self.Edim,
ydim=self.ydim,
task_type=dataset_infos.task_type)
self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule,
timesteps=cfg.model.diffusion_steps)
x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float())
e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float())
x_marginals = x_marginals / (x_marginals ).sum()
e_marginals = e_marginals / (e_marginals ).sum()
xe_conditions = self.dataset_info.transition_E.float()
xe_conditions = xe_conditions[self.active_index][:, self.active_index]
xe_conditions = xe_conditions.sum(dim=1)
ex_conditions = xe_conditions.t()
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
self.transition_model = MarginalTransition(x_marginals=x_marginals,
e_marginals=e_marginals,
xe_conditions=xe_conditions,
ex_conditions=ex_conditions,
y_classes=self.ydim_output,
n_nodes=self.max_n_nodes)
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
self.start_epoch_time = None
self.train_iterations = None
self.val_iterations = None
self.log_every_steps = cfg.general.log_every_steps
self.number_chain_steps = cfg.general.number_chain_steps
self.best_val_nll = 1e8
self.val_counter = 0
self.batch_size = self.cfg.train.batch_size
def forward(self, noisy_data, unconditioned=False):
x, e, y = noisy_data['X_t'].float(), noisy_data['E_t'].float(), noisy_data['y_t'].float().clone()
node_mask, t = noisy_data['node_mask'], noisy_data['t']
pred = self.model(x, e, node_mask, y=y, t=t, unconditioned=unconditioned)
return pred
def training_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask)
X, E = dense_data.X, dense_data.E
noisy_data = self.apply_noise(X, E, data.y, node_mask)
pred = self.forward(noisy_data)
loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
log=i % self.log_every_steps == 0)
self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
log=i % self.log_every_steps == 0)
self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
return {'loss': loss}
def configure_optimizers(self):
params = self.parameters()
optimizer = torch.optim.AdamW(params, lr=self.cfg.train.lr, amsgrad=True,
weight_decay=self.cfg.train.weight_decay)
return optimizer
def on_fit_start(self) -> None:
self.train_iterations = self.trainer.datamodule.training_iterations
print('on fit train iteration:', self.train_iterations)
print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim))
def on_train_epoch_start(self) -> None:
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
self.start_epoch_time = time.time()
self.train_loss.reset()
self.train_metrics.reset()
def on_train_epoch_end(self) -> None:
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
log = True
else:
log = False
self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log)
self.train_metrics.log_epoch_metrics(self.current_epoch, log)
def on_validation_epoch_start(self) -> None:
self.val_nll.reset()
self.val_X_kl.reset()
self.val_E_kl.reset()
self.val_X_logp.reset()
self.val_E_logp.reset()
self.sampling_metrics.reset()
self.val_y_collection = []
@torch.no_grad()
def validation_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask)
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
pred = self.forward(noisy_data)
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
self.val_y_collection.append(data.y)
self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True)
return {'loss': nll}
def on_validation_epoch_end(self) -> None:
metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T,
self.val_X_logp.compute(), self.val_E_logp.compute()]
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
# Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
self.log("val/NLL", metrics[0], sync_dist=True)
if metrics[0] < self.best_val_nll:
self.best_val_nll = metrics[0]
self.val_counter += 1
if self.val_counter % self.cfg.general.sample_every_val == 0 and self.val_counter > 1:
start = time.time()
samples_left_to_generate = self.cfg.general.samples_to_generate
samples_left_to_save = self.cfg.general.samples_to_save
chains_left_to_save = self.cfg.general.chains_to_save
samples, all_ys, ident = [], [], 0
self.val_y_collection = torch.cat(self.val_y_collection, dim=0)
num_examples = self.val_y_collection.size(0)
start_index = 0
while samples_left_to_generate > 0:
bs = 1 * self.cfg.train.batch_size
to_generate = min(samples_left_to_generate, bs)
to_save = min(samples_left_to_save, bs)
chains_save = min(chains_left_to_save, bs)
if start_index + to_generate > num_examples:
start_index = 0
if to_generate > num_examples:
ratio = to_generate // num_examples
self.val_y_collection = self.val_y_collection.repeat(ratio+1, 1)
num_examples = self.val_y_collection.size(0)
batch_y = self.val_y_collection[start_index:start_index + to_generate]
all_ys.append(batch_y)
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
save_final=to_save,
keep_chain=chains_save,
number_chain_steps=self.number_chain_steps))
ident += to_generate
start_index += to_generate
samples_left_to_save -= to_save
samples_left_to_generate -= to_generate
chains_left_to_save -= chains_save
print(f"Computing sampling metrics", ' ...')
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
current_path = os.getcwd()
result_path = os.path.join(current_path,
f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
self.sampling_metrics.reset()
def on_test_epoch_start(self) -> None:
print("Starting test...")
self.test_nll.reset()
self.test_X_kl.reset()
self.test_E_kl.reset()
self.test_X_logp.reset()
self.test_E_logp.reset()
self.test_y_collection = []
@torch.no_grad()
def test_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask)
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
pred = self.forward(noisy_data)
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True)
self.test_y_collection.append(data.y)
return {'loss': nll}
def on_test_epoch_end(self) -> None:
""" Measure likelihood on a test set and compute stability metrics. """
metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(),
self.test_X_logp.compute(), self.test_E_logp.compute()]
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
f"Test Edge type KL: {metrics[2] :.2f}")
## final epcoh
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
samples_left_to_save = self.cfg.general.final_model_samples_to_save
chains_left_to_save = self.cfg.general.final_model_chains_to_save
samples, all_ys, batch_id = [], [], 0
test_y_collection = torch.cat(self.test_y_collection, dim=0)
num_examples = test_y_collection.size(0)
if self.cfg.general.final_model_samples_to_generate > num_examples:
ratio = self.cfg.general.final_model_samples_to_generate // num_examples
test_y_collection = test_y_collection.repeat(ratio+1, 1)
num_examples = test_y_collection.size(0)
while samples_left_to_generate > 0:
print(f'samples left to generate: {samples_left_to_generate}/'
f'{self.cfg.general.final_model_samples_to_generate}', end='', flush=True)
bs = 1 * self.cfg.train.batch_size
to_generate = min(samples_left_to_generate, bs)
to_save = min(samples_left_to_save, bs)
chains_save = min(chains_left_to_save, bs)
batch_y = test_y_collection[batch_id : batch_id + to_generate]
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
samples = samples + cur_sample
all_ys.append(batch_y)
batch_id += to_generate
samples_left_to_save -= to_save
samples_left_to_generate -= to_generate
chains_left_to_save -= chains_save
print(f"final Computing sampling metrics...")
self.sampling_metrics.reset()
self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, self.val_counter, test=True)
self.sampling_metrics.reset()
print(f"Done.")
def kl_prior(self, X, E, node_mask):
"""Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).
This is essentially a lot of work for something that is in practice negligible in the loss. However, you
compute it so that you see it when you've made a mistake in your noise schedule.
"""
# Compute the last alpha value, alpha_T.
ones = torch.ones((X.size(0), 1), device=X.device)
Ts = self.T * ones
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_int=Ts) # (bs, 1)
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X
probX = prob_all[:, :, :self.Xdim_output]
probE = prob_all[:, :, self.Xdim_output:].reshape((bs, n, n, -1))
assert probX.shape == X.shape
limit_X = self.limit_dist.X[None, None, :].expand(bs, n, -1).type_as(probX)
limit_E = self.limit_dist.E[None, None, None, :].expand(bs, n, n, -1).type_as(probE)
# Make sure that masked rows do not contribute to the loss
limit_dist_X, limit_dist_E, probX, probE = diffusion_utils.mask_distributions(true_X=limit_X.clone(),
true_E=limit_E.clone(),
pred_X=probX,
pred_E=probE,
node_mask=node_mask)
kl_distance_X = F.kl_div(input=probX.log(), target=limit_dist_X, reduction='none')
kl_distance_E = F.kl_div(input=probE.log(), target=limit_dist_E, reduction='none')
return diffusion_utils.sum_except_batch(kl_distance_X) + \
diffusion_utils.sum_except_batch(kl_distance_E)
def compute_Lt(self, X, E, y, pred, noisy_data, node_mask, test):
pred_probs_X = F.softmax(pred.X, dim=-1)
pred_probs_E = F.softmax(pred.E, dim=-1)
Qtb = self.transition_model.get_Qt_bar(noisy_data['alpha_t_bar'], self.device)
Qsb = self.transition_model.get_Qt_bar(noisy_data['alpha_s_bar'], self.device)
Qt = self.transition_model.get_Qt(noisy_data['beta_t'], self.device)
# Compute distributions to compare with KL
bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1).float()
Xt_all = torch.cat([noisy_data['X_t'], noisy_data['E_t'].reshape(bs, n, -1)], dim=-1).float()
pred_probs_all = torch.cat([pred_probs_X, pred_probs_E.reshape(bs, n, -1)], dim=-1).float()
prob_true = diffusion_utils.posterior_distributions(X=X_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.Xdim_output)
prob_true.E = prob_true.E.reshape((bs, n, n, -1))
prob_pred = diffusion_utils.posterior_distributions(X=pred_probs_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.Xdim_output)
prob_pred.E = prob_pred.E.reshape((bs, n, n, -1))
# Reshape and filter masked rows
prob_true_X, prob_true_E, prob_pred.X, prob_pred.E = diffusion_utils.mask_distributions(true_X=prob_true.X,
true_E=prob_true.E,
pred_X=prob_pred.X,
pred_E=prob_pred.E,
node_mask=node_mask)
kl_x = (self.test_X_kl if test else self.val_X_kl)(prob_true.X, torch.log(prob_pred.X))
kl_e = (self.test_E_kl if test else self.val_E_kl)(prob_true.E, torch.log(prob_pred.E))
return self.T * (kl_x + kl_e)
def reconstruction_logp(self, t, X, E, y, node_mask):
# Compute noise values for t = 0.
t_zeros = torch.zeros_like(t)
beta_0 = self.noise_schedule(t_zeros)
Q0 = self.transition_model.get_Qt(beta_0, self.device)
bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Q0.X
probX0 = prob_all[:, :, :self.Xdim_output]
probE0 = prob_all[:, :, self.Xdim_output:].reshape((bs, n, n, -1))
sampled0 = diffusion_utils.sample_discrete_features(probX=probX0, probE=probE0, node_mask=node_mask)
X0 = F.one_hot(sampled0.X, num_classes=self.Xdim_output).float()
E0 = F.one_hot(sampled0.E, num_classes=self.Edim_output).float()
assert (X.shape == X0.shape) and (E.shape == E0.shape)
sampled_0 = utils.PlaceHolder(X=X0, E=E0, y=y).mask(node_mask)
# Predictions
noisy_data = {'X_t': sampled_0.X, 'E_t': sampled_0.E, 'y_t': sampled_0.y, 'node_mask': node_mask,
't': torch.zeros(X0.shape[0], 1).type_as(y)}
pred0 = self.forward(noisy_data)
# Normalize predictions
probX0 = F.softmax(pred0.X, dim=-1)
probE0 = F.softmax(pred0.E, dim=-1)
proby0 = None
# Set masked rows to arbitrary values that don't contribute to loss
probX0[~node_mask] = torch.ones(self.Xdim_output).type_as(probX0)
probE0[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))] = torch.ones(self.Edim_output).type_as(probE0)
diag_mask = torch.eye(probE0.size(1)).type_as(probE0).bool()
diag_mask = diag_mask.unsqueeze(0).expand(probE0.size(0), -1, -1)
probE0[diag_mask] = torch.ones(self.Edim_output).type_as(probE0)
return utils.PlaceHolder(X=probX0, E=probE0, y=proby0)
def apply_noise(self, X, E, y, node_mask):
""" Sample noise and apply it to the data. """
# Sample a timestep t.
# When evaluating, the loss for t=0 is computed separately
lowest_t = 0 if self.training else 1
t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1)
s_int = t_int - 1
t_float = t_int / self.T
s_float = s_int / self.T
# beta_t and alpha_s_bar are used for denoising/loss computation
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X
probX = prob_all[:, :, :self.Xdim_output]
probE = prob_all[:, :, self.Xdim_output:].reshape(bs, n, n, -1)
sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
y_t = y
z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask}
return noisy_data
def compute_val_loss(self, pred, noisy_data, X, E, y, node_mask, test=False):
"""Computes an estimator for the variational lower bound.
pred: (batch_size, n, total_features)
noisy_data: dict
X, E, y : (bs, n, dx), (bs, n, n, de), (bs, dy)
node_mask : (bs, n)
Output: nll (size 1)
"""
t = noisy_data['t']
# 1.
N = node_mask.sum(1).long()
log_pN = self.node_dist.log_prob(N)
# 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero.
kl_prior = self.kl_prior(X, E, node_mask)
# 3. Diffusion loss
loss_all_t = self.compute_Lt(X, E, y, pred, noisy_data, node_mask, test)
# 4. Reconstruction loss
# Compute L0 term : -log p (X, E, y | z_0) = reconstruction loss
prob0 = self.reconstruction_logp(t, X, E, y, node_mask)
eps = 1e-8
loss_term_0 = self.val_X_logp(X * (prob0.X+eps).log()) + self.val_E_logp(E * (prob0.E+eps).log())
# Combine terms
nlls = - log_pN + kl_prior + loss_all_t - loss_term_0
assert len(nlls.shape) == 1, f'{nlls.shape} has more than only batch dim.'
# Update NLL metric object and return batch nll
nll = (self.test_nll if test else self.val_nll)(nlls) # Average over the batch
return nll
@torch.no_grad()
def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None):
"""
:param batch_id: int
:param batch_size: int
:param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes
:param save_final: int: number of predictions to save to file
:param keep_chain: int: number of chains to save to file (disabled)
:param keep_chain_steps: number of timesteps to save for each chain (disabled)
:return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions)
"""
if num_nodes is None:
n_nodes = self.node_dist.sample_n(batch_size, self.device)
elif type(num_nodes) == int:
n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int)
else:
assert isinstance(num_nodes, torch.Tensor)
n_nodes = num_nodes
n_max = self.max_n_nodes
arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1)
node_mask = arange < n_nodes.unsqueeze(1)
z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask)
X, E = z_T.X, z_T.E
assert (E == torch.transpose(E, 1, 2)).all()
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)):
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
t_array = s_array + 1
s_norm = s_array / self.T
t_norm = t_array / self.T
# Sample z_s
sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
# Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
molecule_list = []
for i in range(batch_size):
n = n_nodes[i]
atom_types = X[i, :n].cpu()
edge_types = E[i, :n, :n].cpu()
molecule_list.append([atom_types, edge_types])
return molecule_list
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
"""Samples from zs ~ p(zs | zt). Only used during sampling.
if last_step, return the graph prediction as well"""
bs, n, dxs = X_t.shape
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
# Neural net predictions
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask}
def get_prob(noisy_data, unconditioned=False):
pred = self.forward(noisy_data, unconditioned=unconditioned)
# Normalize predictions
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
# Retrieve transitions matrix
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device)
Qt = self.transition_model.get_Qt(beta_t, self.device)
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all,
Qt=Qt.X,
Qsb=Qsb.X,
Qtb=Qtb.X)
predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
weightedX_all = predX_all.unsqueeze(-1) * p_s_and_t_given_0
unnormalized_probX_all = weightedX_all.sum(dim=2) # bs, n, d_t-1
unnormalized_prob_X = unnormalized_probX_all[:, :, :self.Xdim_output]
unnormalized_prob_E = unnormalized_probX_all[:, :, self.Xdim_output:].reshape(bs, n*n, -1)
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True) # bs, n, d_t-1
prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True) # bs, n, d_t-1
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
return prob_X, prob_E
prob_X, prob_E = get_prob(noisy_data)
### Guidance
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True)
prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale
prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale
prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10)
prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10)
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
assert (E_s == torch.transpose(E_s, 1, 2)).all()
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t)
return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)

138
mcd/main.py Normal file
View File

@ -0,0 +1,138 @@
# These imports are tricky because they use c++, do not move them
import os, shutil
import warnings
import torch
import hydra
from omegaconf import DictConfig
from pytorch_lightning import Trainer
import utils
from datasets import dataset
from diffusion_model import MCD
from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete
from metrics.molecular_metrics_sampling import SamplingMolecularMetrics
from analysis.visualization import MolecularVisualization
warnings.filterwarnings("ignore", category=UserWarning)
torch.set_float32_matmul_precision("medium")
def remove_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print("Failed to delete %s. Reason: %s" % (file_path, e))
def get_resume(cfg, model_kwargs):
"""Resumes a run. It loads previous config without allowing to update keys (used for testing)."""
saved_cfg = cfg.copy()
name = cfg.general.name + "_resume"
resume = cfg.general.test_only
batch_size = cfg.train.batch_size
model = MCD.load_from_checkpoint(resume, **model_kwargs)
cfg = model.cfg
cfg.general.test_only = resume
cfg.general.name = name
cfg.train.batch_size = batch_size
cfg = utils.update_config_with_new_keys(cfg, saved_cfg)
return cfg, model
def get_resume_adaptive(cfg, model_kwargs):
"""Resumes a run. It loads previous config but allows to make some changes (used for resuming training)."""
saved_cfg = cfg.copy()
# Fetch path to this file to get base path
current_path = os.path.dirname(os.path.realpath(__file__))
root_dir = current_path.split("outputs")[0]
resume_path = os.path.join(root_dir, cfg.general.resume)
if cfg.model.type == "discrete":
model = MCD.load_from_checkpoint(
resume_path, **model_kwargs
)
else:
raise NotImplementedError("Unknown model")
new_cfg = model.cfg
for category in cfg:
for arg in cfg[category]:
new_cfg[category][arg] = cfg[category][arg]
new_cfg.general.resume = resume_path
new_cfg.general.name = new_cfg.general.name + "_resume"
new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg)
return new_cfg, model
@hydra.main(
version_base="1.1", config_path="../configs", config_name="config_dev"
)
def main(cfg: DictConfig):
datamodule = dataset.DataModule(cfg)
datamodule.prepare_data()
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
train_smiles, reference_smiles = datamodule.get_train_smiles()
dataset_infos.compute_input_output_dims(datamodule=datamodule)
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
sampling_metrics = SamplingMolecularMetrics(
dataset_infos, train_smiles, reference_smiles
)
visualization_tools = MolecularVisualization(dataset_infos)
model_kwargs = {
"dataset_infos": dataset_infos,
"train_metrics": train_metrics,
"sampling_metrics": sampling_metrics,
"visualization_tools": visualization_tools,
}
if cfg.general.test_only:
# When testing, previous configuration is fully loaded
cfg, _ = get_resume(cfg, model_kwargs)
os.chdir(cfg.general.test_only.split("checkpoints")[0])
elif cfg.general.resume is not None:
# When resuming, we can override some parts of previous configuration
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
os.chdir(cfg.general.resume.split("checkpoints")[0])
model = MCD(cfg=cfg, **model_kwargs)
trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad,
accelerator="gpu"
if torch.cuda.is_available() and cfg.general.gpus > 0
else "cpu",
devices=cfg.general.gpus
if torch.cuda.is_available() and cfg.general.gpus > 0
else None,
max_epochs=cfg.train.n_epochs,
enable_checkpointing=False,
check_val_every_n_epoch=cfg.train.check_val_every_n_epoch,
val_check_interval=cfg.train.val_check_interval,
strategy="ddp" if cfg.general.gpus > 1 else "auto",
enable_progress_bar=cfg.general.enable_progress_bar,
callbacks=[],
reload_dataloaders_every_n_epochs=0,
logger=[],
)
if not cfg.general.test_only:
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume)
if cfg.general.save_model:
trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt")
trainer.test(model, datamodule=datamodule)
else:
trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
if __name__ == "__main__":
main()

0
mcd/metrics/__init__.py Normal file
View File

View File

@ -0,0 +1,138 @@
import torch
from torch import Tensor
from torch.nn import functional as F
from torchmetrics import Metric, MeanSquaredError
class TrainAbstractMetricsDiscrete(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
pass
def reset(self):
pass
def log_epoch_metrics(self, current_epoch):
pass
class TrainAbstractMetrics(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log):
pass
def reset(self):
pass
def log_epoch_metrics(self, current_epoch):
pass
class SumExceptBatchMetric(Metric):
def __init__(self):
super().__init__()
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, values) -> None:
self.total_value += torch.sum(values)
self.total_samples += values.shape[0]
def compute(self):
return self.total_value / self.total_samples
class SumExceptBatchMSE(MeanSquaredError):
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape
sum_squared_error, n_obs = self._mean_squared_error_update(preds, target)
self.sum_squared_error += sum_squared_error
self.total += n_obs
def _mean_squared_error_update(self, preds: Tensor, target: Tensor):
""" Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
tensors.
preds: Predicted tensor
target: Ground truth tensor
"""
diff = preds - target
sum_squared_error = torch.sum(diff * diff)
n_obs = preds.shape[0]
return sum_squared_error, n_obs
class SumExceptBatchKL(Metric):
def __init__(self):
super().__init__()
self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, p, q) -> None:
self.total_value += F.kl_div(q, p, reduction='sum')
self.total_samples += p.size(0)
def compute(self):
return self.total_value / self.total_samples
class CrossEntropyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor, weight=None) -> None:
""" Update state with predictions and targets.
preds: Predictions from model (bs * n, d) or (bs * n * n, d)
target: Ground truth values (bs * n, d) or (bs * n * n, d). """
target = torch.argmax(target, dim=-1)
if weight is not None:
weight = weight.type_as(preds)
output = F.cross_entropy(preds, target, weight = weight, reduction='sum')
else:
output = F.cross_entropy(preds, target, reduction='sum')
self.total_ce += output
self.total_samples += preds.size(0)
def compute(self):
return self.total_ce / self.total_samples
class ProbabilityMetric(Metric):
def __init__(self):
""" This metric is used to track the marginal predicted probability of a class during training. """
super().__init__()
self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, preds: Tensor) -> None:
self.prob += preds.sum()
self.total += preds.numel()
def compute(self):
return self.prob / self.total
class NLL(Metric):
def __init__(self):
super().__init__()
self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, batch_nll) -> None:
self.total_nll += torch.sum(batch_nll)
self.total_samples += batch_nll.numel()
def compute(self):
return self.total_nll / self.total_samples

BIN
mcd/metrics/fpscores.pkl.gz Normal file

Binary file not shown.

View File

@ -0,0 +1,138 @@
### packages for visualization
from analysis.rdkit_functions import compute_molecular_metrics
from mini_moses.metrics.metrics import compute_intermediate_statistics
from metrics.property_metric import TaskModel
import torch
import torch.nn as nn
import os
import csv
import time
def result_to_csv(path, dict_data):
file_exists = os.path.exists(path)
log_name = dict_data.pop("log_name", None)
if log_name is None:
raise ValueError("The provided dictionary must contain a 'log_name' key.")
field_names = ["log_name"] + list(dict_data.keys())
dict_data["log_name"] = log_name
with open(path, "a", newline="") as file:
writer = csv.DictWriter(file, fieldnames=field_names)
if not file_exists:
writer.writeheader()
writer.writerow(dict_data)
class SamplingMolecularMetrics(nn.Module):
def __init__(
self,
dataset_infos,
train_smiles,
reference_smiles,
n_jobs=1,
device="cpu",
batch_size=512,
):
super().__init__()
self.task_name = dataset_infos.task
self.dataset_infos = dataset_infos
self.active_atoms = dataset_infos.active_atoms
self.train_smiles = train_smiles
if reference_smiles is not None:
print(
f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
)
start_time = time.time()
self.stat_ref = compute_intermediate_statistics(
reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
)
end_time = time.time()
elapsed_time = end_time - start_time
print(
f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
)
else:
self.stat_ref = None
self.comput_config = {
"n_jobs": n_jobs,
"device": device,
"batch_size": batch_size,
}
self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None}
for cur_task in dataset_infos.task.split("-")[:]:
# print('loading evaluator for task', cur_task)
model_path = os.path.join(
dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
evaluator = TaskModel(model_path, cur_task)
self.task_evaluator[cur_task] = evaluator
def forward(self, molecules, targets, name, current_epoch, val_counter, test=False):
if isinstance(targets, list):
targets_cat = torch.cat(targets, dim=0)
targets_np = targets_cat.detach().cpu().numpy()
else:
targets_np = targets.detach().cpu().numpy()
unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics(
molecules,
targets_np,
self.train_smiles,
self.stat_ref,
self.dataset_infos,
self.task_evaluator,
self.comput_config,
)
if test:
file_name = "final_smiles.txt"
with open(file_name, "w") as fp:
all_tasks_name = list(self.task_evaluator.keys())
all_tasks_name = all_tasks_name.copy()
if 'meta_taskname' in all_tasks_name:
all_tasks_name.remove('meta_taskname')
if 'scs' in all_tasks_name:
all_tasks_name.remove('scs')
all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
fp.write(all_tasks_str + "\n")
for i, smiles in enumerate(all_smiles):
if targets_log is not None:
all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
fp.write(all_result_str + "\n")
else:
fp.write("%s\n" % smiles)
print("All smiles saved")
else:
result_path = os.path.join(os.getcwd(), f"graphs/{name}")
os.makedirs(result_path, exist_ok=True)
text_path = os.path.join(
result_path,
f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt",
)
textfile = open(text_path, "w")
for smiles in unique_smiles:
textfile.write(smiles + "\n")
textfile.close()
all_logs = all_metrics
if test:
all_logs["log_name"] = "test"
else:
all_logs["log_name"] = (
"epoch" + str(current_epoch) + "_batch" + str(val_counter)
)
result_to_csv("output.csv", all_logs)
return all_smiles
def reset(self):
pass
if __name__ == "__main__":
pass

View File

@ -0,0 +1,126 @@
import torch
from torchmetrics import Metric, MetricCollection
from torch import Tensor
import torch.nn as nn
class CEPerClass(Metric):
full_state_update = False
def __init__(self, class_id):
super().__init__()
self.class_id = class_id
self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
self.softmax = torch.nn.Softmax(dim=-1)
self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum')
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.
Args:
preds: Predictions from model (bs, n, d) or (bs, n, n, d)
target: Ground truth values (bs, n, d) or (bs, n, n, d)
"""
target = target.reshape(-1, target.shape[-1])
mask = (target != 0.).any(dim=-1)
prob = self.softmax(preds)[..., self.class_id]
prob = prob.flatten()[mask]
target = target[:, self.class_id]
target = target[mask]
output = self.binary_cross_entropy(prob, target)
self.total_ce += output
self.total_samples += prob.numel()
def compute(self):
return self.total_ce / self.total_samples
class AtomCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class NoBondCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class SingleCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class DoubleCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class TripleCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class AromaticCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class AtomMetricsCE(MetricCollection):
def __init__(self, active_atoms):
metrics_list = []
for i, atom_type in enumerate(active_atoms):
metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i))
super().__init__(metrics_list)
class BondMetricsCE(MetricCollection):
def __init__(self):
ce_no_bond = NoBondCE(0)
ce_SI = SingleCE(1)
ce_DO = DoubleCE(2)
ce_TR = TripleCE(3)
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
class TrainMolecularMetricsDiscrete(nn.Module):
def __init__(self, dataset_infos):
super().__init__()
active_atoms = dataset_infos.active_atoms
self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms)
self.train_bond_metrics = BondMetricsCE()
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
self.train_atom_metrics(masked_pred_X, true_X)
self.train_bond_metrics(masked_pred_E, true_E)
if log:
to_log = {}
for key, val in self.train_atom_metrics.compute().items():
to_log['train/' + key] = val.item()
for key, val in self.train_bond_metrics.compute().items():
to_log['train/' + key] = val.item()
def reset(self):
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
metric.reset()
def log_epoch_metrics(self, current_epoch, log=True):
epoch_atom_metrics = self.train_atom_metrics.compute()
epoch_bond_metrics = self.train_bond_metrics.compute()
to_log = {}
for key, val in epoch_atom_metrics.items():
to_log['train_epoch/' + key] = val.item()
for key, val in epoch_bond_metrics.items():
to_log['train_epoch/' + key] = val.item()
for key, val in epoch_atom_metrics.items():
epoch_atom_metrics[key] = round(val.item(),4)
for key, val in epoch_bond_metrics.items():
epoch_bond_metrics[key] = round(val.item(),4)
if log:
print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}")

View File

@ -0,0 +1,201 @@
import math, os
import pickle
import os.path as op
import numpy as np
import pandas as pd
from joblib import dump, load
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import mean_absolute_error, roc_auc_score
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import rdMolDescriptors
rdBase.DisableLog('rdApp.error')
task_to_colname = {
'hiv_b': 'HIV_active',
'bace_b': 'Class',
'bbbp_b': 'p_np',
'O2': 'O2',
'N2': 'N2',
'CO2': 'CO2',
}
tasktype_name = {
'hiv_b': 'classification',
'bace_b': 'classification',
'bbbp_b': 'classification',
'O2': 'regression',
'N2': 'regression',
'CO2': 'regression',
}
class TaskModel():
"""Scores based on an ECFP classifier."""
def __init__(self, model_path, task_name):
task_type = tasktype_name[task_name]
self.task_name = task_name
self.task_type = task_type
self.model_path = model_path
self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error
try:
self.model = load(model_path)
print(self.task_name, ' evaluator loaded')
except:
print(self.task_name, ' evaluator not found, training new one...')
if 'classification' in task_type:
self.model = RandomForestClassifier(random_state=0)
elif 'regression' in task_type:
self.model = RandomForestRegressor(random_state=0)
perfermance = self.train()
dump(self.model, model_path)
print('Oracle peformance: ', perfermance)
def train(self):
data_path = os.path.dirname(self.model_path)
data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz')
df = pd.read_csv(data_path)
col_name = task_to_colname[self.task_name]
y = df[col_name].to_numpy()
x_smiles = df['smiles'].to_numpy()
mask = ~np.isnan(y)
y = y[mask]
if 'classification' in self.task_type:
y = y.astype(int)
x_smiles = x_smiles[mask]
x_fps = []
mask = []
for i,smiles in enumerate(x_smiles):
mol = Chem.MolFromSmiles(smiles)
mask.append( int(mol is not None) )
fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
x_fps.append(fp)
x_fps = np.concatenate(x_fps, axis=0)
self.model.fit(x_fps, y)
y_pred = self.model.predict(x_fps)
perf = self.metric_func(y, y_pred)
print(f'{self.task_name} performance: {perf}')
return perf
def __call__(self, smiles_list):
fps = []
mask = []
for i,smiles in enumerate(smiles_list):
mol = Chem.MolFromSmiles(smiles)
mask.append( int(mol is not None) )
fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048))
fps.append(fp)
fps = np.concatenate(fps, axis=0)
if 'classification' in self.task_type:
scores = self.model.predict_proba(fps)[:, 1]
else:
scores = self.model.predict(fps)
scores = scores * np.array(mask)
return np.float32(scores)
@classmethod
def fingerprints_from_mol(cls, mol): # use ECFP4
features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
features = np.zeros((1,))
DataStructs.ConvertToNumpyArray(features_vec, features)
return features.reshape(1, -1)
###### SAS Score ######
_fscores = None
def readFragmentScores(name='fpscores'):
import gzip
global _fscores
# generate the full path filename:
if name == "fpscores":
name = op.join(op.dirname(__file__), name)
data = pickle.load(gzip.open('%s.pkl.gz' % name))
outDict = {}
for i in data:
for j in range(1, len(i)):
outDict[i[j]] = float(i[0])
_fscores = outDict
def numBridgeheadsAndSpiro(mol, ri=None):
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
return nBridgehead, nSpiro
def calculateSAS(smiles_list):
scores = []
for i, smiles in enumerate(smiles_list):
mol = Chem.MolFromSmiles(smiles)
score = calculateScore(mol)
scores.append(score)
return np.float32(scores)
def calculateScore(m):
if _fscores is None:
readFragmentScores()
# fragment score
fp = rdMolDescriptors.GetMorganFingerprint(m,
2) # <- 2 is the *radius* of the circular fingerprint
fps = fp.GetNonzeroElements()
score1 = 0.
nf = 0
for bitId, v in fps.items():
nf += v
sfp = bitId
score1 += _fscores.get(sfp, -4) * v
score1 /= nf
# features score
nAtoms = m.GetNumAtoms()
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
ri = m.GetRingInfo()
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
nMacrocycles = 0
for x in ri.AtomRings():
if len(x) > 8:
nMacrocycles += 1
sizePenalty = nAtoms**1.005 - nAtoms
stereoPenalty = math.log10(nChiralCenters + 1)
spiroPenalty = math.log10(nSpiro + 1)
bridgePenalty = math.log10(nBridgeheads + 1)
macrocyclePenalty = 0.
# ---------------------------------------
# This differs from the paper, which defines:
# macrocyclePenalty = math.log10(nMacrocycles+1)
# This form generates better results when 2 or more macrocycles are present
if nMacrocycles > 0:
macrocyclePenalty = math.log10(2)
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
# correction for the fingerprint density
# not in the original publication, added in version 1.1
# to make highly symmetrical molecules easier to synthetise
score3 = 0.
if nAtoms > len(fps):
score3 = math.log(float(nAtoms) / len(fps)) * .5
sascore = score1 + score2 + score3
# need to transform "raw" value into scale between 1 and 10
min = -4.0
max = 2.5
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
# smooth the 10-end
if sascore > 8.:
sascore = 8. + math.log(sascore + 1. - 9.)
if sascore > 10.:
sascore = 10.0
elif sascore < 1.:
sascore = 1.0
return sascore

94
mcd/metrics/train_loss.py Normal file
View File

@ -0,0 +1,94 @@
import time
import torch
import torch.nn as nn
from metrics.abstract_metrics import CrossEntropyMetric
from torchmetrics import Metric, MeanSquaredError
# from 2:He to 119:*
valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
valencies_check = torch.tensor(valencies_check)
weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0]
weight_check = torch.tensor(weight_check)
class AtomWeightMetric(Metric):
def __init__(self):
super().__init__()
self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
global weight_check
self.weight_check = weight_check
def update(self, X, Y):
atom_pred_num = X.argmax(dim=-1)
atom_real_num = Y.argmax(dim=-1)
self.weight_check = self.weight_check.type_as(X)
pred_weight = self.weight_check[atom_pred_num]
real_weight = self.weight_check[atom_real_num]
lss = 0
lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum()
self.total_loss += lss
self.total_samples += X.size(0)
def compute(self):
return self.total_loss / self.total_samples
class TrainLossDiscrete(nn.Module):
""" Train with Cross entropy"""
def __init__(self, lambda_train, weight_node=None, weight_edge=None):
super().__init__()
self.node_loss = CrossEntropyMetric()
self.edge_loss = CrossEntropyMetric()
self.weight_loss = AtomWeightMetric()
self.y_loss = MeanSquaredError()
self.lambda_train = lambda_train
def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool):
""" Compute train metrics
masked_pred_X : tensor -- (bs, n, dx)
masked_pred_E : tensor -- (bs, n, n, de)
pred_y : tensor -- (bs, )
true_X : tensor -- (bs, n, dx)
true_E : tensor -- (bs, n, n, de)
true_y : tensor -- (bs, )
log : boolean. """
loss_weight = self.weight_loss(masked_pred_X, true_X)
true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx)
true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de)
masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1))) # (bs * n, dx)
masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1))) # (bs * n * n, de)
# Remove masked rows
mask_X = (true_X != 0.).any(dim=-1)
mask_E = (true_E != 0.).any(dim=-1)
flat_true_X = true_X[mask_X, :]
flat_pred_X = masked_pred_X[mask_X, :]
flat_true_E = true_E[mask_E, :]
flat_pred_E = masked_pred_E[mask_E, :]
loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0
loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0
return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight
def reset(self):
for metric in [self.node_loss, self.edge_loss, self.y_loss]:
metric.reset()
def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True):
epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1
if log:
print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} "
f"Weight: {epoch_weight_loss :.4f} "
f"-- Time taken {time.time() - start_epoch_time:.1f}s ")

0
mcd/models/__init__.py Normal file
View File

119
mcd/models/conditions.py Normal file
View File

@ -0,0 +1,119 @@
import torch
import torch.nn as nn
import math
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t = t.view(-1)
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class CategoricalEmbedder(nn.Module):
"""
Embeds categorical conditions such as data sources into vector representations.
Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None, t=None):
labels = labels.long().view(-1)
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
if True and train:
noise = torch.randn_like(embeddings)
embeddings = embeddings + noise
return embeddings
class ClusterContinuousEmbedder(nn.Module):
def __init__(self, input_size, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
if use_cfg_embedding:
self.embedding_drop = nn.Embedding(1, hidden_size)
self.mlp = nn.Sequential(
nn.Linear(input_size, hidden_size, bias=True),
nn.Softmax(dim=1),
nn.Linear(hidden_size, hidden_size, bias=False)
)
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
def forward(self, labels, train, force_drop_ids=None, timestep=None):
use_dropout = self.dropout_prob > 0
if force_drop_ids is not None:
drop_ids = force_drop_ids == 1
else:
drop_ids = None
if (train and use_dropout):
drop_ids_rand = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
if force_drop_ids is not None:
drop_ids = torch.logical_or(drop_ids, drop_ids_rand)
else:
drop_ids = drop_ids_rand
if drop_ids is not None:
embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device)
embeddings[~drop_ids] = self.mlp(labels[~drop_ids])
embeddings[drop_ids] += self.embedding_drop.weight[0]
else:
embeddings = self.mlp(labels)
if train:
noise = torch.randn_like(embeddings)
embeddings = embeddings + noise
return embeddings

114
mcd/models/layers.py Normal file
View File

@ -0,0 +1,114 @@
from torch.jit import Final
import torch.nn.functional as F
from itertools import repeat
import collections.abc
import torch
import torch.nn as nn
class Attention(nn.Module):
fast_attn: Final[bool]
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0,
proj_drop=0,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
assert self.fast_attn, "scaled_dot_product_attention Not implemented"
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def dot_product_attention(self, q, k, v):
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn_sfmx = attn.softmax(dim=-1)
attn_sfmx = self.attn_drop(attn_sfmx)
x = attn_sfmx @ v
return x, attn
def forward(self, x, node_mask):
B, N, D = x.shape
# B, head, N, head_dim
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # B, head, N, head_dim
q, k = self.q_norm(q), self.k_norm(k)
attn_mask = (node_mask[:, None, :, None] & node_mask[:, None, None, :]).expand(-1, self.num_heads, N, N)
attn_mask[attn_mask.sum(-1) == 0] = True
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
attn_mask=attn_mask,
)
x = x.transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)

184
mcd/models/transformer.py Normal file
View File

@ -0,0 +1,184 @@
import torch
import torch.nn as nn
import utils
from models.layers import Attention, Mlp
from models.conditions import TimestepEmbedder, CategoricalEmbedder, ClusterContinuousEmbedder
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class Denoiser(nn.Module):
def __init__(
self,
max_n_nodes,
hidden_size=384,
depth=12,
num_heads=16,
mlp_ratio=4.0,
drop_condition=0.1,
Xdim=118,
Edim=5,
ydim=3,
task_type='regression',
):
super().__init__()
self.num_heads = num_heads
self.ydim = ydim
self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedding_list = torch.nn.ModuleList()
self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition))
for i in range(ydim - 2):
if task_type == 'regression':
self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition))
else:
self.y_embedding_list.append(CategoricalEmbedder(2, hidden_size, drop_condition))
self.encoders = nn.ModuleList(
[
SELayer(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(depth)
]
)
self.decoder = Decoder(
max_n_nodes=max_n_nodes,
hidden_size=hidden_size,
atom_type=Xdim,
bond_type=Edim,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def _constant_init(module, i):
if isinstance(module, nn.Linear):
nn.init.constant_(module.weight, i)
if module.bias is not None:
nn.init.constant_(module.bias, i)
self.apply(_basic_init)
for block in self.encoders :
_constant_init(block.adaLN_modulation[0], 0)
_constant_init(self.decoder.adaLN_modulation[0], 0)
def forward(self, x, e, node_mask, y, t, unconditioned):
force_drop_id = torch.zeros_like(y.sum(-1))
force_drop_id[torch.isnan(y.sum(-1))] = 1
if unconditioned:
force_drop_id = torch.ones_like(y[:, 0])
x_in, e_in, y_in = x, e, y
bs, n, _ = x.size()
x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1)
x = self.x_embedder(x)
c1 = self.t_embedder(t)
for i in range(1, self.ydim):
if i == 1:
c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t)
else:
c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t)
c = c1 + c2
for i, block in enumerate(self.encoders):
x = block(x, c, node_mask)
# X: B * N * dx, E: B * N * N * de
X, E, y = self.decoder(x, x_in, e_in, c, t, node_mask)
return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
class SELayer(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.dropout = 0.
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.attn = Attention(
hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True, **block_kwargs
)
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
drop=self.dropout,
)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, node_mask):
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * modulate(self.norm1(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa)
x = x + gate_mlp.unsqueeze(1) * modulate(self.norm2(self.mlp(x)), shift_mlp, scale_mlp)
return x
class Decoder(nn.Module):
# Structure Decoder
def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
super().__init__()
self.atom_type = atom_type
self.bond_type = bond_type
final_size = atom_type + max_n_nodes * bond_type
self.xedecoder = Mlp(in_features=hidden_size,
out_features=final_size, drop=0)
self.norm_final = nn.LayerNorm(final_size, elementwise_affine=False)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, 2 * final_size, bias=True)
)
def forward(self, x, x_in, e_in, c, t, node_mask):
x_all = self.xedecoder(x)
B, N, D = x_all.size()
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x_all = modulate(self.norm_final(x_all), shift, scale)
atom_out = x_all[:, :, :self.atom_type]
atom_out = x_in + atom_out
bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type)
bond_out = e_in + bond_out
##### standardize adj_out
edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :]
diag_mask = (
torch.eye(N, dtype=torch.bool)
.unsqueeze(0)
.expand(B, -1, -1)
.type_as(edge_mask)
)
bond_out.masked_fill_(edge_mask[:, :, :, None], 0)
bond_out.masked_fill_(diag_mask[:, :, :, None], 0)
bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2))
return atom_out, bond_out, None

135
mcd/utils.py Normal file
View File

@ -0,0 +1,135 @@
import os
from omegaconf import OmegaConf, open_dict
import torch
import torch_geometric.utils
from torch_geometric.utils import to_dense_adj, to_dense_batch
def create_folders(args):
try:
os.makedirs('graphs')
os.makedirs('chains')
except OSError:
pass
try:
os.makedirs('graphs/' + args.general.name)
os.makedirs('chains/' + args.general.name)
except OSError:
pass
def normalize(X, E, y, norm_values, norm_biases, node_mask):
X = (X - norm_biases[0]) / norm_values[0]
E = (E - norm_biases[1]) / norm_values[1]
y = (y - norm_biases[2]) / norm_values[2]
diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
E[diag] = 0
return PlaceHolder(X=X, E=E, y=y).mask(node_mask)
def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False):
"""
X : node features
E : edge features
y : global features`
norm_values : [norm value X, norm value E, norm value y]
norm_biases : same order
node_mask
"""
X = (X * norm_values[0] + norm_biases[0])
E = (E * norm_values[1] + norm_biases[1])
y = y * norm_values[2] + norm_biases[2]
return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse)
def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None):
X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes)
# node_mask = node_mask.float()
edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
if max_num_nodes is None:
max_num_nodes = X.size(1)
E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
E = encode_no_edge(E)
return PlaceHolder(X=X, E=E, y=None), node_mask
def encode_no_edge(E):
assert len(E.shape) == 4
if E.shape[-1] == 0:
return E
no_edge = torch.sum(E, dim=3) == 0
first_elt = E[:, :, :, 0]
first_elt[no_edge] = 1
E[:, :, :, 0] = first_elt
diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
E[diag] = 0
return E
def update_config_with_new_keys(cfg, saved_cfg):
saved_general = saved_cfg.general
saved_train = saved_cfg.train
saved_model = saved_cfg.model
saved_dataset = saved_cfg.dataset
for key, val in saved_dataset.items():
OmegaConf.set_struct(cfg.dataset, True)
with open_dict(cfg.dataset):
if key not in cfg.dataset.keys():
setattr(cfg.dataset, key, val)
for key, val in saved_general.items():
OmegaConf.set_struct(cfg.general, True)
with open_dict(cfg.general):
if key not in cfg.general.keys():
setattr(cfg.general, key, val)
OmegaConf.set_struct(cfg.train, True)
with open_dict(cfg.train):
for key, val in saved_train.items():
if key not in cfg.train.keys():
setattr(cfg.train, key, val)
OmegaConf.set_struct(cfg.model, True)
with open_dict(cfg.model):
for key, val in saved_model.items():
if key not in cfg.model.keys():
setattr(cfg.model, key, val)
return cfg
class PlaceHolder:
def __init__(self, X, E, y):
self.X = X
self.E = E
self.y = y
def type_as(self, x: torch.Tensor, categorical: bool = False):
""" Changes the device and dtype of X, E, y. """
self.X = self.X.type_as(x)
self.E = self.E.type_as(x)
if categorical:
self.y = self.y.type_as(x)
return self
def mask(self, node_mask, collapse=False):
x_mask = node_mask.unsqueeze(-1) # bs, n, 1
e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
if collapse:
self.X = torch.argmax(self.X, dim=-1)
self.E = torch.argmax(self.E, dim=-1)
self.X[node_mask == 0] = - 1
self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
else:
self.X = self.X * x_mask
self.E = self.E * e_mask1 * e_mask2
assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
return self

17
requirements.txt Normal file
View File

@ -0,0 +1,17 @@
fcd_torch==1.0.7
hydra-core==1.3.2
imageio==2.26.0
joblib==1.2.0
matplotlib==3.7.0
mini_moses==1.0
networkx==3.0
numpy==1.24.2
omegaconf==2.3.0
pandas==1.5.3
pytorch_lightning==2.0.1
rdkit==2023.9.4
scikit_learn==1.2.1
torch==2.0.0
torch_geometric==2.3.0
torchmetrics==0.11.4
tqdm==4.64.1