diff --git a/README.md b/README.md index f23a17e..c1394b5 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ This is the code for MCD: a Multi-Conditional Diffusion Model for inverse small ## Requirements All dependencies are specified in the `requirements.txt` file. -This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, pytorch-lightning 2.0.1. +This code was developed and tested with Python 3.9.16, PyTorch 2.0.0, and PyG 2.3.0, Pytorch-lightning 2.0.1. For molecular generation evaluation, we should first install rdkit: diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..33dcb95 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,46 @@ +general: + name: 'MCD' + wandb: 'disabled' + gpus: 1 + resume: null + test_only: null + sample_every_val: 2500 + samples_to_generate: 512 + samples_to_save: 3 + chains_to_save: 1 + log_every_steps: 50 + number_chain_steps: 8 + final_model_samples_to_generate: 10000 + final_model_samples_to_save: 20 + final_model_chains_to_save: 1 + enable_progress_bar: False + save_model: False +model: + type: 'discrete' + transition: 'marginal' + model: 'MCD' + diffusion_steps: 500 + diffusion_noise_schedule: 'cosine' + guide_scale: 2 + hidden_size: 1152 + depth: 6 + num_heads: 16 + mlp_ratio: 4 + drop_condition: 0.01 + lambda_train: [1, 10] # node and edge training weight + ensure_connected: True +train: + n_epochs: 10000 + batch_size: 1200 + lr: 0.0002 + clip_grad: null + num_workers: 0 + weight_decay: 0 + seed: 0 + val_check_interval: null + check_val_every_n_epoch: 1 +dataset: + datadir: 'data/' + task_name: null + guidance_target: null + pin_memory: False diff --git a/mcd/__init__.py b/mcd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcd/analysis/__init__.py b/mcd/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcd/analysis/rdkit_functions.py b/mcd/analysis/rdkit_functions.py new file mode 100644 index 0000000..156f6a1 --- /dev/null +++ b/mcd/analysis/rdkit_functions.py @@ -0,0 +1,411 @@ +from rdkit import Chem, RDLogger +RDLogger.DisableLog('rdApp.*') +from fcd_torch import FCD as FCDMetric +from mini_moses.metrics.metrics import FragMetric, internal_diversity +from mini_moses.metrics.utils import get_mol, mapper + +import re +import time +import random +random.seed(0) +import numpy as np +from multiprocessing import Pool + +import torch +from metrics.property_metric import calculateSAS + +bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC] +ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} + +bd_dict_x = {'O2-N2': [5.00E+04, 1.00E-03]} +bd_dict_y = {'O2-N2': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05]} + +selectivity = ['O2-N2'] +a_dict = {} +b_dict = {} +for prop_name in selectivity: + x1, x2 = np.log10(bd_dict_x[prop_name][0]), np.log10(bd_dict_x[prop_name][1]) + y1, y2 = np.log10(bd_dict_y[prop_name][0]), np.log10(bd_dict_y[prop_name][1]) + a = (y1-y2)/(x1-x2) + b = y1-a*x1 + a_dict[prop_name] = a + b_dict[prop_name] = b + +def selectivity_evaluation(gas1, gas2, prop_name): + x = np.log10(np.array(gas1)) + y = np.log10(np.array(gas1) / np.array(gas2)) + upper = (y - (a_dict[prop_name] * x + b_dict[prop_name])) > 0 + return upper + +class BasicMolecularMetrics(object): + def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512): + self.dataset_smiles_list = train_smiles + self.atom_decoder = atom_decoder + self.n_jobs = n_jobs + self.device = device + self.batch_size = batch_size + self.stat_ref = stat_ref + self.task_evaluator = task_evaluator + + def compute_relaxed_validity(self, generated, ensure_connected): + valid = [] + num_components = [] + all_smiles = [] + valid_mols = [] + covered_atoms = set() + direct_valid_count = 0 + for graph in generated: + atom_types, edge_types = graph + mol = build_molecule_with_partial_charges(atom_types, edge_types, self.atom_decoder) + direct_valid_flag = True if check_mol(mol, largest_connected_comp=True) is not None else False + if direct_valid_flag: + direct_valid_count += 1 + if not ensure_connected: + mol_conn, _ = correct_mol(mol, connection=True) + mol = mol_conn if mol_conn is not None else correct_mol(mol, connection=False)[0] + else: # ensure fully connected + mol, _ = correct_mol(mol, connection=True) + smiles = mol2smiles(mol) + mol = get_mol(smiles) + try: + mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) + num_components.append(len(mol_frags)) + largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) + smiles = mol2smiles(largest_mol) + if smiles is not None and largest_mol is not None and len(smiles) > 1 and Chem.MolFromSmiles(smiles) is not None: + valid_mols.append(largest_mol) + valid.append(smiles) + for atom in largest_mol.GetAtoms(): + covered_atoms.add(atom.GetSymbol()) + all_smiles.append(smiles) + else: + all_smiles.append(None) + except Exception as e: + # print(f"An error occurred: {e}") + all_smiles.append(None) + + return valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_smiles, covered_atoms + + def evaluate(self, generated, targets, ensure_connected, active_atoms=None): + """ generated: list of pairs (positions: n x 3, atom_types: n [int]) + the positions and atom types should already be masked. """ + valid, validity, nc_validity, num_components, all_smiles, covered_atoms = self.compute_relaxed_validity(generated, ensure_connected=ensure_connected) + nc_mu = num_components.mean() if len(num_components) > 0 else 0 + nc_min = num_components.min() if len(num_components) > 0 else 0 + nc_max = num_components.max() if len(num_components) > 0 else 0 + + len_active = len(active_atoms) if active_atoms is not None else 1 + + cover_str = f"Cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}" + print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}") + print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}") + + if validity > 0: + dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity} + unique = list(set(valid)) + close_pool = False + if self.n_jobs != 1: + pool = Pool(self.n_jobs) + close_pool = True + else: + pool = 1 + valid_mols = mapper(pool)(get_mol, valid) + dist_metrics['interval_diversity'] = internal_diversity(valid_mols, pool, device=self.device) + + start_time = time.time() + if self.stat_ref is not None: + kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size} + kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size} + try: + dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_mols, pref=self.stat_ref['Frag']) + except: + print('error: ', 'pool', pool) + print('valid_mols: ', valid_mols) + dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD']) + + if self.task_evaluator is not None: + evaluation_list = list(self.task_evaluator.keys()) + evaluation_list = evaluation_list.copy() + + assert 'meta_taskname' in evaluation_list + meta_taskname = self.task_evaluator['meta_taskname'] + evaluation_list.remove('meta_taskname') + meta_split = meta_taskname.split('-') + + valid_index = np.array([True if smiles else False for smiles in all_smiles]) + targets_log = {} + for i, name in enumerate(evaluation_list): + targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index)) + targets_log[f'input_{name}'] = targets[:, i] + + targets = targets[valid_index] + if len(meta_split) == 2: + cached_perm = {meta_split[0]: None, meta_split[1]: None} + + for i, name in enumerate(evaluation_list): + if name == 'scs': + continue + elif name == 'sas': + scores = calculateSAS(valid) + else: + scores = self.task_evaluator[name](valid) + targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index)) + targets_log[f'output_{name}'][valid_index] = scores + if name in ['O2', 'N2', 'CO2']: + if len(meta_split) == 2: + cached_perm[name] = scores + scores, cur_targets = np.log10(scores), np.log10(targets[:, i]) + dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets)) + elif name == 'sas': + dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i])) + else: + true_y = targets[:, i] + predicted_labels = (scores >= 0.5).astype(int) + acc = (predicted_labels == true_y).sum() / len(true_y) + dist_metrics[f'{name}/acc'] = acc + + if len(meta_split) == 2: + if cached_perm[meta_split[0]] is not None and cached_perm[meta_split[1]] is not None: + task_name = self.task_evaluator['meta_taskname'] + upper = selectivity_evaluation(cached_perm[meta_split[0]], cached_perm[meta_split[1]], task_name) + dist_metrics[f'selectivity/{task_name}'] = np.sum(upper) + + end_time = time.time() + elapsed_time = end_time - start_time + max_key_length = max(len(key) for key in dist_metrics) + print(f'Details over {len(valid)} ({len(generated)}) valid (total) molecules, calculating metrics using {elapsed_time:.2f} s:') + strs = '' + for i, (key, value) in enumerate(dist_metrics.items()): + if isinstance(value, (int, float, np.floating, np.integer)): + strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t' + if i % 4 == 3: + strs = strs + '\n' + print(strs) + + if close_pool: + pool.close() + pool.join() + else: + unique = [] + dist_metrics = {} + targets_log = None + return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles, dist_metrics, targets_log + +def mol2smiles(mol): + if mol is None: + return None + try: + Chem.SanitizeMol(mol) + except ValueError: + return None + return Chem.MolToSmiles(mol) + +def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False): + if verbose: + print("\nbuilding new molecule") + + mol = Chem.RWMol() + for atom in atom_types: + a = Chem.Atom(atom_decoder[atom.item()]) + mol.AddAtom(a) + if verbose: + print("Atom added: ", atom.item(), atom_decoder[atom.item()]) + + edge_types = torch.triu(edge_types) + all_bonds = torch.nonzero(edge_types) + + for i, bond in enumerate(all_bonds): + if bond[0].item() != bond[1].item(): + mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()]) + if verbose: + print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), + bond_dict[edge_types[bond[0], bond[1]].item()]) + # add formal charge to atom: e.g. [O+], [N+], [S+] + # not support [O-], [N-], [S-], [NH+] etc. + flag, atomid_valence = check_valency(mol) + if verbose: + print("flag, valence", flag, atomid_valence) + if flag: + continue + else: + if len(atomid_valence) == 2: + idx = atomid_valence[0] + v = atomid_valence[1] + an = mol.GetAtomWithIdx(idx).GetAtomicNum() + if verbose: + print("atomic num of atom with a large valence", an) + if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: + mol.GetAtomWithIdx(idx).SetFormalCharge(1) + # print("Formal charge added") + else: + continue + return mol + +def check_valency(mol): + try: + Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) + return True, None + except ValueError as e: + e = str(e) + p = e.find('#') + e_sub = e[p:] + atomid_valence = list(map(int, re.findall(r'\d+', e_sub))) + return False, atomid_valence + + +def correct_mol(mol, connection=False): + ##### + no_correct = False + flag, _ = check_valency(mol) + if flag: + no_correct = True + + while True: + if connection: + mol_conn = connect_fragments(mol) + # if mol_conn is not None: + mol = mol_conn + if mol is None: + return None, no_correct + flag, atomid_valence = check_valency(mol) + if flag: + break + else: + try: + assert len(atomid_valence) == 2 + idx = atomid_valence[0] + v = atomid_valence[1] + queue = [] + check_idx = 0 + for b in mol.GetAtomWithIdx(idx).GetBonds(): + type = int(b.GetBondType()) + queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx())) + if type == 12: + check_idx += 1 + queue.sort(key=lambda tup: tup[1], reverse=True) + + if queue[-1][1] == 12: + return None, no_correct + elif len(queue) > 0: + start = queue[check_idx][2] + end = queue[check_idx][3] + t = queue[check_idx][1] - 1 + mol.RemoveBond(start, end) + if t >= 1: + mol.AddBond(start, end, bond_dict[t]) + except Exception as e: + # print(f"An error occurred in correction: {e}") + return None, no_correct + return mol, no_correct + + +def check_mol(m, largest_connected_comp=True): + if m is None: + return None + sm = Chem.MolToSmiles(m, isomericSmiles=True) + if largest_connected_comp and '.' in sm: + vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.') + vsm.sort(key=lambda tup: tup[1], reverse=True) + mol = Chem.MolFromSmiles(vsm[0][0]) + else: + mol = Chem.MolFromSmiles(sm) + return mol + + +##### connect fragements +def select_atom_with_available_valency(frag): + atoms = list(frag.GetAtoms()) + random.shuffle(atoms) + for atom in atoms: + if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0: + return atom + + return None + +def select_atoms_with_available_valency(frag): + return [atom for atom in frag.GetAtoms() if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0] + +def try_to_connect_fragments(combined_mol, frag, atom1, atom2): + # Make copies of the molecules to try the connection + trial_combined_mol = Chem.RWMol(combined_mol) + trial_frag = Chem.RWMol(frag) + + # Add the new fragment to the combined molecule with new indices + new_indices = {atom.GetIdx(): trial_combined_mol.AddAtom(atom) for atom in trial_frag.GetAtoms()} + + # Add the bond between the suitable atoms from each fragment + trial_combined_mol.AddBond(atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE) + + # Adjust the hydrogen count of the connected atoms + for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]: + atom = trial_combined_mol.GetAtomWithIdx(atom_idx) + num_h = atom.GetTotalNumHs() + atom.SetNumExplicitHs(max(0, num_h - 1)) + + # Add bonds for the new fragment + for bond in trial_frag.GetBonds(): + trial_combined_mol.AddBond(new_indices[bond.GetBeginAtomIdx()], new_indices[bond.GetEndAtomIdx()], bond.GetBondType()) + + # Convert to a Mol object and try to sanitize it + new_mol = Chem.Mol(trial_combined_mol) + try: + Chem.SanitizeMol(new_mol) + return new_mol # Return the new valid molecule + except Chem.MolSanitizeException: + return None # If the molecule is not valid, return None + +def connect_fragments(mol): + # Get the separate fragments + frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) + if len(frags) < 2: + return mol + + combined_mol = Chem.RWMol(frags[0]) + + for frag in frags[1:]: + # Select all atoms with available valency from both molecules + atoms1 = select_atoms_with_available_valency(combined_mol) + atoms2 = select_atoms_with_available_valency(frag) + + # Try to connect using all combinations of available valency atoms + for atom1 in atoms1: + for atom2 in atoms2: + new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2) + if new_mol is not None: + # If a valid connection is made, update the combined molecule and break + combined_mol = new_mol + break + else: + # Continue if the inner loop didn't break (no valid connection found for atom1) + continue + # Break if the inner loop did break (valid connection found) + break + else: + # If no valid connections could be made with any of the atoms, return None + return None + + return combined_mol + +#### connect fragements + +def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config): + """ molecule_list: (dict) """ + + atom_decoder = dataset_info.atom_decoder + active_atoms = dataset_info.active_atoms + ensure_connected = dataset_info.ensure_connected + metrics = BasicMolecularMetrics(atom_decoder, train_smiles, stat_ref, task_evaluator, **comput_config) + evaluated_res = metrics.evaluate(molecule_list, targets, ensure_connected, active_atoms) + all_smiles = evaluated_res[-3] + all_metrics = evaluated_res[-2] + targets_log = evaluated_res[-1] + unique_smiles = evaluated_res[0] + + return unique_smiles, all_smiles, all_metrics, targets_log + +if __name__ == '__main__': + smiles_mol = 'C1CCC1' + print("Smiles mol %s" % smiles_mol) + chem_mol = Chem.MolFromSmiles(smiles_mol) + print(block_mol) diff --git a/mcd/analysis/visualization.py b/mcd/analysis/visualization.py new file mode 100644 index 0000000..913961c --- /dev/null +++ b/mcd/analysis/visualization.py @@ -0,0 +1,222 @@ +import os + +from rdkit import Chem +from rdkit.Chem import Draw, AllChem +from rdkit.Geometry import Point3D +from rdkit import RDLogger +import imageio +import networkx as nx +import numpy as np +import rdkit.Chem +import matplotlib.pyplot as plt + + +class MolecularVisualization: + def __init__(self, dataset_infos): + self.dataset_infos = dataset_infos + + def mol_from_graphs(self, node_list, adjacency_matrix): + """ + Convert graphs to rdkit molecules + node_list: the nodes of a batch of nodes (bs x n) + adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) + """ + # dictionary to map integer value to the char of atom + atom_decoder = self.dataset_infos.atom_decoder + # [list(self.dataset_infos.atom_decoder.keys())[0]] + + # create empty editable mol object + mol = Chem.RWMol() + + # add atoms to mol and keep track of index + node_to_idx = {} + for i in range(len(node_list)): + if node_list[i] == -1: + continue + a = Chem.Atom(atom_decoder[int(node_list[i])]) + molIdx = mol.AddAtom(a) + node_to_idx[i] = molIdx + + for ix, row in enumerate(adjacency_matrix): + for iy, bond in enumerate(row): + # only traverse half the symmetric matrix + if iy <= ix: + continue + if bond == 1: + bond_type = Chem.rdchem.BondType.SINGLE + elif bond == 2: + bond_type = Chem.rdchem.BondType.DOUBLE + elif bond == 3: + bond_type = Chem.rdchem.BondType.TRIPLE + elif bond == 4: + bond_type = Chem.rdchem.BondType.AROMATIC + else: + continue + mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) + + try: + mol = mol.GetMol() + except rdkit.Chem.KekulizeException: + print("Can't kekulize molecule") + mol = None + return mol + + def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'): + # define path to save figures + if not os.path.exists(path): + os.makedirs(path) + + # visualize the final molecules + print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}") + if num_molecules_to_visualize > len(molecules): + print(f"Shortening to {len(molecules)}") + num_molecules_to_visualize = len(molecules) + + for i in range(num_molecules_to_visualize): + file_path = os.path.join(path, 'molecule_{}.png'.format(i)) + mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy()) + try: + Draw.MolToFile(mol, file_path) + except rdkit.Chem.KekulizeException: + print("Can't kekulize molecule") + + def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None): + RDLogger.DisableLog('rdApp.*') + # convert graphs to the rdkit molecules + mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] + + # find the coordinates of atoms in the final molecule + final_molecule = mols[-1] + AllChem.Compute2DCoords(final_molecule) + + coords = [] + for i, atom in enumerate(final_molecule.GetAtoms()): + positions = final_molecule.GetConformer().GetAtomPosition(i) + coords.append((positions.x, positions.y, positions.z)) + + # align all the molecules + for i, mol in enumerate(mols): + AllChem.Compute2DCoords(mol) + conf = mol.GetConformer() + for j, atom in enumerate(mol.GetAtoms()): + x, y, z = coords[j] + conf.SetAtomPosition(j, Point3D(x, y, z)) + + # draw gif + save_paths = [] + num_frams = nodes_list.shape[0] + + for frame in range(num_frams): + file_name = os.path.join(path, 'fram_{}.png'.format(frame)) + Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}") + save_paths.append(file_name) + + imgs = [imageio.imread(fn) for fn in save_paths] + gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1])) + imgs.extend([imgs[-1]] * 10) + imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5) + + # draw grid image + try: + img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200)) + img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1]))) + except Chem.rdchem.KekulizeException: + print("Can't kekulize molecule") + return mols + + def visualize_by_smiles(self, path: str, smiles_list: list, num_to_visualize: int): + os.makedirs(path, exist_ok=True) + + print(f"Visualizing corrected {num_to_visualize} of {len(smiles_list)}") + if num_to_visualize > len(smiles_list): + print(f"Shortening to {len(smiles_list)}") + num_to_visualize = len(smiles_list) + + for i in range(num_to_visualize): + file_path = os.path.join(path, 'molecule_corrected_{}.png'.format(i)) + if smiles_list[i] is None: + continue + mol = Chem.MolFromSmiles(smiles_list[i]) + try: + Draw.MolToFile(mol, file_path) + except rdkit.Chem.KekulizeException: + print("Can't kekulize molecule") + +class NonMolecularVisualization: + def to_networkx(self, node_list, adjacency_matrix): + """ + Convert graphs to networkx graphs + node_list: the nodes of a batch of nodes (bs x n) + adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) + """ + graph = nx.Graph() + + for i in range(len(node_list)): + if node_list[i] == -1: + continue + graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i]) + + rows, cols = np.where(adjacency_matrix >= 1) + edges = zip(rows.tolist(), cols.tolist()) + for edge in edges: + edge_type = adjacency_matrix[edge[0]][edge[1]] + graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type) + + return graph + + def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False): + if largest_component: + CGs = [graph.subgraph(c) for c in nx.connected_components(graph)] + CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True) + graph = CGs[0] + + # Plot the graph structure with colors + if pos is None: + pos = nx.spring_layout(graph, iterations=iterations) + + # Set node colors based on the eigenvectors + w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray()) + vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1]) + m = max(np.abs(vmin), vmax) + vmin, vmax = -m, m + + plt.figure() + nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1], + cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey') + + plt.tight_layout() + plt.savefig(path) + plt.close("all") + + def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'): + # define path to save figures + if not os.path.exists(path): + os.makedirs(path) + + # visualize the final molecules + for i in range(num_graphs_to_visualize): + file_path = os.path.join(path, 'graph_{}.png'.format(i)) + graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy()) + self.visualize_non_molecule(graph=graph, pos=None, path=file_path) + im = plt.imread(file_path) + + def visualize_chain(self, path, nodes_list, adjacency_matrix): + # convert graphs to networkx + graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] + # find the coordinates of atoms in the final molecule + final_graph = graphs[-1] + final_pos = nx.spring_layout(final_graph, seed=0) + + # draw gif + save_paths = [] + num_frams = nodes_list.shape[0] + + for frame in range(num_frams): + file_name = os.path.join(path, 'fram_{}.png'.format(frame)) + self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name) + save_paths.append(file_name) + + imgs = [imageio.imread(fn) for fn in save_paths] + gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1])) + imgs.extend([imgs[-1]] * 10) + imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5) diff --git a/mcd/datasets/__init__.py b/mcd/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcd/datasets/abstract_dataset.py b/mcd/datasets/abstract_dataset.py new file mode 100644 index 0000000..0d0d9f9 --- /dev/null +++ b/mcd/datasets/abstract_dataset.py @@ -0,0 +1,126 @@ +from diffusion.distributions import DistributionNodes +import utils as utils +import torch +import pytorch_lightning as pl +from torch_geometric.loader import DataLoader + + +class AbstractDataModule(pl.LightningDataModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.dataloaders = None + self.input_dims = None + self.output_dims = None + + def prepare_data(self, datasets) -> None: + batch_size = self.cfg.train.batch_size + num_workers = self.cfg.train.num_workers + self.dataloaders = {split: DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, + shuffle='debug' not in self.cfg.general.name) + for split, dataset in datasets.items()} + + def train_dataloader(self): + return self.dataloaders["train"] + + def val_dataloader(self): + return self.dataloaders["val"] + + def test_dataloader(self): + return self.dataloaders["test"] + + def __getitem__(self, idx): + return self.dataloaders['train'][idx] + + def node_counts(self, max_nodes_possible=300): + all_counts = torch.zeros(max_nodes_possible) + for split in ['train', 'val', 'test']: + for i, data in enumerate(self.dataloaders[split]): + unique, counts = torch.unique(data.batch, return_counts=True) + for count in counts: + all_counts[count] += 1 + max_index = max(all_counts.nonzero()) + all_counts = all_counts[:max_index + 1] + all_counts = all_counts / all_counts.sum() + return all_counts + + def node_types(self): + num_classes = None + for data in self.dataloaders['train']: + num_classes = data.x.shape[1] + break + + counts = torch.zeros(num_classes) + + for split in ['train', 'val', 'test']: + for i, data in enumerate(self.dataloaders[split]): + counts += data.x.sum(dim=0) + + counts = counts / counts.sum() + return counts + + def edge_counts(self): + num_classes = None + for data in self.dataloaders['train']: + num_classes = 5 + break + + d = torch.Tensor(num_classes) + + for split in ['train', 'val', 'test']: + for i, data in enumerate(self.dataloaders[split]): + unique, counts = torch.unique(data.batch, return_counts=True) + + all_pairs = 0 + for count in counts: + all_pairs += count * (count - 1) + + num_edges = data.edge_index.shape[1] + num_non_edges = all_pairs - num_edges + + data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float() + edge_types = data_edge_attr.sum(dim=0) + assert num_non_edges >= 0 + d[0] += num_non_edges + d[1:] += edge_types[1:] + + d = d / d.sum() + return d + + +class MolecularDataModule(AbstractDataModule): + def valency_count(self, max_n_nodes): + valencies = torch.zeros(3 * max_n_nodes - 2) # Max valency possible if everything is connected + multiplier = torch.Tensor([0, 1, 2, 3, 1.5]) + for split in ['train', 'val', 'test']: + for i, data in enumerate(self.dataloaders[split]): + n = data.x.shape[0] + for atom in range(n): + data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float() + edges = data_edge_attr[data.edge_index[0] == atom] + edges_total = edges.sum(dim=0) + valency = (edges_total * multiplier).sum() + valencies[valency.long().item()] += 1 + valencies = valencies / valencies.sum() + return valencies + + +class AbstractDatasetInfos: + def complete_infos(self, n_nodes, node_types): + self.input_dims = None + self.output_dims = None + self.num_classes = len(node_types) + self.max_n_nodes = len(n_nodes) - 1 + self.nodes_dist = DistributionNodes(n_nodes) + + def compute_input_output_dims(self, datamodule): + example_batch = datamodule.example_batch() + example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] + example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float() + + self.input_dims = {'X': example_batch_x.size(1), + 'E': example_batch_edge_attr.size(1), + 'y': example_batch['y'].size(1)} + self.output_dims = {'X': example_batch_x.size(1), + 'E': example_batch_edge_attr.size(1), + 'y': example_batch['y'].size(1)} \ No newline at end of file diff --git a/mcd/datasets/dataset.py b/mcd/datasets/dataset.py new file mode 100644 index 0000000..a30d139 --- /dev/null +++ b/mcd/datasets/dataset.py @@ -0,0 +1,381 @@ + +import sys +sys.path.append('../') + +import os +import os.path as osp +import pathlib +import json + +import torch +import torch.nn.functional as F +from rdkit import Chem, RDLogger +from rdkit.Chem.rdchem import BondType as BT +from tqdm import tqdm +import numpy as np +import pandas as pd +from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.loader import DataLoader +from sklearn.model_selection import train_test_split + +import utils as utils +from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule +from diffusion.distributions import DistributionNodes + +bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} + +class DataModule(AbstractDataModule): + def __init__(self, cfg): + self.datadir = cfg.dataset.datadir + self.task = cfg.dataset.task_name + super().__init__(cfg) + + def prepare_data(self) -> None: + target = getattr(self.cfg.dataset, 'guidance_target', None) + base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] + root_path = os.path.join(base_path, self.datadir) + self.root_path = root_path + + batch_size = self.cfg.train.batch_size + num_workers = self.cfg.train.num_workers + pin_memory = self.cfg.dataset.pin_memory + + dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) + + if len(self.task.split('-')) == 2: + train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) + else: + train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) + + self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index + train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) + if len(unlabeled_index) > 0: + train_index = torch.cat([train_index, unlabeled_index], dim=0) + + train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] + self.train_dataset = train_dataset + self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) + self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) + self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) + + training_iterations = len(train_dataset) // batch_size + self.training_iterations = training_iterations + + def random_data_split(self, dataset): + nan_count = torch.isnan(dataset.y[:, 0]).sum().item() + labeled_len = len(dataset) - nan_count + full_idx = list(range(labeled_len)) + train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 + train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) + train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42) + unlabeled_index = list(range(labeled_len, len(dataset))) + print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index)) + return train_index, val_index, test_index, unlabeled_index + + def fixed_split(self, dataset): + if self.task == 'O2-N2': + test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604] + else: + raise ValueError('Invalid task name: {}'.format(self.task)) + full_idx = list(range(len(dataset))) + full_idx = list(set(full_idx) - set(test_index)) + train_ratio = 0.8 + train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42) + print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) + return train_index, val_index, test_index, [] + + def get_train_smiles(self): + filename = f'{self.task}.csv.gz' + df = pd.read_csv(f'{self.root_path}/raw/{filename}') + df_test = df.iloc[self.test_index] + df = df.iloc[self.train_index] + smiles_list = df['smiles'].tolist() + smiles_list_test = df_test['smiles'].tolist() + smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list] + smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test] + return smiles_list, smiles_list_test + + def get_data_split(self): + filename = f'{self.task}.csv.gz' + df = pd.read_csv(f'{self.root_path}/raw/{filename}') + df_val = df.iloc[self.val_index] + df_test = df.iloc[self.test_index] + df_train = df.iloc[self.train_index] + return df_train, df_val, df_test + + def example_batch(self): + return next(iter(self.val_loader)) + + def train_dataloader(self): + return self.train_loader + + def val_dataloader(self): + return self.val_loader + + def test_dataloader(self): + return self.test_loader + + +class Dataset(InMemoryDataset): + def __init__(self, source, root, target_prop=None, + transform=None, pre_transform=None, pre_filter=None): + self.target_prop = target_prop + self.source = source + super().__init__(root, transform, pre_transform, pre_filter) + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return [f'{self.source}.csv.gz'] + + @property + def processed_file_names(self): + return [f'{self.source}.pt'] + + def process(self): + RDLogger.DisableLog('rdApp.*') + data_path = osp.join(self.raw_dir, self.raw_file_names[0]) + data_df = pd.read_csv(data_path) + + def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None): + type_idx = [] + heavy_atom_indices, active_atoms = [], [] + for atom in mol.GetAtoms(): + if atom.GetAtomicNum() != 1: + type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2) + heavy_atom_indices.append(atom.GetIdx()) + active_atoms.append(atom.GetSymbol()) + if valid_atoms is not None: + if not atom.GetSymbol() in valid_atoms: + return None, None + x = torch.LongTensor(type_idx) + + edges_list = [] + edge_type = [] + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + if start in heavy_atom_indices and end in heavy_atom_indices: + start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end) + edges_list.append((start_new, end_new)) + edge_type.append(bonds[bond.GetBondType()]) + edges_list.append((end_new, start_new)) + edge_type.append(bonds[bond.GetBondType()]) + edge_index = torch.tensor(edges_list, dtype=torch.long).t() + edge_type = torch.tensor(edge_type, dtype=torch.long) + edge_attr = edge_type + + if target3 is not None: + y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1) + elif target2 is not None: + y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1) + else: + y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) + if self.pre_transform is not None: + data = self.pre_transform(data) + return data, active_atoms + + # Loop through every row in the DataFrame and apply the function + data_list = [] + len_data = len(data_df) + with tqdm(total=len_data) as pbar: + # --- data processing start --- + active_atoms = set() + for i, (sms, df_row) in enumerate(data_df.iterrows()): + if i == sms: + sms = df_row['smiles'] + mol = Chem.MolFromSmiles(sms, sanitize=False) + if len(self.target_prop.split('-')) == 2: + target1, target2 = self.target_prop.split('-') + data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2]) + elif len(self.target_prop.split('-')) == 3: + target1, target2, target3 = self.target_prop.split('-') + data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3]) + else: + data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop]) + active_atoms.update(cur_active_atoms) + data_list.append(data) + pbar.update(1) + + torch.save(self.collate(data_list), self.processed_paths[0]) + + +class DataInfos(AbstractDatasetInfos): + def __init__(self, datamodule, cfg): + tasktype_dict = { + 'hiv_b': 'classification', + 'bace_b': 'classification', + 'bbbp_b': 'classification', + 'O2': 'regression', + 'N2': 'regression', + 'CO2': 'regression', + } + task_name = cfg.dataset.task_name + self.task = task_name + self.task_type = tasktype_dict.get(task_name, "regression") + self.ensure_connected = cfg.model.ensure_connected + + datadir = cfg.dataset.datadir + + base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] + meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json') + data_root = os.path.join(base_path, datadir, 'raw') + if os.path.exists(meta_filename): + with open(meta_filename, 'r') as f: + meta_dict = json.load(f) + else: + meta_dict = compute_meta(data_root, task_name, datamodule.train_index, datamodule.test_index) + + self.base_path = base_path + self.active_atoms = meta_dict['active_atoms'] + self.max_n_nodes = meta_dict['max_node'] + self.original_max_n_nodes = meta_dict['max_node'] + self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist']) + self.edge_types = torch.Tensor(meta_dict['bond_type_dist']) + self.transition_E = torch.Tensor(meta_dict['transition_E']) + + self.atom_decoder = meta_dict['active_atoms'] + node_types = torch.Tensor(meta_dict['atom_type_dist']) + active_index = (node_types > 0).nonzero().squeeze() + self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index] + self.nodes_dist = DistributionNodes(self.n_nodes) + self.active_index = active_index + + val_len = 3 * self.original_max_n_nodes - 2 + meta_val = torch.Tensor(meta_dict['valencies']) + self.valency_distribution = torch.zeros(val_len) + val_len = min(val_len, len(meta_val)) + self.valency_distribution[:val_len] = meta_val[:val_len] + self.y_prior = None + self.train_ymin = [] + self.train_ymax = [] + + +def compute_meta(root, source_name, train_index, test_index): + pt = Chem.GetPeriodicTable() + atom_name_list = [] + atom_count_list = [] + for i in range(2, 119): + atom_name_list.append(pt.GetElementSymbol(i)) + atom_count_list.append(0) + atom_name_list.append('*') + atom_count_list.append(0) + n_atoms_per_mol = [0] * 500 + bond_count_list = [0, 0, 0, 0, 0] + bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} + valencies = [0] * 500 + tansition_E = np.zeros((118, 118, 5)) + + filename = f'{source_name}.csv.gz' + df = pd.read_csv(f'{root}/{filename}') + all_index = list(range(len(df))) + non_test_index = list(set(all_index) - set(test_index)) + df = df.iloc[non_test_index] + tot_smiles = df['smiles'].tolist() + + n_atom_list = [] + n_bond_list = [] + for i, sms in enumerate(tot_smiles): + try: + mol = Chem.MolFromSmiles(sms) + except: + continue + + n_atom = mol.GetNumHeavyAtoms() + n_bond = mol.GetNumBonds() + n_atom_list.append(n_atom) + n_bond_list.append(n_bond) + + n_atoms_per_mol[n_atom] += 1 + cur_atom_count_arr = np.zeros(118) + for atom in mol.GetAtoms(): + symbol = atom.GetSymbol() + if symbol == 'H': + continue + elif symbol == '*': + atom_count_list[-1] += 1 + cur_atom_count_arr[-1] += 1 + else: + atom_count_list[atom.GetAtomicNum()-2] += 1 + cur_atom_count_arr[atom.GetAtomicNum()-2] += 1 + try: + valencies[int(atom.GetExplicitValence())] += 1 + except: + print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence())) + + tansition_E_temp = np.zeros((118, 118, 5)) + for bond in mol.GetBonds(): + start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() + if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H': + continue + + if start_atom.GetSymbol() == '*': + start_index = 117 + else: + start_index = start_atom.GetAtomicNum() - 2 + if end_atom.GetSymbol() == '*': + end_index = 117 + else: + end_index = end_atom.GetAtomicNum() - 2 + + bond_type = bond.GetBondType() + bond_index = bond_type_to_index[bond_type] + bond_count_list[bond_index] += 2 + + tansition_E[start_index, end_index, bond_index] += 2 + tansition_E[end_index, start_index, bond_index] += 2 + tansition_E_temp[start_index, end_index, bond_index] += 2 + tansition_E_temp[end_index, start_index, bond_index] += 2 + + bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 + cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118 + cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118 + tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1) + assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' + + n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) + n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] + + atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) + print('processed meta info: ------', filename, '------') + print('len atom_count_list', len(atom_count_list)) + print('len atom_name_list', len(atom_name_list)) + active_atoms = np.array(atom_name_list)[atom_count_list > 0] + active_atoms = active_atoms.tolist() + atom_count_list = atom_count_list.tolist() + + bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) + bond_count_list = bond_count_list.tolist() + valencies = np.array(valencies) / np.sum(valencies) + valencies = valencies.tolist() + + no_edge = np.sum(tansition_E, axis=-1) == 0 + first_elt = tansition_E[:, :, 0] + first_elt[no_edge] = 1 + tansition_E[:, :, 0] = first_elt + + tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True) + + meta_dict = { + 'source': source_name, + 'num_graph': len(n_atom_list), + 'n_atoms_per_mol_dist': n_atoms_per_mol, + 'max_node': max(n_atom_list), + 'max_bond': max(n_bond_list), + 'atom_type_dist': atom_count_list, + 'bond_type_dist': bond_count_list, + 'valencies': valencies, + 'active_atoms': active_atoms, + 'num_atom_type': len(active_atoms), + 'transition_E': tansition_E.tolist(), + } + + with open(f'{root}/{source_name}.meta.json', "w") as f: + json.dump(meta_dict, f) + + return meta_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/mcd/diffusion/__init__.py b/mcd/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcd/diffusion/diffusion_utils.py b/mcd/diffusion/diffusion_utils.py new file mode 100644 index 0000000..ca1f40a --- /dev/null +++ b/mcd/diffusion/diffusion_utils.py @@ -0,0 +1,224 @@ +import torch +from torch.nn import functional as F +import numpy as np +from utils import PlaceHolder + + +def sum_except_batch(x): + return x.reshape(x.size(0), -1).sum(dim=-1) + +def assert_correctly_masked(variable, node_mask): + assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \ + 'Variables not masked properly.' + +def cosine_beta_schedule_discrete(timesteps, s=0.008): + """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """ + steps = timesteps + 2 + x = np.linspace(0, steps, steps) + + alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = 1 - alphas + return betas.squeeze() + +def custom_beta_schedule_discrete(timesteps, average_num_nodes=30, s=0.008): + """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """ + steps = timesteps + 2 + x = np.linspace(0, steps, steps) + + alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = 1 - alphas + + assert timesteps >= 100 + + p = 4 / 5 # 1 - 1 / num_edge_classes + num_edges = average_num_nodes * (average_num_nodes - 1) / 2 + + # First 100 steps: only a few updates per graph + updates_per_graph = 1.2 + beta_first = updates_per_graph / (p * num_edges) + + betas[betas < beta_first] = beta_first + return np.array(betas) + + +def check_mask_correct(variables, node_mask): + for i, variable in enumerate(variables): + if len(variable) > 0: + assert_correctly_masked(variable, node_mask) + + +def check_tensor_same_size(*args): + for i, arg in enumerate(args): + if i == 0: + continue + assert args[0].size() == arg.size() + + + +def reverse_tensor(x): + return x[torch.arange(x.size(0) - 1, -1, -1)] + + +def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True): + ''' Sample features from multinomial distribution with given probabilities (probX, probE, proby) + :param probX: bs, n, dx_out node features + :param probE: bs, n, n, de_out edge features + :param proby: bs, dy_out global features. + ''' + bs, n, _ = probX.shape + + # Noise X + # The masked rows should define probability distributions as well + probX[~node_mask] = 1 / probX.shape[-1] + + # Flatten the probability tensor to sample with multinomial + probX = probX.reshape(bs * n, -1) # (bs * n, dx_out) + + # Sample X + probX = probX + 1e-12 + probX = probX / probX.sum(dim=-1, keepdim=True) + X_t = probX.multinomial(1) # (bs * n, 1) + X_t = X_t.reshape(bs, n) # (bs, n) + + # Noise E + # The masked rows should define probability distributions as well + inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2)) + diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1) + + probE[inverse_edge_mask] = 1 / probE.shape[-1] + probE[diag_mask.bool()] = 1 / probE.shape[-1] + probE = probE.reshape(bs * n * n, -1) # (bs * n * n, de_out) + probE = probE + 1e-12 + probE = probE / probE.sum(dim=-1, keepdim=True) + + # Sample E + E_t = probE.multinomial(1).reshape(bs, n, n) # (bs, n, n) + E_t = torch.triu(E_t, diagonal=1) + E_t = (E_t + torch.transpose(E_t, 1, 2)) + + return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t)) + + +def compute_batched_over0_posterior_distribution(X_t, Qt, Qsb, Qtb): + """ M: X or E + Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0 + X_t: bs, n, dt or bs, n, n, dt + Qt: bs, d_t-1, dt + Qsb: bs, d0, d_t-1 + Qtb: bs, d0, dt. + """ + X_t = X_t.float() + Qt_T = Qt.transpose(-1, -2).float() # bs, N, dt + assert Qt.dim() == 3 + left_term = X_t @ Qt_T + left_term = left_term.unsqueeze(dim=2) # bs, N, 1, d_t-1 + right_term = Qsb.unsqueeze(1) + numerator = left_term * right_term # bs, N, d0, d_t-1 + + denominator = Qtb @ X_t.transpose(-1, -2) # bs, d0, N + denominator = denominator.transpose(-1, -2) # bs, N, d0 + denominator = denominator.unsqueeze(-1) # bs, N, d0, 1 + + denominator[denominator == 0] = 1. + return numerator / denominator + + +def mask_distributions(true_X, true_E, pred_X, pred_E, node_mask): + # Add a small value everywhere to avoid nans + pred_X = pred_X.clamp_min(1e-18) + pred_X = pred_X / torch.sum(pred_X, dim=-1, keepdim=True) + + pred_E = pred_E.clamp_min(1e-18) + pred_E = pred_E / torch.sum(pred_E, dim=-1, keepdim=True) + + # Set masked rows to arbitrary distributions, so it doesn't contribute to loss + row_X = torch.ones(true_X.size(-1), dtype=true_X.dtype, device=true_X.device) + row_E = torch.zeros(true_E.size(-1), dtype=true_E.dtype, device=true_E.device).clamp_min(1e-18) + row_E[0] = 1. + + diag_mask = ~torch.eye(node_mask.size(1), device=node_mask.device, dtype=torch.bool).unsqueeze(0) + true_X[~node_mask] = row_X + true_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E + pred_X[~node_mask] = row_X + pred_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E + + return true_X, true_E, pred_X, pred_E + +def posterior_distributions(X, X_t, Qt, Qsb, Qtb, X_dim): + bs, n, d = X.shape + X = X.float() + Qt_X_T = torch.transpose(Qt.X, -2, -1).float() # (bs, d, d) + left_term = X_t @ Qt_X_T # (bs, N, d) + right_term = X @ Qsb.X # (bs, N, d) + + numerator = left_term * right_term # (bs, N, d) + denominator = X @ Qtb.X # (bs, N, d) @ (bs, d, d) = (bs, N, d) + denominator = denominator * X_t + + num_X = numerator[:, :, :X_dim] + num_E = numerator[:, :, X_dim:].reshape(bs, n*n, -1) + + deno_X = denominator[:, :, :X_dim] + deno_E = denominator[:, :, X_dim:].reshape(bs, n*n, -1) + + # denominator = (denominator * X_t).sum(dim=-1) # (bs, N, d) * (bs, N, d) + sum = (bs, N) + denominator = denominator.unsqueeze(-1) # (bs, N, 1) + + deno_X = deno_X.sum(dim=-1).unsqueeze(-1) + deno_E = deno_E.sum(dim=-1).unsqueeze(-1) + + deno_X[deno_X == 0.] = 1 + deno_E[deno_E == 0.] = 1 + prob_X = num_X / deno_X + prob_E = num_E / deno_E + + prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True) + prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True) + return PlaceHolder(X=prob_X, E=prob_E, y=None) + + +def sample_discrete_feature_noise(limit_dist, node_mask): + """ Sample from the limit distribution of the diffusion process""" + bs, n_max = node_mask.shape + x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1) + U_X = x_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max) + U_X = F.one_hot(U_X.long(), num_classes=x_limit.shape[-1]).float() + + e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1) + U_E = e_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max, n_max) + U_E = F.one_hot(U_E.long(), num_classes=e_limit.shape[-1]).float() + + U_X = U_X.to(node_mask.device) + U_E = U_E.to(node_mask.device) + + # Get upper triangular part of edge noise, without main diagonal + upper_triangular_mask = torch.zeros_like(U_E) + indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1) + upper_triangular_mask[:, indices[0], indices[1], :] = 1 + + U_E = U_E * upper_triangular_mask + U_E = (U_E + torch.transpose(U_E, 1, 2)) + + assert (U_E == torch.transpose(U_E, 1, 2)).all() + return PlaceHolder(X=U_X, E=U_E, y=None).mask(node_mask) + +def index_QE(X, q_e, n_bond=5): + bs, n, n_atom = X.shape + node_indices = X.argmax(-1) # (bs, n) + + exp_ind1 = node_indices[ :, :, None, None, None].expand(bs, n, n_atom, n_bond, n_bond) + exp_ind2 = node_indices[ :, :, None, None, None].expand(bs, n, n, n_bond, n_bond) + + q_e = torch.gather(q_e, 1, exp_ind1) + q_e = torch.gather(q_e, 2, exp_ind2) # (bs, n, n, n_bond, n_bond) + + + node_mask = X.sum(-1) != 0 + no_edge = (~node_mask)[:, :, None] & (~node_mask)[:, None, :] + q_e[no_edge] = torch.tensor([1, 0, 0, 0, 0]).type_as(q_e) + + return q_e diff --git a/mcd/diffusion/distributions.py b/mcd/diffusion/distributions.py new file mode 100644 index 0000000..71d5625 --- /dev/null +++ b/mcd/diffusion/distributions.py @@ -0,0 +1,30 @@ +import torch + +class DistributionNodes: + def __init__(self, histogram): + """ Compute the distribution of the number of nodes in the dataset, and sample from this distribution. + historgram: dict. The keys are num_nodes, the values are counts + """ + + if type(histogram) == dict: + max_n_nodes = max(histogram.keys()) + prob = torch.zeros(max_n_nodes + 1) + for num_nodes, count in histogram.items(): + prob[num_nodes] = count + else: + prob = histogram + + self.prob = prob / prob.sum() + self.m = torch.distributions.Categorical(prob) + + def sample_n(self, n_samples, device): + idx = self.m.sample((n_samples,)) + return idx.to(device) + + def log_prob(self, batch_n_nodes): + assert len(batch_n_nodes.size()) == 1 + p = self.prob.to(batch_n_nodes.device) + + probas = p[batch_n_nodes] + log_p = torch.log(probas + 1e-30) + return log_p diff --git a/mcd/diffusion/noise_schedule.py b/mcd/diffusion/noise_schedule.py new file mode 100644 index 0000000..dd92ab3 --- /dev/null +++ b/mcd/diffusion/noise_schedule.py @@ -0,0 +1,159 @@ +import torch +import utils +from diffusion import diffusion_utils + +class PredefinedNoiseScheduleDiscrete(torch.nn.Module): + def __init__(self, noise_schedule, timesteps): + super(PredefinedNoiseScheduleDiscrete, self).__init__() + self.timesteps = timesteps + + if noise_schedule == 'cosine': + betas = diffusion_utils.cosine_beta_schedule_discrete(timesteps) + elif noise_schedule == 'custom': + betas = diffusion_utils.custom_beta_schedule_discrete(timesteps) + else: + raise NotImplementedError(noise_schedule) + + self.register_buffer('betas', torch.from_numpy(betas).float()) + + # 0.9999 + self.alphas = 1 - torch.clamp(self.betas, min=0, max=1) + + log_alpha = torch.log(self.alphas) + log_alpha_bar = torch.cumsum(log_alpha, dim=0) + self.alphas_bar = torch.exp(log_alpha_bar) + + def forward(self, t_normalized=None, t_int=None): + assert int(t_normalized is None) + int(t_int is None) == 1 + if t_int is None: + t_int = torch.round(t_normalized * self.timesteps) + return self.betas[t_int.long()] + + def get_alpha_bar(self, t_normalized=None, t_int=None): + assert int(t_normalized is None) + int(t_int is None) == 1 + if t_int is None: + t_int = torch.round(t_normalized * self.timesteps) + ### new + self.alphas_bar = self.alphas_bar.to(t_int.device) + return self.alphas_bar[t_int.long()] + + +class DiscreteUniformTransition: + def __init__(self, x_classes: int, e_classes: int, y_classes: int): + self.X_classes = x_classes + self.E_classes = e_classes + self.y_classes = y_classes + self.u_x = torch.ones(1, self.X_classes, self.X_classes) + if self.X_classes > 0: + self.u_x = self.u_x / self.X_classes + + self.u_e = torch.ones(1, self.E_classes, self.E_classes) + if self.E_classes > 0: + self.u_e = self.u_e / self.E_classes + + self.u_y = torch.ones(1, self.y_classes, self.y_classes) + if self.y_classes > 0: + self.u_y = self.u_y / self.y_classes + + def get_Qt(self, beta_t, device, X=None, flatten_e=None): + """ Returns one-step transition matrices for X and E, from step t - 1 to step t. + Qt = (1 - beta_t) * I + beta_t / K + + beta_t: (bs) noise level between 0 and 1 + returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). + """ + beta_t = beta_t.unsqueeze(1) + beta_t = beta_t.to(device) + self.u_x = self.u_x.to(device) + self.u_e = self.u_e.to(device) + self.u_y = self.u_y.to(device) + + q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes, device=device).unsqueeze(0) + q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye(self.E_classes, device=device).unsqueeze(0) + q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye(self.y_classes, device=device).unsqueeze(0) + + return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) + + def get_Qt_bar(self, alpha_bar_t, device, X=None, flatten_e=None): + """ Returns t-step transition matrices for X and E, from step 0 to step t. + Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) / K + + alpha_bar_t: (bs) Product of the (1 - beta_t) for each time step from 0 to t. + returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). + """ + alpha_bar_t = alpha_bar_t.unsqueeze(1) + alpha_bar_t = alpha_bar_t.to(device) + self.u_x = self.u_x.to(device) + self.u_e = self.u_e.to(device) + self.u_y = self.u_y.to(device) + + q_x = alpha_bar_t * torch.eye(self.X_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_x + q_e = alpha_bar_t * torch.eye(self.E_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_e + q_y = alpha_bar_t * torch.eye(self.y_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_y + + return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) + + +class MarginalTransition: + def __init__(self, x_marginals, e_marginals, xe_conditions, ex_conditions, y_classes, n_nodes): + self.X_classes = len(x_marginals) + self.E_classes = len(e_marginals) + self.y_classes = y_classes + self.x_marginals = x_marginals # Dx + self.e_marginals = e_marginals # Dx, De + self.xe_conditions = xe_conditions + + self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) # 1, Dx, Dx + self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) # 1, De, De + self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De + self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx + self.u = self.get_union_transition(self.u_x, self.u_e, self.u_xe, self.u_ex, n_nodes) # 1, Dx + n*De, Dx + n*De + + def get_union_transition(self, u_x, u_e, u_xe, u_ex, n_nodes): + u_e = u_e.repeat(1, n_nodes, n_nodes) # (1, n*de, n*de) + u_xe = u_xe.repeat(1, 1, n_nodes) # (1, dx, n*de) + u_ex = u_ex.repeat(1, n_nodes, 1) # (1, n*de, dx) + u0 = torch.cat([u_x, u_xe], dim=2) # (1, dx, dx + n*de) + u1 = torch.cat([u_ex, u_e], dim=2) # (1, n*de, dx + n*de) + u = torch.cat([u0, u1], dim=1) # (1, dx + n*de, dx + n*de) + return u + + def index_edge_margin(self, X, q_e, n_bond=5): + # q_e: (bs, dx, de) --> (bs, n, de) + bs, n, n_atom = X.shape + node_indices = X.argmax(-1) # (bs, n) + ind = node_indices[ :, :, None].expand(bs, n, n_bond) + q_e = torch.gather(q_e, 1, ind) + return q_e + + def get_Qt(self, beta_t, device): + """ Returns one-step transition matrices for X and E, from step t - 1 to step t. + Qt = (1 - beta_t) * I + beta_t / K + beta_t: (bs) + returns: q (bs, d0, d0) + """ + bs = beta_t.size(0) + d0 = self.u.size(-1) + self.u = self.u.to(device) + u = self.u.expand(bs, d0, d0) + + beta_t = beta_t.to(device) + beta_t = beta_t.view(bs, 1, 1) + q = beta_t * u + (1 - beta_t) * torch.eye(d0, device=device).unsqueeze(0) + + return utils.PlaceHolder(X=q, E=None, y=None) + + def get_Qt_bar(self, alpha_bar_t, device): + """ Returns t-step transition matrices for X and E, from step 0 to step t. + Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) * K + alpha_bar_t: (bs, 1) roduct of the (1 - beta_t) for each time step from 0 to t. + returns: q (bs, d0, d0) + """ + bs = alpha_bar_t.size(0) + d0 = self.u.size(-1) + alpha_bar_t = alpha_bar_t.to(device) + alpha_bar_t = alpha_bar_t.view(bs, 1, 1) + self.u = self.u.to(device) + q = alpha_bar_t * torch.eye(d0, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u + + return utils.PlaceHolder(X=q, E=None, y=None) \ No newline at end of file diff --git a/mcd/diffusion_model.py b/mcd/diffusion_model.py new file mode 100644 index 0000000..90d4e8b --- /dev/null +++ b/mcd/diffusion_model.py @@ -0,0 +1,619 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl +import time +import os + +from models.transformer import Denoiser +from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition + +from diffusion import diffusion_utils +from metrics.train_loss import TrainLossDiscrete +from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL +import utils + +class MCD(pl.LightningModule): + def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): + super().__init__() + self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) + self.test_only = cfg.general.test_only + self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) + + input_dims = dataset_infos.input_dims + output_dims = dataset_infos.output_dims + nodes_dist = dataset_infos.nodes_dist + active_index = dataset_infos.active_index + + self.cfg = cfg + self.name = cfg.general.name + self.T = cfg.model.diffusion_steps + self.guide_scale = cfg.model.guide_scale + + self.Xdim = input_dims['X'] + self.Edim = input_dims['E'] + self.ydim = input_dims['y'] + self.Xdim_output = output_dims['X'] + self.Edim_output = output_dims['E'] + self.ydim_output = output_dims['y'] + self.node_dist = nodes_dist + self.active_index = active_index + self.dataset_info = dataset_infos + + self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train) + + self.val_nll = NLL() + self.val_X_kl = SumExceptBatchKL() + self.val_E_kl = SumExceptBatchKL() + self.val_X_logp = SumExceptBatchMetric() + self.val_E_logp = SumExceptBatchMetric() + self.val_y_collection = [] + + self.test_nll = NLL() + self.test_X_kl = SumExceptBatchKL() + self.test_E_kl = SumExceptBatchKL() + self.test_X_logp = SumExceptBatchMetric() + self.test_E_logp = SumExceptBatchMetric() + self.test_y_collection = [] + + self.train_metrics = train_metrics + self.sampling_metrics = sampling_metrics + + self.visualization_tools = visualization_tools + self.max_n_nodes = dataset_infos.max_n_nodes + + self.model = Denoiser(max_n_nodes=self.max_n_nodes, + hidden_size=cfg.model.hidden_size, + depth=cfg.model.depth, + num_heads=cfg.model.num_heads, + mlp_ratio=cfg.model.mlp_ratio, + drop_condition=cfg.model.drop_condition, + Xdim=self.Xdim, + Edim=self.Edim, + ydim=self.ydim, + task_type=dataset_infos.task_type) + + self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule, + timesteps=cfg.model.diffusion_steps) + + + x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) + + e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float()) + x_marginals = x_marginals / (x_marginals ).sum() + e_marginals = e_marginals / (e_marginals ).sum() + + xe_conditions = self.dataset_info.transition_E.float() + xe_conditions = xe_conditions[self.active_index][:, self.active_index] + + xe_conditions = xe_conditions.sum(dim=1) + ex_conditions = xe_conditions.t() + xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) + ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) + + self.transition_model = MarginalTransition(x_marginals=x_marginals, + e_marginals=e_marginals, + xe_conditions=xe_conditions, + ex_conditions=ex_conditions, + y_classes=self.ydim_output, + n_nodes=self.max_n_nodes) + + self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) + + self.start_epoch_time = None + self.train_iterations = None + self.val_iterations = None + self.log_every_steps = cfg.general.log_every_steps + self.number_chain_steps = cfg.general.number_chain_steps + + self.best_val_nll = 1e8 + self.val_counter = 0 + self.batch_size = self.cfg.train.batch_size + + + def forward(self, noisy_data, unconditioned=False): + x, e, y = noisy_data['X_t'].float(), noisy_data['E_t'].float(), noisy_data['y_t'].float().clone() + node_mask, t = noisy_data['node_mask'], noisy_data['t'] + pred = self.model(x, e, node_mask, y=y, t=t, unconditioned=unconditioned) + return pred + + def training_step(self, data, i): + data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] + data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() + + dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) + dense_data = dense_data.mask(node_mask) + X, E = dense_data.X, dense_data.E + noisy_data = self.apply_noise(X, E, data.y, node_mask) + pred = self.forward(noisy_data) + loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, + true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, + log=i % self.log_every_steps == 0) + + self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, + log=i % self.log_every_steps == 0) + self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) + return {'loss': loss} + + + def configure_optimizers(self): + params = self.parameters() + optimizer = torch.optim.AdamW(params, lr=self.cfg.train.lr, amsgrad=True, + weight_decay=self.cfg.train.weight_decay) + return optimizer + + def on_fit_start(self) -> None: + self.train_iterations = self.trainer.datamodule.training_iterations + print('on fit train iteration:', self.train_iterations) + print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) + + def on_train_epoch_start(self) -> None: + if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: + print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) + self.start_epoch_time = time.time() + self.train_loss.reset() + self.train_metrics.reset() + + def on_train_epoch_end(self) -> None: + if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: + log = True + else: + log = False + self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log) + self.train_metrics.log_epoch_metrics(self.current_epoch, log) + + def on_validation_epoch_start(self) -> None: + self.val_nll.reset() + self.val_X_kl.reset() + self.val_E_kl.reset() + self.val_X_logp.reset() + self.val_E_logp.reset() + self.sampling_metrics.reset() + self.val_y_collection = [] + + @torch.no_grad() + def validation_step(self, data, i): + data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] + data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() + + dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) + dense_data = dense_data.mask(node_mask) + noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) + pred = self.forward(noisy_data) + nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) + self.val_y_collection.append(data.y) + self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True) + return {'loss': nll} + + def on_validation_epoch_end(self) -> None: + metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T, + self.val_X_logp.compute(), self.val_E_logp.compute()] + + if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: + print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", + f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll)) + + # Log val nll with default Lightning logger, so it can be monitored by checkpoint callback + self.log("val/NLL", metrics[0], sync_dist=True) + + if metrics[0] < self.best_val_nll: + self.best_val_nll = metrics[0] + + self.val_counter += 1 + + if self.val_counter % self.cfg.general.sample_every_val == 0 and self.val_counter > 1: + start = time.time() + samples_left_to_generate = self.cfg.general.samples_to_generate + samples_left_to_save = self.cfg.general.samples_to_save + chains_left_to_save = self.cfg.general.chains_to_save + + samples, all_ys, ident = [], [], 0 + + self.val_y_collection = torch.cat(self.val_y_collection, dim=0) + num_examples = self.val_y_collection.size(0) + start_index = 0 + while samples_left_to_generate > 0: + bs = 1 * self.cfg.train.batch_size + to_generate = min(samples_left_to_generate, bs) + to_save = min(samples_left_to_save, bs) + chains_save = min(chains_left_to_save, bs) + + if start_index + to_generate > num_examples: + start_index = 0 + if to_generate > num_examples: + ratio = to_generate // num_examples + self.val_y_collection = self.val_y_collection.repeat(ratio+1, 1) + num_examples = self.val_y_collection.size(0) + batch_y = self.val_y_collection[start_index:start_index + to_generate] + all_ys.append(batch_y) + samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, + save_final=to_save, + keep_chain=chains_save, + number_chain_steps=self.number_chain_steps)) + ident += to_generate + start_index += to_generate + + samples_left_to_save -= to_save + samples_left_to_generate -= to_generate + chains_left_to_save -= chains_save + + print(f"Computing sampling metrics", ' ...') + valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) + print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') + current_path = os.getcwd() + result_path = os.path.join(current_path, + f'graphs/{self.name}/epoch{self.current_epoch}_b0/') + self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) + self.sampling_metrics.reset() + + def on_test_epoch_start(self) -> None: + print("Starting test...") + self.test_nll.reset() + self.test_X_kl.reset() + self.test_E_kl.reset() + self.test_X_logp.reset() + self.test_E_logp.reset() + self.test_y_collection = [] + + @torch.no_grad() + def test_step(self, data, i): + data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] + data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() + + dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) + dense_data = dense_data.mask(node_mask) + noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) + pred = self.forward(noisy_data) + nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) + self.test_y_collection.append(data.y) + return {'loss': nll} + + def on_test_epoch_end(self) -> None: + """ Measure likelihood on a test set and compute stability metrics. """ + metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(), + self.test_X_logp.compute(), self.test_E_logp.compute()] + + print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ", + f"Test Edge type KL: {metrics[2] :.2f}") + + ## final epcoh + samples_left_to_generate = self.cfg.general.final_model_samples_to_generate + samples_left_to_save = self.cfg.general.final_model_samples_to_save + chains_left_to_save = self.cfg.general.final_model_chains_to_save + + samples, all_ys, batch_id = [], [], 0 + + test_y_collection = torch.cat(self.test_y_collection, dim=0) + num_examples = test_y_collection.size(0) + if self.cfg.general.final_model_samples_to_generate > num_examples: + ratio = self.cfg.general.final_model_samples_to_generate // num_examples + test_y_collection = test_y_collection.repeat(ratio+1, 1) + num_examples = test_y_collection.size(0) + + while samples_left_to_generate > 0: + print(f'samples left to generate: {samples_left_to_generate}/' + f'{self.cfg.general.final_model_samples_to_generate}', end='', flush=True) + bs = 1 * self.cfg.train.batch_size + to_generate = min(samples_left_to_generate, bs) + to_save = min(samples_left_to_save, bs) + chains_save = min(chains_left_to_save, bs) + batch_y = test_y_collection[batch_id : batch_id + to_generate] + + cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, + keep_chain=chains_save, number_chain_steps=self.number_chain_steps) + samples = samples + cur_sample + + all_ys.append(batch_y) + batch_id += to_generate + + samples_left_to_save -= to_save + samples_left_to_generate -= to_generate + chains_left_to_save -= chains_save + + print(f"final Computing sampling metrics...") + self.sampling_metrics.reset() + self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, self.val_counter, test=True) + self.sampling_metrics.reset() + print(f"Done.") + + + def kl_prior(self, X, E, node_mask): + """Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1). + + This is essentially a lot of work for something that is in practice negligible in the loss. However, you + compute it so that you see it when you've made a mistake in your noise schedule. + """ + # Compute the last alpha value, alpha_T. + ones = torch.ones((X.size(0), 1), device=X.device) + Ts = self.T * ones + alpha_t_bar = self.noise_schedule.get_alpha_bar(t_int=Ts) # (bs, 1) + + Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) + + bs, n, d = X.shape + X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) + prob_all = X_all @ Qtb.X + probX = prob_all[:, :, :self.Xdim_output] + probE = prob_all[:, :, self.Xdim_output:].reshape((bs, n, n, -1)) + + assert probX.shape == X.shape + + limit_X = self.limit_dist.X[None, None, :].expand(bs, n, -1).type_as(probX) + limit_E = self.limit_dist.E[None, None, None, :].expand(bs, n, n, -1).type_as(probE) + + # Make sure that masked rows do not contribute to the loss + limit_dist_X, limit_dist_E, probX, probE = diffusion_utils.mask_distributions(true_X=limit_X.clone(), + true_E=limit_E.clone(), + pred_X=probX, + pred_E=probE, + node_mask=node_mask) + + kl_distance_X = F.kl_div(input=probX.log(), target=limit_dist_X, reduction='none') + kl_distance_E = F.kl_div(input=probE.log(), target=limit_dist_E, reduction='none') + + return diffusion_utils.sum_except_batch(kl_distance_X) + \ + diffusion_utils.sum_except_batch(kl_distance_E) + + def compute_Lt(self, X, E, y, pred, noisy_data, node_mask, test): + pred_probs_X = F.softmax(pred.X, dim=-1) + pred_probs_E = F.softmax(pred.E, dim=-1) + + Qtb = self.transition_model.get_Qt_bar(noisy_data['alpha_t_bar'], self.device) + Qsb = self.transition_model.get_Qt_bar(noisy_data['alpha_s_bar'], self.device) + Qt = self.transition_model.get_Qt(noisy_data['beta_t'], self.device) + + # Compute distributions to compare with KL + bs, n, d = X.shape + X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1).float() + Xt_all = torch.cat([noisy_data['X_t'], noisy_data['E_t'].reshape(bs, n, -1)], dim=-1).float() + pred_probs_all = torch.cat([pred_probs_X, pred_probs_E.reshape(bs, n, -1)], dim=-1).float() + + prob_true = diffusion_utils.posterior_distributions(X=X_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.Xdim_output) + prob_true.E = prob_true.E.reshape((bs, n, n, -1)) + prob_pred = diffusion_utils.posterior_distributions(X=pred_probs_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.Xdim_output) + prob_pred.E = prob_pred.E.reshape((bs, n, n, -1)) + + # Reshape and filter masked rows + prob_true_X, prob_true_E, prob_pred.X, prob_pred.E = diffusion_utils.mask_distributions(true_X=prob_true.X, + true_E=prob_true.E, + pred_X=prob_pred.X, + pred_E=prob_pred.E, + node_mask=node_mask) + kl_x = (self.test_X_kl if test else self.val_X_kl)(prob_true.X, torch.log(prob_pred.X)) + kl_e = (self.test_E_kl if test else self.val_E_kl)(prob_true.E, torch.log(prob_pred.E)) + + return self.T * (kl_x + kl_e) + + def reconstruction_logp(self, t, X, E, y, node_mask): + # Compute noise values for t = 0. + t_zeros = torch.zeros_like(t) + beta_0 = self.noise_schedule(t_zeros) + Q0 = self.transition_model.get_Qt(beta_0, self.device) + + bs, n, d = X.shape + X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) + prob_all = X_all @ Q0.X + probX0 = prob_all[:, :, :self.Xdim_output] + probE0 = prob_all[:, :, self.Xdim_output:].reshape((bs, n, n, -1)) + + sampled0 = diffusion_utils.sample_discrete_features(probX=probX0, probE=probE0, node_mask=node_mask) + + X0 = F.one_hot(sampled0.X, num_classes=self.Xdim_output).float() + E0 = F.one_hot(sampled0.E, num_classes=self.Edim_output).float() + + assert (X.shape == X0.shape) and (E.shape == E0.shape) + sampled_0 = utils.PlaceHolder(X=X0, E=E0, y=y).mask(node_mask) + + # Predictions + noisy_data = {'X_t': sampled_0.X, 'E_t': sampled_0.E, 'y_t': sampled_0.y, 'node_mask': node_mask, + 't': torch.zeros(X0.shape[0], 1).type_as(y)} + pred0 = self.forward(noisy_data) + + # Normalize predictions + probX0 = F.softmax(pred0.X, dim=-1) + probE0 = F.softmax(pred0.E, dim=-1) + proby0 = None + + # Set masked rows to arbitrary values that don't contribute to loss + probX0[~node_mask] = torch.ones(self.Xdim_output).type_as(probX0) + probE0[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))] = torch.ones(self.Edim_output).type_as(probE0) + + diag_mask = torch.eye(probE0.size(1)).type_as(probE0).bool() + diag_mask = diag_mask.unsqueeze(0).expand(probE0.size(0), -1, -1) + probE0[diag_mask] = torch.ones(self.Edim_output).type_as(probE0) + + return utils.PlaceHolder(X=probX0, E=probE0, y=proby0) + + def apply_noise(self, X, E, y, node_mask): + """ Sample noise and apply it to the data. """ + + # Sample a timestep t. + # When evaluating, the loss for t=0 is computed separately + lowest_t = 0 if self.training else 1 + t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1) + s_int = t_int - 1 + + t_float = t_int / self.T + s_float = s_int / self.T + + # beta_t and alpha_s_bar are used for denoising/loss computation + beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) + alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) + alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) + + Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out) + + bs, n, d = X.shape + X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) + prob_all = X_all @ Qtb.X + probX = prob_all[:, :, :self.Xdim_output] + probE = prob_all[:, :, self.Xdim_output:].reshape(bs, n, n, -1) + + sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask) + + X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) + E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) + assert (X.shape == X_t.shape) and (E.shape == E_t.shape) + + y_t = y + z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) + + noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar, + 'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask} + return noisy_data + + def compute_val_loss(self, pred, noisy_data, X, E, y, node_mask, test=False): + """Computes an estimator for the variational lower bound. + pred: (batch_size, n, total_features) + noisy_data: dict + X, E, y : (bs, n, dx), (bs, n, n, de), (bs, dy) + node_mask : (bs, n) + Output: nll (size 1) + """ + t = noisy_data['t'] + + # 1. + N = node_mask.sum(1).long() + log_pN = self.node_dist.log_prob(N) + + # 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero. + kl_prior = self.kl_prior(X, E, node_mask) + + # 3. Diffusion loss + loss_all_t = self.compute_Lt(X, E, y, pred, noisy_data, node_mask, test) + + # 4. Reconstruction loss + # Compute L0 term : -log p (X, E, y | z_0) = reconstruction loss + prob0 = self.reconstruction_logp(t, X, E, y, node_mask) + + eps = 1e-8 + loss_term_0 = self.val_X_logp(X * (prob0.X+eps).log()) + self.val_E_logp(E * (prob0.E+eps).log()) + + # Combine terms + nlls = - log_pN + kl_prior + loss_all_t - loss_term_0 + assert len(nlls.shape) == 1, f'{nlls.shape} has more than only batch dim.' + + # Update NLL metric object and return batch nll + nll = (self.test_nll if test else self.val_nll)(nlls) # Average over the batch + + return nll + + @torch.no_grad() + def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None): + """ + :param batch_id: int + :param batch_size: int + :param num_nodes: int, tensor (batch_size) (optional) for specifying number of nodes + :param save_final: int: number of predictions to save to file + :param keep_chain: int: number of chains to save to file (disabled) + :param keep_chain_steps: number of timesteps to save for each chain (disabled) + :return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions) + """ + if num_nodes is None: + n_nodes = self.node_dist.sample_n(batch_size, self.device) + elif type(num_nodes) == int: + n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int) + else: + assert isinstance(num_nodes, torch.Tensor) + n_nodes = num_nodes + n_max = self.max_n_nodes + arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1) + node_mask = arange < n_nodes.unsqueeze(1) + + z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask) + X, E = z_T.X, z_T.E + + assert (E == torch.transpose(E, 1, 2)).all() + + # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. + for s_int in reversed(range(0, self.T)): + s_array = s_int * torch.ones((batch_size, 1)).type_as(y) + t_array = s_array + 1 + s_norm = s_array / self.T + t_norm = t_array / self.T + + # Sample z_s + sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) + X, E, y = sampled_s.X, sampled_s.E, sampled_s.y + + # Sample + sampled_s = sampled_s.mask(node_mask, collapse=True) + X, E, y = sampled_s.X, sampled_s.E, sampled_s.y + + molecule_list = [] + for i in range(batch_size): + n = n_nodes[i] + atom_types = X[i, :n].cpu() + edge_types = E[i, :n, :n].cpu() + molecule_list.append([atom_types, edge_types]) + + return molecule_list + + def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): + """Samples from zs ~ p(zs | zt). Only used during sampling. + if last_step, return the graph prediction as well""" + bs, n, dxs = X_t.shape + beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) + alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) + alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) + + # Neural net predictions + noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask} + + def get_prob(noisy_data, unconditioned=False): + pred = self.forward(noisy_data, unconditioned=unconditioned) + + # Normalize predictions + pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 + pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 + + # Retrieve transitions matrix + Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) + Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device) + Qt = self.transition_model.get_Qt(beta_t, self.device) + + Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) + p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all, + Qt=Qt.X, + Qsb=Qsb.X, + Qtb=Qtb.X) + predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) + weightedX_all = predX_all.unsqueeze(-1) * p_s_and_t_given_0 + unnormalized_probX_all = weightedX_all.sum(dim=2) # bs, n, d_t-1 + + unnormalized_prob_X = unnormalized_probX_all[:, :, :self.Xdim_output] + unnormalized_prob_E = unnormalized_probX_all[:, :, self.Xdim_output:].reshape(bs, n*n, -1) + + unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 + unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 + + prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True) # bs, n, d_t-1 + prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True) # bs, n, d_t-1 + prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) + + return prob_X, prob_E + + prob_X, prob_E = get_prob(noisy_data) + + ### Guidance + if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: + uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True) + prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale + prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale + prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10) + prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10) + + assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() + assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() + + sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) + + X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() + E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() + + assert (E_s == torch.transpose(E_s, 1, 2)).all() + assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) + + out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) + out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) + + return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t) diff --git a/mcd/main.py b/mcd/main.py new file mode 100644 index 0000000..433377d --- /dev/null +++ b/mcd/main.py @@ -0,0 +1,138 @@ +# These imports are tricky because they use c++, do not move them +import os, shutil +import warnings + +import torch +import hydra +from omegaconf import DictConfig +from pytorch_lightning import Trainer + +import utils +from datasets import dataset +from diffusion_model import MCD +from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete +from metrics.molecular_metrics_sampling import SamplingMolecularMetrics + +from analysis.visualization import MolecularVisualization + +warnings.filterwarnings("ignore", category=UserWarning) +torch.set_float32_matmul_precision("medium") + +def remove_folder(folder): + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print("Failed to delete %s. Reason: %s" % (file_path, e)) + + +def get_resume(cfg, model_kwargs): + """Resumes a run. It loads previous config without allowing to update keys (used for testing).""" + saved_cfg = cfg.copy() + name = cfg.general.name + "_resume" + resume = cfg.general.test_only + batch_size = cfg.train.batch_size + model = MCD.load_from_checkpoint(resume, **model_kwargs) + cfg = model.cfg + cfg.general.test_only = resume + cfg.general.name = name + cfg.train.batch_size = batch_size + cfg = utils.update_config_with_new_keys(cfg, saved_cfg) + return cfg, model + +def get_resume_adaptive(cfg, model_kwargs): + """Resumes a run. It loads previous config but allows to make some changes (used for resuming training).""" + saved_cfg = cfg.copy() + # Fetch path to this file to get base path + current_path = os.path.dirname(os.path.realpath(__file__)) + root_dir = current_path.split("outputs")[0] + + resume_path = os.path.join(root_dir, cfg.general.resume) + + if cfg.model.type == "discrete": + model = MCD.load_from_checkpoint( + resume_path, **model_kwargs + ) + else: + raise NotImplementedError("Unknown model") + + new_cfg = model.cfg + for category in cfg: + for arg in cfg[category]: + new_cfg[category][arg] = cfg[category][arg] + + new_cfg.general.resume = resume_path + new_cfg.general.name = new_cfg.general.name + "_resume" + + new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg) + return new_cfg, model + + +@hydra.main( + version_base="1.1", config_path="../configs", config_name="config_dev" +) +def main(cfg: DictConfig): + + datamodule = dataset.DataModule(cfg) + datamodule.prepare_data() + dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) + train_smiles, reference_smiles = datamodule.get_train_smiles() + + dataset_infos.compute_input_output_dims(datamodule=datamodule) + train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) + + sampling_metrics = SamplingMolecularMetrics( + dataset_infos, train_smiles, reference_smiles + ) + visualization_tools = MolecularVisualization(dataset_infos) + + model_kwargs = { + "dataset_infos": dataset_infos, + "train_metrics": train_metrics, + "sampling_metrics": sampling_metrics, + "visualization_tools": visualization_tools, + } + + if cfg.general.test_only: + # When testing, previous configuration is fully loaded + cfg, _ = get_resume(cfg, model_kwargs) + os.chdir(cfg.general.test_only.split("checkpoints")[0]) + elif cfg.general.resume is not None: + # When resuming, we can override some parts of previous configuration + cfg, _ = get_resume_adaptive(cfg, model_kwargs) + os.chdir(cfg.general.resume.split("checkpoints")[0]) + + model = MCD(cfg=cfg, **model_kwargs) + trainer = Trainer( + gradient_clip_val=cfg.train.clip_grad, + accelerator="gpu" + if torch.cuda.is_available() and cfg.general.gpus > 0 + else "cpu", + devices=cfg.general.gpus + if torch.cuda.is_available() and cfg.general.gpus > 0 + else None, + max_epochs=cfg.train.n_epochs, + enable_checkpointing=False, + check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, + val_check_interval=cfg.train.val_check_interval, + strategy="ddp" if cfg.general.gpus > 1 else "auto", + enable_progress_bar=cfg.general.enable_progress_bar, + callbacks=[], + reload_dataloaders_every_n_epochs=0, + logger=[], + ) + + if not cfg.general.test_only: + trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) + if cfg.general.save_model: + trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") + trainer.test(model, datamodule=datamodule) + else: + trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) + +if __name__ == "__main__": + main() diff --git a/mcd/metrics/__init__.py b/mcd/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcd/metrics/abstract_metrics.py b/mcd/metrics/abstract_metrics.py new file mode 100644 index 0000000..4d47467 --- /dev/null +++ b/mcd/metrics/abstract_metrics.py @@ -0,0 +1,138 @@ +import torch +from torch import Tensor +from torch.nn import functional as F +from torchmetrics import Metric, MeanSquaredError + + +class TrainAbstractMetricsDiscrete(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): + pass + + def reset(self): + pass + + def log_epoch_metrics(self, current_epoch): + pass + + +class TrainAbstractMetrics(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log): + pass + + def reset(self): + pass + + def log_epoch_metrics(self, current_epoch): + pass + + +class SumExceptBatchMetric(Metric): + def __init__(self): + super().__init__() + self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") + + def update(self, values) -> None: + self.total_value += torch.sum(values) + self.total_samples += values.shape[0] + + def compute(self): + return self.total_value / self.total_samples + + +class SumExceptBatchMSE(MeanSquaredError): + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + assert preds.shape == target.shape + sum_squared_error, n_obs = self._mean_squared_error_update(preds, target) + + self.sum_squared_error += sum_squared_error + self.total += n_obs + + def _mean_squared_error_update(self, preds: Tensor, target: Tensor): + """ Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input + tensors. + preds: Predicted tensor + target: Ground truth tensor + """ + diff = preds - target + sum_squared_error = torch.sum(diff * diff) + n_obs = preds.shape[0] + return sum_squared_error, n_obs + + +class SumExceptBatchKL(Metric): + def __init__(self): + super().__init__() + self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") + + def update(self, p, q) -> None: + self.total_value += F.kl_div(q, p, reduction='sum') + self.total_samples += p.size(0) + + def compute(self): + return self.total_value / self.total_samples + + +class CrossEntropyMetric(Metric): + def __init__(self): + super().__init__() + self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor, weight=None) -> None: + """ Update state with predictions and targets. + preds: Predictions from model (bs * n, d) or (bs * n * n, d) + target: Ground truth values (bs * n, d) or (bs * n * n, d). """ + target = torch.argmax(target, dim=-1) + if weight is not None: + weight = weight.type_as(preds) + output = F.cross_entropy(preds, target, weight = weight, reduction='sum') + else: + output = F.cross_entropy(preds, target, reduction='sum') + self.total_ce += output + self.total_samples += preds.size(0) + + def compute(self): + return self.total_ce / self.total_samples + + +class ProbabilityMetric(Metric): + def __init__(self): + """ This metric is used to track the marginal predicted probability of a class during training. """ + super().__init__() + self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum") + + def update(self, preds: Tensor) -> None: + self.prob += preds.sum() + self.total += preds.numel() + + def compute(self): + return self.prob / self.total + + +class NLL(Metric): + def __init__(self): + super().__init__() + self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") + + def update(self, batch_nll) -> None: + self.total_nll += torch.sum(batch_nll) + self.total_samples += batch_nll.numel() + + def compute(self): + return self.total_nll / self.total_samples \ No newline at end of file diff --git a/mcd/metrics/fpscores.pkl.gz b/mcd/metrics/fpscores.pkl.gz new file mode 100644 index 0000000..6c5738b Binary files /dev/null and b/mcd/metrics/fpscores.pkl.gz differ diff --git a/mcd/metrics/molecular_metrics_sampling.py b/mcd/metrics/molecular_metrics_sampling.py new file mode 100644 index 0000000..4eae271 --- /dev/null +++ b/mcd/metrics/molecular_metrics_sampling.py @@ -0,0 +1,138 @@ +### packages for visualization +from analysis.rdkit_functions import compute_molecular_metrics +from mini_moses.metrics.metrics import compute_intermediate_statistics +from metrics.property_metric import TaskModel + +import torch +import torch.nn as nn + +import os +import csv +import time + +def result_to_csv(path, dict_data): + file_exists = os.path.exists(path) + log_name = dict_data.pop("log_name", None) + if log_name is None: + raise ValueError("The provided dictionary must contain a 'log_name' key.") + field_names = ["log_name"] + list(dict_data.keys()) + dict_data["log_name"] = log_name + with open(path, "a", newline="") as file: + writer = csv.DictWriter(file, fieldnames=field_names) + if not file_exists: + writer.writeheader() + writer.writerow(dict_data) + + +class SamplingMolecularMetrics(nn.Module): + def __init__( + self, + dataset_infos, + train_smiles, + reference_smiles, + n_jobs=1, + device="cpu", + batch_size=512, + ): + super().__init__() + self.task_name = dataset_infos.task + self.dataset_infos = dataset_infos + self.active_atoms = dataset_infos.active_atoms + self.train_smiles = train_smiles + + if reference_smiles is not None: + print( + f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" + ) + start_time = time.time() + self.stat_ref = compute_intermediate_statistics( + reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size + ) + end_time = time.time() + elapsed_time = end_time - start_time + print( + f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" + ) + else: + self.stat_ref = None + + self.comput_config = { + "n_jobs": n_jobs, + "device": device, + "batch_size": batch_size, + } + + self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None} + for cur_task in dataset_infos.task.split("-")[:]: + # print('loading evaluator for task', cur_task) + model_path = os.path.join( + dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib" + ) + os.makedirs(os.path.dirname(model_path), exist_ok=True) + evaluator = TaskModel(model_path, cur_task) + self.task_evaluator[cur_task] = evaluator + + def forward(self, molecules, targets, name, current_epoch, val_counter, test=False): + if isinstance(targets, list): + targets_cat = torch.cat(targets, dim=0) + targets_np = targets_cat.detach().cpu().numpy() + else: + targets_np = targets.detach().cpu().numpy() + + unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics( + molecules, + targets_np, + self.train_smiles, + self.stat_ref, + self.dataset_infos, + self.task_evaluator, + self.comput_config, + ) + + if test: + file_name = "final_smiles.txt" + with open(file_name, "w") as fp: + all_tasks_name = list(self.task_evaluator.keys()) + all_tasks_name = all_tasks_name.copy() + if 'meta_taskname' in all_tasks_name: + all_tasks_name.remove('meta_taskname') + if 'scs' in all_tasks_name: + all_tasks_name.remove('scs') + + all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) + fp.write(all_tasks_str + "\n") + for i, smiles in enumerate(all_smiles): + if targets_log is not None: + all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) + fp.write(all_result_str + "\n") + else: + fp.write("%s\n" % smiles) + print("All smiles saved") + else: + result_path = os.path.join(os.getcwd(), f"graphs/{name}") + os.makedirs(result_path, exist_ok=True) + text_path = os.path.join( + result_path, + f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt", + ) + textfile = open(text_path, "w") + for smiles in unique_smiles: + textfile.write(smiles + "\n") + textfile.close() + + all_logs = all_metrics + if test: + all_logs["log_name"] = "test" + else: + all_logs["log_name"] = ( + "epoch" + str(current_epoch) + "_batch" + str(val_counter) + ) + + result_to_csv("output.csv", all_logs) + return all_smiles + + def reset(self): + pass + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/mcd/metrics/molecular_metrics_train.py b/mcd/metrics/molecular_metrics_train.py new file mode 100644 index 0000000..f9f8779 --- /dev/null +++ b/mcd/metrics/molecular_metrics_train.py @@ -0,0 +1,126 @@ +import torch +from torchmetrics import Metric, MetricCollection +from torch import Tensor +import torch.nn as nn + +class CEPerClass(Metric): + full_state_update = False + def __init__(self, class_id): + super().__init__() + self.class_id = class_id + self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") + self.softmax = torch.nn.Softmax(dim=-1) + self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum') + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + Args: + preds: Predictions from model (bs, n, d) or (bs, n, n, d) + target: Ground truth values (bs, n, d) or (bs, n, n, d) + """ + target = target.reshape(-1, target.shape[-1]) + mask = (target != 0.).any(dim=-1) + + prob = self.softmax(preds)[..., self.class_id] + prob = prob.flatten()[mask] + + target = target[:, self.class_id] + target = target[mask] + + output = self.binary_cross_entropy(prob, target) + + self.total_ce += output + self.total_samples += prob.numel() + + def compute(self): + return self.total_ce / self.total_samples + + +class AtomCE(CEPerClass): + def __init__(self, i): + super().__init__(i) + +class NoBondCE(CEPerClass): + def __init__(self, i): + super().__init__(i) + + +class SingleCE(CEPerClass): + def __init__(self, i): + super().__init__(i) + + +class DoubleCE(CEPerClass): + def __init__(self, i): + super().__init__(i) + + +class TripleCE(CEPerClass): + def __init__(self, i): + super().__init__(i) + + +class AromaticCE(CEPerClass): + def __init__(self, i): + super().__init__(i) + + +class AtomMetricsCE(MetricCollection): + def __init__(self, active_atoms): + metrics_list = [] + + for i, atom_type in enumerate(active_atoms): + metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i)) + + super().__init__(metrics_list) + + +class BondMetricsCE(MetricCollection): + def __init__(self): + ce_no_bond = NoBondCE(0) + ce_SI = SingleCE(1) + ce_DO = DoubleCE(2) + ce_TR = TripleCE(3) + super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) + + +class TrainMolecularMetricsDiscrete(nn.Module): + def __init__(self, dataset_infos): + super().__init__() + active_atoms = dataset_infos.active_atoms + self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms) + self.train_bond_metrics = BondMetricsCE() + + def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): + self.train_atom_metrics(masked_pred_X, true_X) + self.train_bond_metrics(masked_pred_E, true_E) + if log: + to_log = {} + for key, val in self.train_atom_metrics.compute().items(): + to_log['train/' + key] = val.item() + for key, val in self.train_bond_metrics.compute().items(): + to_log['train/' + key] = val.item() + + def reset(self): + for metric in [self.train_atom_metrics, self.train_bond_metrics]: + metric.reset() + + def log_epoch_metrics(self, current_epoch, log=True): + epoch_atom_metrics = self.train_atom_metrics.compute() + epoch_bond_metrics = self.train_bond_metrics.compute() + + to_log = {} + for key, val in epoch_atom_metrics.items(): + to_log['train_epoch/' + key] = val.item() + for key, val in epoch_bond_metrics.items(): + to_log['train_epoch/' + key] = val.item() + + for key, val in epoch_atom_metrics.items(): + epoch_atom_metrics[key] = round(val.item(),4) + for key, val in epoch_bond_metrics.items(): + epoch_bond_metrics[key] = round(val.item(),4) + + if log: + print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}") + diff --git a/mcd/metrics/property_metric.py b/mcd/metrics/property_metric.py new file mode 100644 index 0000000..e0ff0b1 --- /dev/null +++ b/mcd/metrics/property_metric.py @@ -0,0 +1,201 @@ +import math, os +import pickle +import os.path as op + +import numpy as np +import pandas as pd +from joblib import dump, load +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.metrics import mean_absolute_error, roc_auc_score + + +from rdkit import Chem +from rdkit import rdBase +from rdkit.Chem import AllChem +from rdkit import DataStructs +from rdkit.Chem import rdMolDescriptors +rdBase.DisableLog('rdApp.error') + +task_to_colname = { + 'hiv_b': 'HIV_active', + 'bace_b': 'Class', + 'bbbp_b': 'p_np', + 'O2': 'O2', + 'N2': 'N2', + 'CO2': 'CO2', +} + +tasktype_name = { + 'hiv_b': 'classification', + 'bace_b': 'classification', + 'bbbp_b': 'classification', + 'O2': 'regression', + 'N2': 'regression', + 'CO2': 'regression', +} + +class TaskModel(): + """Scores based on an ECFP classifier.""" + def __init__(self, model_path, task_name): + task_type = tasktype_name[task_name] + self.task_name = task_name + self.task_type = task_type + self.model_path = model_path + self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error + + try: + self.model = load(model_path) + print(self.task_name, ' evaluator loaded') + except: + print(self.task_name, ' evaluator not found, training new one...') + if 'classification' in task_type: + self.model = RandomForestClassifier(random_state=0) + elif 'regression' in task_type: + self.model = RandomForestRegressor(random_state=0) + perfermance = self.train() + dump(self.model, model_path) + print('Oracle peformance: ', perfermance) + + def train(self): + data_path = os.path.dirname(self.model_path) + data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz') + df = pd.read_csv(data_path) + col_name = task_to_colname[self.task_name] + y = df[col_name].to_numpy() + x_smiles = df['smiles'].to_numpy() + mask = ~np.isnan(y) + y = y[mask] + + if 'classification' in self.task_type: + y = y.astype(int) + + x_smiles = x_smiles[mask] + x_fps = [] + mask = [] + for i,smiles in enumerate(x_smiles): + mol = Chem.MolFromSmiles(smiles) + mask.append( int(mol is not None) ) + fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048)) + x_fps.append(fp) + x_fps = np.concatenate(x_fps, axis=0) + self.model.fit(x_fps, y) + y_pred = self.model.predict(x_fps) + perf = self.metric_func(y, y_pred) + print(f'{self.task_name} performance: {perf}') + return perf + + def __call__(self, smiles_list): + fps = [] + mask = [] + for i,smiles in enumerate(smiles_list): + mol = Chem.MolFromSmiles(smiles) + mask.append( int(mol is not None) ) + fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048)) + fps.append(fp) + + fps = np.concatenate(fps, axis=0) + if 'classification' in self.task_type: + scores = self.model.predict_proba(fps)[:, 1] + else: + scores = self.model.predict(fps) + scores = scores * np.array(mask) + return np.float32(scores) + + @classmethod + def fingerprints_from_mol(cls, mol): # use ECFP4 + features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) + features = np.zeros((1,)) + DataStructs.ConvertToNumpyArray(features_vec, features) + return features.reshape(1, -1) + +###### SAS Score ###### +_fscores = None + +def readFragmentScores(name='fpscores'): + import gzip + global _fscores + # generate the full path filename: + if name == "fpscores": + name = op.join(op.dirname(__file__), name) + data = pickle.load(gzip.open('%s.pkl.gz' % name)) + outDict = {} + for i in data: + for j in range(1, len(i)): + outDict[i[j]] = float(i[0]) + _fscores = outDict + +def numBridgeheadsAndSpiro(mol, ri=None): + nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) + nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) + return nBridgehead, nSpiro + +def calculateSAS(smiles_list): + scores = [] + for i, smiles in enumerate(smiles_list): + mol = Chem.MolFromSmiles(smiles) + score = calculateScore(mol) + scores.append(score) + return np.float32(scores) + +def calculateScore(m): + if _fscores is None: + readFragmentScores() + + # fragment score + fp = rdMolDescriptors.GetMorganFingerprint(m, + 2) # <- 2 is the *radius* of the circular fingerprint + fps = fp.GetNonzeroElements() + score1 = 0. + nf = 0 + for bitId, v in fps.items(): + nf += v + sfp = bitId + score1 += _fscores.get(sfp, -4) * v + score1 /= nf + + # features score + nAtoms = m.GetNumAtoms() + nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) + ri = m.GetRingInfo() + nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) + nMacrocycles = 0 + for x in ri.AtomRings(): + if len(x) > 8: + nMacrocycles += 1 + + sizePenalty = nAtoms**1.005 - nAtoms + stereoPenalty = math.log10(nChiralCenters + 1) + spiroPenalty = math.log10(nSpiro + 1) + bridgePenalty = math.log10(nBridgeheads + 1) + macrocyclePenalty = 0. + # --------------------------------------- + # This differs from the paper, which defines: + # macrocyclePenalty = math.log10(nMacrocycles+1) + # This form generates better results when 2 or more macrocycles are present + if nMacrocycles > 0: + macrocyclePenalty = math.log10(2) + + score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty + + # correction for the fingerprint density + # not in the original publication, added in version 1.1 + # to make highly symmetrical molecules easier to synthetise + score3 = 0. + if nAtoms > len(fps): + score3 = math.log(float(nAtoms) / len(fps)) * .5 + + sascore = score1 + score2 + score3 + + # need to transform "raw" value into scale between 1 and 10 + min = -4.0 + max = 2.5 + sascore = 11. - (sascore - min + 1) / (max - min) * 9. + # smooth the 10-end + if sascore > 8.: + sascore = 8. + math.log(sascore + 1. - 9.) + if sascore > 10.: + sascore = 10.0 + elif sascore < 1.: + sascore = 1.0 + + return sascore diff --git a/mcd/metrics/train_loss.py b/mcd/metrics/train_loss.py new file mode 100644 index 0000000..702405e --- /dev/null +++ b/mcd/metrics/train_loss.py @@ -0,0 +1,94 @@ +import time +import torch +import torch.nn as nn +from metrics.abstract_metrics import CrossEntropyMetric +from torchmetrics import Metric, MeanSquaredError + +# from 2:He to 119:* +valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] +valencies_check = torch.tensor(valencies_check) + +weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0] +weight_check = torch.tensor(weight_check) + +class AtomWeightMetric(Metric): + def __init__(self): + super().__init__() + self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") + global weight_check + self.weight_check = weight_check + + def update(self, X, Y): + atom_pred_num = X.argmax(dim=-1) + atom_real_num = Y.argmax(dim=-1) + self.weight_check = self.weight_check.type_as(X) + + pred_weight = self.weight_check[atom_pred_num] + real_weight = self.weight_check[atom_real_num] + + lss = 0 + lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum() + self.total_loss += lss + self.total_samples += X.size(0) + + def compute(self): + return self.total_loss / self.total_samples + + +class TrainLossDiscrete(nn.Module): + """ Train with Cross entropy""" + def __init__(self, lambda_train, weight_node=None, weight_edge=None): + super().__init__() + self.node_loss = CrossEntropyMetric() + self.edge_loss = CrossEntropyMetric() + self.weight_loss = AtomWeightMetric() + + self.y_loss = MeanSquaredError() + self.lambda_train = lambda_train + + def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool): + """ Compute train metrics + masked_pred_X : tensor -- (bs, n, dx) + masked_pred_E : tensor -- (bs, n, n, de) + pred_y : tensor -- (bs, ) + true_X : tensor -- (bs, n, dx) + true_E : tensor -- (bs, n, n, de) + true_y : tensor -- (bs, ) + log : boolean. """ + + loss_weight = self.weight_loss(masked_pred_X, true_X) + + true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx) + true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de) + masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1))) # (bs * n, dx) + masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1))) # (bs * n * n, de) + + # Remove masked rows + mask_X = (true_X != 0.).any(dim=-1) + mask_E = (true_E != 0.).any(dim=-1) + + flat_true_X = true_X[mask_X, :] + flat_pred_X = masked_pred_X[mask_X, :] + + flat_true_E = true_E[mask_E, :] + flat_pred_E = masked_pred_E[mask_E, :] + + loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0 + loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0 + + return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight + + def reset(self): + for metric in [self.node_loss, self.edge_loss, self.y_loss]: + metric.reset() + + def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True): + epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1 + epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1 + epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1 + + if log: + print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} " + f"Weight: {epoch_weight_loss :.4f} " + f"-- Time taken {time.time() - start_epoch_time:.1f}s ") \ No newline at end of file diff --git a/mcd/models/__init__.py b/mcd/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcd/models/conditions.py b/mcd/models/conditions.py new file mode 100644 index 0000000..ba2d4c6 --- /dev/null +++ b/mcd/models/conditions.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import math + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t = t.view(-1) + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class CategoricalEmbedder(nn.Module): + """ + Embeds categorical conditions such as data sources into vector representations. + Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None, t=None): + labels = labels.long().view(-1) + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + if True and train: + noise = torch.randn_like(embeddings) + embeddings = embeddings + noise + return embeddings + +class ClusterContinuousEmbedder(nn.Module): + def __init__(self, input_size, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + + if use_cfg_embedding: + self.embedding_drop = nn.Embedding(1, hidden_size) + + self.mlp = nn.Sequential( + nn.Linear(input_size, hidden_size, bias=True), + nn.Softmax(dim=1), + nn.Linear(hidden_size, hidden_size, bias=False) + ) + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + + def forward(self, labels, train, force_drop_ids=None, timestep=None): + use_dropout = self.dropout_prob > 0 + if force_drop_ids is not None: + drop_ids = force_drop_ids == 1 + else: + drop_ids = None + + if (train and use_dropout): + drop_ids_rand = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + if force_drop_ids is not None: + drop_ids = torch.logical_or(drop_ids, drop_ids_rand) + else: + drop_ids = drop_ids_rand + + if drop_ids is not None: + embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device) + embeddings[~drop_ids] = self.mlp(labels[~drop_ids]) + embeddings[drop_ids] += self.embedding_drop.weight[0] + else: + embeddings = self.mlp(labels) + + if train: + noise = torch.randn_like(embeddings) + embeddings = embeddings + noise + return embeddings diff --git a/mcd/models/layers.py b/mcd/models/layers.py new file mode 100644 index 0000000..09deb42 --- /dev/null +++ b/mcd/models/layers.py @@ -0,0 +1,114 @@ +from torch.jit import Final +import torch.nn.functional as F +from itertools import repeat +import collections.abc + +import torch +import torch.nn as nn + +class Attention(nn.Module): + fast_attn: Final[bool] + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0, + proj_drop=0, + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.scale = self.head_dim ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + assert self.fast_attn, "scaled_dot_product_attention Not implemented" + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def dot_product_attention(self, q, k, v): + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn_sfmx = attn.softmax(dim=-1) + attn_sfmx = self.attn_drop(attn_sfmx) + x = attn_sfmx @ v + return x, attn + + def forward(self, x, node_mask): + B, N, D = x.shape + + # B, head, N, head_dim + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # B, head, N, head_dim + q, k = self.q_norm(q), self.k_norm(k) + + attn_mask = (node_mask[:, None, :, None] & node_mask[:, None, None, :]).expand(-1, self.num_heads, N, N) + attn_mask[attn_mask.sum(-1) == 0] = True + + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + attn_mask=attn_mask, + ) + + x = x.transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + drop=0., + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse + +to_2tuple = _ntuple(2) + + diff --git a/mcd/models/transformer.py b/mcd/models/transformer.py new file mode 100644 index 0000000..e9e95d1 --- /dev/null +++ b/mcd/models/transformer.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn + +import utils +from models.layers import Attention, Mlp +from models.conditions import TimestepEmbedder, CategoricalEmbedder, ClusterContinuousEmbedder + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +class Denoiser(nn.Module): + def __init__( + self, + max_n_nodes, + hidden_size=384, + depth=12, + num_heads=16, + mlp_ratio=4.0, + drop_condition=0.1, + Xdim=118, + Edim=5, + ydim=3, + task_type='regression', + ): + super().__init__() + self.num_heads = num_heads + self.ydim = ydim + self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) + + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedding_list = torch.nn.ModuleList() + + self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition)) + for i in range(ydim - 2): + if task_type == 'regression': + self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) + else: + self.y_embedding_list.append(CategoricalEmbedder(2, hidden_size, drop_condition)) + + self.encoders = nn.ModuleList( + [ + SELayer(hidden_size, num_heads, mlp_ratio=mlp_ratio) + for _ in range(depth) + ] + ) + + self.decoder = Decoder( + max_n_nodes=max_n_nodes, + hidden_size=hidden_size, + atom_type=Xdim, + bond_type=Edim, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + ) + + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def _constant_init(module, i): + if isinstance(module, nn.Linear): + nn.init.constant_(module.weight, i) + if module.bias is not None: + nn.init.constant_(module.bias, i) + + self.apply(_basic_init) + + for block in self.encoders : + _constant_init(block.adaLN_modulation[0], 0) + _constant_init(self.decoder.adaLN_modulation[0], 0) + + def forward(self, x, e, node_mask, y, t, unconditioned): + + force_drop_id = torch.zeros_like(y.sum(-1)) + force_drop_id[torch.isnan(y.sum(-1))] = 1 + if unconditioned: + force_drop_id = torch.ones_like(y[:, 0]) + + x_in, e_in, y_in = x, e, y + bs, n, _ = x.size() + x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1) + x = self.x_embedder(x) + + c1 = self.t_embedder(t) + for i in range(1, self.ydim): + if i == 1: + c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) + else: + c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) + c = c1 + c2 + + for i, block in enumerate(self.encoders): + x = block(x, c, node_mask) + + # X: B * N * dx, E: B * N * N * de + X, E, y = self.decoder(x, x_in, e_in, c, t, node_mask) + return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask) + + +class SELayer(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.dropout = 0. + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False) + + self.attn = Attention( + hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True, **block_kwargs + ) + + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=int(hidden_size * mlp_ratio), + drop=self.dropout, + ) + + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, node_mask): + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * modulate(self.norm1(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa) + x = x + gate_mlp.unsqueeze(1) * modulate(self.norm2(self.mlp(x)), shift_mlp, scale_mlp) + return x + + +class Decoder(nn.Module): + # Structure Decoder + def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None): + super().__init__() + self.atom_type = atom_type + self.bond_type = bond_type + final_size = atom_type + max_n_nodes * bond_type + self.xedecoder = Mlp(in_features=hidden_size, + out_features=final_size, drop=0) + + self.norm_final = nn.LayerNorm(final_size, elementwise_affine=False) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, 2 * final_size, bias=True) + ) + + def forward(self, x, x_in, e_in, c, t, node_mask): + x_all = self.xedecoder(x) + B, N, D = x_all.size() + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x_all = modulate(self.norm_final(x_all), shift, scale) + + atom_out = x_all[:, :, :self.atom_type] + atom_out = x_in + atom_out + + bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type) + bond_out = e_in + bond_out + + ##### standardize adj_out + edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :] + diag_mask = ( + torch.eye(N, dtype=torch.bool) + .unsqueeze(0) + .expand(B, -1, -1) + .type_as(edge_mask) + ) + bond_out.masked_fill_(edge_mask[:, :, :, None], 0) + bond_out.masked_fill_(diag_mask[:, :, :, None], 0) + bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2)) + + return atom_out, bond_out, None diff --git a/mcd/utils.py b/mcd/utils.py new file mode 100644 index 0000000..23776ea --- /dev/null +++ b/mcd/utils.py @@ -0,0 +1,135 @@ +import os +from omegaconf import OmegaConf, open_dict + +import torch +import torch_geometric.utils +from torch_geometric.utils import to_dense_adj, to_dense_batch + +def create_folders(args): + try: + os.makedirs('graphs') + os.makedirs('chains') + except OSError: + pass + + try: + os.makedirs('graphs/' + args.general.name) + os.makedirs('chains/' + args.general.name) + except OSError: + pass + +def normalize(X, E, y, norm_values, norm_biases, node_mask): + X = (X - norm_biases[0]) / norm_values[0] + E = (E - norm_biases[1]) / norm_values[1] + y = (y - norm_biases[2]) / norm_values[2] + + diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) + E[diag] = 0 + + return PlaceHolder(X=X, E=E, y=y).mask(node_mask) + + +def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False): + """ + X : node features + E : edge features + y : global features` + norm_values : [norm value X, norm value E, norm value y] + norm_biases : same order + node_mask + """ + X = (X * norm_values[0] + norm_biases[0]) + E = (E * norm_values[1] + norm_biases[1]) + y = y * norm_values[2] + norm_biases[2] + + return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse) + + +def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None): + X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes) + # node_mask = node_mask.float() + edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr) + if max_num_nodes is None: + max_num_nodes = X.size(1) + E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes) + E = encode_no_edge(E) + return PlaceHolder(X=X, E=E, y=None), node_mask + + +def encode_no_edge(E): + assert len(E.shape) == 4 + if E.shape[-1] == 0: + return E + no_edge = torch.sum(E, dim=3) == 0 + first_elt = E[:, :, :, 0] + first_elt[no_edge] = 1 + E[:, :, :, 0] = first_elt + diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) + E[diag] = 0 + return E + + +def update_config_with_new_keys(cfg, saved_cfg): + saved_general = saved_cfg.general + saved_train = saved_cfg.train + saved_model = saved_cfg.model + saved_dataset = saved_cfg.dataset + + for key, val in saved_dataset.items(): + OmegaConf.set_struct(cfg.dataset, True) + with open_dict(cfg.dataset): + if key not in cfg.dataset.keys(): + setattr(cfg.dataset, key, val) + + for key, val in saved_general.items(): + OmegaConf.set_struct(cfg.general, True) + with open_dict(cfg.general): + if key not in cfg.general.keys(): + setattr(cfg.general, key, val) + + OmegaConf.set_struct(cfg.train, True) + with open_dict(cfg.train): + for key, val in saved_train.items(): + if key not in cfg.train.keys(): + setattr(cfg.train, key, val) + + OmegaConf.set_struct(cfg.model, True) + with open_dict(cfg.model): + for key, val in saved_model.items(): + if key not in cfg.model.keys(): + setattr(cfg.model, key, val) + return cfg + + +class PlaceHolder: + def __init__(self, X, E, y): + self.X = X + self.E = E + self.y = y + + def type_as(self, x: torch.Tensor, categorical: bool = False): + """ Changes the device and dtype of X, E, y. """ + self.X = self.X.type_as(x) + self.E = self.E.type_as(x) + if categorical: + self.y = self.y.type_as(x) + return self + + def mask(self, node_mask, collapse=False): + x_mask = node_mask.unsqueeze(-1) # bs, n, 1 + e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 + e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 + + if collapse: + self.X = torch.argmax(self.X, dim=-1) + self.E = torch.argmax(self.E, dim=-1) + + self.X[node_mask == 0] = - 1 + self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1 + else: + self.X = self.X * x_mask + self.E = self.E * e_mask1 * e_mask2 + assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) + return self + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..45cd2ce --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +fcd_torch==1.0.7 +hydra-core==1.3.2 +imageio==2.26.0 +joblib==1.2.0 +matplotlib==3.7.0 +mini_moses==1.0 +networkx==3.0 +numpy==1.24.2 +omegaconf==2.3.0 +pandas==1.5.3 +pytorch_lightning==2.0.1 +rdkit==2023.9.4 +scikit_learn==1.2.1 +torch==2.0.0 +torch_geometric==2.3.0 +torchmetrics==0.11.4 +tqdm==4.64.1