update_name
This commit is contained in:
		
							
								
								
									
										0
									
								
								graph_dit/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								graph_dit/analysis/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/analysis/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										411
									
								
								graph_dit/analysis/rdkit_functions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										411
									
								
								graph_dit/analysis/rdkit_functions.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,411 @@ | ||||
| from rdkit import Chem, RDLogger | ||||
| RDLogger.DisableLog('rdApp.*') | ||||
| from fcd_torch import FCD as FCDMetric | ||||
| from mini_moses.metrics.metrics import FragMetric, internal_diversity | ||||
| from mini_moses.metrics.utils import get_mol, mapper | ||||
|  | ||||
| import re | ||||
| import time | ||||
| import random | ||||
| random.seed(0) | ||||
| import numpy as np | ||||
| from multiprocessing import Pool | ||||
|  | ||||
| import torch | ||||
| from metrics.property_metric import calculateSAS | ||||
|  | ||||
| bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, | ||||
|                  Chem.rdchem.BondType.AROMATIC] | ||||
| ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} | ||||
|  | ||||
| bd_dict_x = {'O2-N2': [5.00E+04, 1.00E-03]} | ||||
| bd_dict_y = {'O2-N2': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05]} | ||||
|  | ||||
| selectivity = ['O2-N2'] | ||||
| a_dict = {} | ||||
| b_dict = {} | ||||
| for prop_name in selectivity: | ||||
|     x1, x2 = np.log10(bd_dict_x[prop_name][0]), np.log10(bd_dict_x[prop_name][1]) | ||||
|     y1, y2 = np.log10(bd_dict_y[prop_name][0]), np.log10(bd_dict_y[prop_name][1]) | ||||
|     a = (y1-y2)/(x1-x2) | ||||
|     b = y1-a*x1 | ||||
|     a_dict[prop_name] = a | ||||
|     b_dict[prop_name] = b | ||||
|  | ||||
| def selectivity_evaluation(gas1, gas2, prop_name): | ||||
|     x = np.log10(np.array(gas1)) | ||||
|     y = np.log10(np.array(gas1) / np.array(gas2)) | ||||
|     upper = (y - (a_dict[prop_name] * x + b_dict[prop_name])) > 0 | ||||
|     return upper | ||||
|  | ||||
| class BasicMolecularMetrics(object): | ||||
|     def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512): | ||||
|         self.dataset_smiles_list = train_smiles | ||||
|         self.atom_decoder = atom_decoder | ||||
|         self.n_jobs = n_jobs | ||||
|         self.device = device | ||||
|         self.batch_size = batch_size | ||||
|         self.stat_ref = stat_ref | ||||
|         self.task_evaluator = task_evaluator | ||||
|  | ||||
|     def compute_relaxed_validity(self, generated, ensure_connected): | ||||
|         valid = [] | ||||
|         num_components = [] | ||||
|         all_smiles = [] | ||||
|         valid_mols = [] | ||||
|         covered_atoms = set() | ||||
|         direct_valid_count = 0 | ||||
|         for graph in generated: | ||||
|             atom_types, edge_types = graph | ||||
|             mol = build_molecule_with_partial_charges(atom_types, edge_types, self.atom_decoder) | ||||
|             direct_valid_flag = True if check_mol(mol, largest_connected_comp=True) is not None else False | ||||
|             if direct_valid_flag: | ||||
|                 direct_valid_count += 1 | ||||
|             if not ensure_connected: | ||||
|                 mol_conn, _ = correct_mol(mol, connection=True) | ||||
|                 mol = mol_conn if mol_conn is not None else correct_mol(mol, connection=False)[0] | ||||
|             else: # ensure fully connected | ||||
|                 mol, _ = correct_mol(mol, connection=True) | ||||
|             smiles = mol2smiles(mol) | ||||
|             mol = get_mol(smiles) | ||||
|             try: | ||||
|                 mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | ||||
|                 num_components.append(len(mol_frags)) | ||||
|                 largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) | ||||
|                 smiles = mol2smiles(largest_mol) | ||||
|                 if smiles is not None and largest_mol is not None and len(smiles) > 1 and Chem.MolFromSmiles(smiles) is not None: | ||||
|                     valid_mols.append(largest_mol) | ||||
|                     valid.append(smiles) | ||||
|                     for atom in largest_mol.GetAtoms(): | ||||
|                         covered_atoms.add(atom.GetSymbol()) | ||||
|                     all_smiles.append(smiles) | ||||
|                 else: | ||||
|                     all_smiles.append(None) | ||||
|             except Exception as e:  | ||||
|                 # print(f"An error occurred: {e}") | ||||
|                 all_smiles.append(None) | ||||
|                  | ||||
|         return valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_smiles, covered_atoms | ||||
|  | ||||
|     def evaluate(self, generated, targets, ensure_connected, active_atoms=None): | ||||
|         """ generated: list of pairs (positions: n x 3, atom_types: n [int]) | ||||
|             the positions and atom types should already be masked. """ | ||||
|         valid, validity, nc_validity, num_components, all_smiles, covered_atoms = self.compute_relaxed_validity(generated, ensure_connected=ensure_connected) | ||||
|         nc_mu = num_components.mean() if len(num_components) > 0 else 0 | ||||
|         nc_min = num_components.min() if len(num_components) > 0 else 0 | ||||
|         nc_max = num_components.max() if len(num_components) > 0 else 0 | ||||
|  | ||||
|         len_active = len(active_atoms) if active_atoms is not None else 1 | ||||
|          | ||||
|         cover_str = f"Cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}" | ||||
|         print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_atoms)} ({len(covered_atoms)/len_active * 100:.2f}%) atoms: {covered_atoms}") | ||||
|         print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}") | ||||
|  | ||||
|         if validity > 0:  | ||||
|             dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity} | ||||
|             unique = list(set(valid)) | ||||
|             close_pool = False | ||||
|             if self.n_jobs != 1: | ||||
|                 pool = Pool(self.n_jobs) | ||||
|                 close_pool = True | ||||
|             else: | ||||
|                 pool = 1 | ||||
|             valid_mols = mapper(pool)(get_mol, valid)  | ||||
|             dist_metrics['interval_diversity'] = internal_diversity(valid_mols, pool, device=self.device) | ||||
|              | ||||
|             start_time = time.time() | ||||
|             if self.stat_ref is not None: | ||||
|                 kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size} | ||||
|                 kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size} | ||||
|                 try: | ||||
|                     dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_mols, pref=self.stat_ref['Frag']) | ||||
|                 except: | ||||
|                     print('error: ', 'pool', pool) | ||||
|                     print('valid_mols: ', valid_mols) | ||||
|                 dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD']) | ||||
|  | ||||
|             if self.task_evaluator is not None: | ||||
|                 evaluation_list = list(self.task_evaluator.keys()) | ||||
|                 evaluation_list = evaluation_list.copy() | ||||
|  | ||||
|                 assert 'meta_taskname' in evaluation_list | ||||
|                 meta_taskname = self.task_evaluator['meta_taskname'] | ||||
|                 evaluation_list.remove('meta_taskname') | ||||
|                 meta_split = meta_taskname.split('-') | ||||
|  | ||||
|                 valid_index = np.array([True if smiles else False for smiles in all_smiles]) | ||||
|                 targets_log = {} | ||||
|                 for i, name in enumerate(evaluation_list): | ||||
|                     targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index)) | ||||
|                     targets_log[f'input_{name}'] = targets[:, i] | ||||
|                  | ||||
|                 targets = targets[valid_index] | ||||
|                 if len(meta_split) == 2: | ||||
|                     cached_perm = {meta_split[0]: None, meta_split[1]: None} | ||||
|                  | ||||
|                 for i, name in enumerate(evaluation_list): | ||||
|                     if name == 'scs': | ||||
|                         continue | ||||
|                     elif name == 'sas': | ||||
|                         scores = calculateSAS(valid) | ||||
|                     else: | ||||
|                         scores = self.task_evaluator[name](valid) | ||||
|                     targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index)) | ||||
|                     targets_log[f'output_{name}'][valid_index] = scores | ||||
|                     if name in ['O2', 'N2', 'CO2']: | ||||
|                         if len(meta_split) == 2: | ||||
|                             cached_perm[name] = scores | ||||
|                         scores, cur_targets = np.log10(scores), np.log10(targets[:, i]) | ||||
|                         dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets)) | ||||
|                     elif name == 'sas': | ||||
|                         dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i])) | ||||
|                     else: | ||||
|                         true_y = targets[:, i] | ||||
|                         predicted_labels = (scores >= 0.5).astype(int) | ||||
|                         acc = (predicted_labels == true_y).sum() / len(true_y) | ||||
|                         dist_metrics[f'{name}/acc'] = acc | ||||
|  | ||||
|                 if len(meta_split) == 2: | ||||
|                     if cached_perm[meta_split[0]] is not None and cached_perm[meta_split[1]] is not None: | ||||
|                         task_name = self.task_evaluator['meta_taskname'] | ||||
|                         upper = selectivity_evaluation(cached_perm[meta_split[0]], cached_perm[meta_split[1]], task_name) | ||||
|                         dist_metrics[f'selectivity/{task_name}'] = np.sum(upper) | ||||
|  | ||||
|             end_time = time.time() | ||||
|             elapsed_time = end_time - start_time | ||||
|             max_key_length = max(len(key) for key in dist_metrics) | ||||
|             print(f'Details over {len(valid)} ({len(generated)}) valid (total) molecules, calculating metrics using {elapsed_time:.2f} s:') | ||||
|             strs = '' | ||||
|             for i, (key, value) in enumerate(dist_metrics.items()): | ||||
|                 if isinstance(value, (int, float, np.floating, np.integer)): | ||||
|                     strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t' | ||||
|                 if i % 4 == 3: | ||||
|                     strs = strs + '\n' | ||||
|             print(strs) | ||||
|  | ||||
|             if close_pool: | ||||
|                 pool.close() | ||||
|                 pool.join() | ||||
|         else: | ||||
|             unique = [] | ||||
|             dist_metrics = {} | ||||
|             targets_log = None | ||||
|         return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles, dist_metrics, targets_log | ||||
|  | ||||
| def mol2smiles(mol): | ||||
|     if mol is None: | ||||
|         return None | ||||
|     try: | ||||
|         Chem.SanitizeMol(mol) | ||||
|     except ValueError: | ||||
|         return None | ||||
|     return Chem.MolToSmiles(mol) | ||||
|  | ||||
| def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False): | ||||
|     if verbose: | ||||
|         print("\nbuilding new molecule") | ||||
|  | ||||
|     mol = Chem.RWMol() | ||||
|     for atom in atom_types: | ||||
|         a = Chem.Atom(atom_decoder[atom.item()]) | ||||
|         mol.AddAtom(a) | ||||
|         if verbose: | ||||
|             print("Atom added: ", atom.item(), atom_decoder[atom.item()]) | ||||
|  | ||||
|     edge_types = torch.triu(edge_types) | ||||
|     all_bonds = torch.nonzero(edge_types) | ||||
|  | ||||
|     for i, bond in enumerate(all_bonds): | ||||
|         if bond[0].item() != bond[1].item(): | ||||
|             mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()]) | ||||
|             if verbose: | ||||
|                 print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), | ||||
|                       bond_dict[edge_types[bond[0], bond[1]].item()]) | ||||
|             # add formal charge to atom: e.g. [O+], [N+], [S+] | ||||
|             # not support [O-], [N-], [S-], [NH+] etc. | ||||
|             flag, atomid_valence = check_valency(mol) | ||||
|             if verbose: | ||||
|                 print("flag, valence", flag, atomid_valence) | ||||
|             if flag: | ||||
|                 continue | ||||
|             else: | ||||
|                 if len(atomid_valence) == 2: | ||||
|                     idx = atomid_valence[0] | ||||
|                     v = atomid_valence[1] | ||||
|                     an = mol.GetAtomWithIdx(idx).GetAtomicNum() | ||||
|                     if verbose: | ||||
|                         print("atomic num of atom with a large valence", an) | ||||
|                     if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: | ||||
|                         mol.GetAtomWithIdx(idx).SetFormalCharge(1) | ||||
|                         # print("Formal charge added") | ||||
|                 else: | ||||
|                     continue | ||||
|     return mol | ||||
|  | ||||
| def check_valency(mol): | ||||
|     try: | ||||
|         Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) | ||||
|         return True, None | ||||
|     except ValueError as e: | ||||
|         e = str(e) | ||||
|         p = e.find('#') | ||||
|         e_sub = e[p:] | ||||
|         atomid_valence = list(map(int, re.findall(r'\d+', e_sub))) | ||||
|         return False, atomid_valence | ||||
|  | ||||
|  | ||||
| def correct_mol(mol, connection=False): | ||||
|     ##### | ||||
|     no_correct = False | ||||
|     flag, _ = check_valency(mol) | ||||
|     if flag: | ||||
|         no_correct = True | ||||
|  | ||||
|     while True: | ||||
|         if connection: | ||||
|             mol_conn = connect_fragments(mol) | ||||
|             # if mol_conn is not None: | ||||
|             mol = mol_conn | ||||
|             if mol is None: | ||||
|                 return None, no_correct | ||||
|         flag, atomid_valence = check_valency(mol) | ||||
|         if flag: | ||||
|             break | ||||
|         else: | ||||
|             try: | ||||
|                 assert len(atomid_valence) == 2 | ||||
|                 idx = atomid_valence[0] | ||||
|                 v = atomid_valence[1] | ||||
|                 queue = [] | ||||
|                 check_idx = 0 | ||||
|                 for b in mol.GetAtomWithIdx(idx).GetBonds(): | ||||
|                     type = int(b.GetBondType()) | ||||
|                     queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx())) | ||||
|                     if type == 12: | ||||
|                         check_idx += 1 | ||||
|                 queue.sort(key=lambda tup: tup[1], reverse=True) | ||||
|  | ||||
|                 if queue[-1][1] == 12: | ||||
|                     return None, no_correct | ||||
|                 elif len(queue) > 0: | ||||
|                     start = queue[check_idx][2] | ||||
|                     end = queue[check_idx][3] | ||||
|                     t = queue[check_idx][1] - 1 | ||||
|                     mol.RemoveBond(start, end) | ||||
|                     if t >= 1: | ||||
|                         mol.AddBond(start, end, bond_dict[t]) | ||||
|             except Exception as e: | ||||
|                 # print(f"An error occurred in correction: {e}") | ||||
|                 return None, no_correct | ||||
|     return mol, no_correct | ||||
|  | ||||
|  | ||||
| def check_mol(m, largest_connected_comp=True): | ||||
|     if m is None: | ||||
|         return None | ||||
|     sm = Chem.MolToSmiles(m, isomericSmiles=True) | ||||
|     if largest_connected_comp and '.' in sm: | ||||
|         vsm = [(s, len(s)) for s in sm.split('.')]  # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.') | ||||
|         vsm.sort(key=lambda tup: tup[1], reverse=True) | ||||
|         mol = Chem.MolFromSmiles(vsm[0][0]) | ||||
|     else: | ||||
|         mol = Chem.MolFromSmiles(sm) | ||||
|     return mol | ||||
|  | ||||
|  | ||||
| ##### connect fragements | ||||
| def select_atom_with_available_valency(frag): | ||||
|     atoms = list(frag.GetAtoms()) | ||||
|     random.shuffle(atoms) | ||||
|     for atom in atoms: | ||||
|         if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0: | ||||
|             return atom | ||||
|  | ||||
|     return None | ||||
|  | ||||
| def select_atoms_with_available_valency(frag): | ||||
|     return [atom for atom in frag.GetAtoms() if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0] | ||||
|  | ||||
| def try_to_connect_fragments(combined_mol, frag, atom1, atom2): | ||||
|     # Make copies of the molecules to try the connection | ||||
|     trial_combined_mol = Chem.RWMol(combined_mol) | ||||
|     trial_frag = Chem.RWMol(frag) | ||||
|      | ||||
|     # Add the new fragment to the combined molecule with new indices | ||||
|     new_indices = {atom.GetIdx(): trial_combined_mol.AddAtom(atom) for atom in trial_frag.GetAtoms()} | ||||
|      | ||||
|     # Add the bond between the suitable atoms from each fragment | ||||
|     trial_combined_mol.AddBond(atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE) | ||||
|      | ||||
|     # Adjust the hydrogen count of the connected atoms | ||||
|     for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]: | ||||
|         atom = trial_combined_mol.GetAtomWithIdx(atom_idx) | ||||
|         num_h = atom.GetTotalNumHs() | ||||
|         atom.SetNumExplicitHs(max(0, num_h - 1)) | ||||
|          | ||||
|     # Add bonds for the new fragment | ||||
|     for bond in trial_frag.GetBonds(): | ||||
|         trial_combined_mol.AddBond(new_indices[bond.GetBeginAtomIdx()], new_indices[bond.GetEndAtomIdx()], bond.GetBondType()) | ||||
|      | ||||
|     # Convert to a Mol object and try to sanitize it | ||||
|     new_mol = Chem.Mol(trial_combined_mol) | ||||
|     try: | ||||
|         Chem.SanitizeMol(new_mol) | ||||
|         return new_mol  # Return the new valid molecule | ||||
|     except Chem.MolSanitizeException: | ||||
|         return None  # If the molecule is not valid, return None | ||||
|  | ||||
| def connect_fragments(mol): | ||||
|     # Get the separate fragments | ||||
|     frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) | ||||
|     if len(frags) < 2: | ||||
|         return mol | ||||
|  | ||||
|     combined_mol = Chem.RWMol(frags[0]) | ||||
|  | ||||
|     for frag in frags[1:]: | ||||
|         # Select all atoms with available valency from both molecules | ||||
|         atoms1 = select_atoms_with_available_valency(combined_mol) | ||||
|         atoms2 = select_atoms_with_available_valency(frag) | ||||
|          | ||||
|         # Try to connect using all combinations of available valency atoms | ||||
|         for atom1 in atoms1: | ||||
|             for atom2 in atoms2: | ||||
|                 new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2) | ||||
|                 if new_mol is not None: | ||||
|                     # If a valid connection is made, update the combined molecule and break | ||||
|                     combined_mol = new_mol | ||||
|                     break | ||||
|             else: | ||||
|                 # Continue if the inner loop didn't break (no valid connection found for atom1) | ||||
|                 continue | ||||
|             # Break if the inner loop did break (valid connection found) | ||||
|             break | ||||
|         else: | ||||
|             # If no valid connections could be made with any of the atoms, return None | ||||
|             return None | ||||
|  | ||||
|     return combined_mol | ||||
|  | ||||
| #### connect fragements | ||||
|  | ||||
| def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config): | ||||
|     """ molecule_list: (dict) """ | ||||
|  | ||||
|     atom_decoder = dataset_info.atom_decoder | ||||
|     active_atoms = dataset_info.active_atoms | ||||
|     ensure_connected = dataset_info.ensure_connected | ||||
|     metrics = BasicMolecularMetrics(atom_decoder, train_smiles, stat_ref, task_evaluator, **comput_config) | ||||
|     evaluated_res = metrics.evaluate(molecule_list, targets, ensure_connected, active_atoms) | ||||
|     all_smiles = evaluated_res[-3] | ||||
|     all_metrics = evaluated_res[-2] | ||||
|     targets_log = evaluated_res[-1] | ||||
|     unique_smiles = evaluated_res[0] | ||||
|  | ||||
|     return unique_smiles, all_smiles, all_metrics, targets_log | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     smiles_mol = 'C1CCC1' | ||||
|     print("Smiles mol %s" % smiles_mol) | ||||
|     chem_mol = Chem.MolFromSmiles(smiles_mol) | ||||
|     print(block_mol) | ||||
							
								
								
									
										222
									
								
								graph_dit/analysis/visualization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								graph_dit/analysis/visualization.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,222 @@ | ||||
| import os | ||||
|  | ||||
| from rdkit import Chem | ||||
| from rdkit.Chem import Draw, AllChem | ||||
| from rdkit.Geometry import Point3D | ||||
| from rdkit import RDLogger | ||||
| import imageio | ||||
| import networkx as nx | ||||
| import numpy as np | ||||
| import rdkit.Chem | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
|  | ||||
| class MolecularVisualization: | ||||
|     def __init__(self, dataset_infos): | ||||
|         self.dataset_infos = dataset_infos | ||||
|  | ||||
|     def mol_from_graphs(self, node_list, adjacency_matrix): | ||||
|         """ | ||||
|         Convert graphs to rdkit molecules | ||||
|         node_list: the nodes of a batch of nodes (bs x n) | ||||
|         adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) | ||||
|         """ | ||||
|         # dictionary to map integer value to the char of atom | ||||
|         atom_decoder = self.dataset_infos.atom_decoder | ||||
|         # [list(self.dataset_infos.atom_decoder.keys())[0]] | ||||
|  | ||||
|         # create empty editable mol object | ||||
|         mol = Chem.RWMol() | ||||
|  | ||||
|         # add atoms to mol and keep track of index | ||||
|         node_to_idx = {} | ||||
|         for i in range(len(node_list)): | ||||
|             if node_list[i] == -1: | ||||
|                 continue | ||||
|             a = Chem.Atom(atom_decoder[int(node_list[i])]) | ||||
|             molIdx = mol.AddAtom(a) | ||||
|             node_to_idx[i] = molIdx | ||||
|  | ||||
|         for ix, row in enumerate(adjacency_matrix): | ||||
|             for iy, bond in enumerate(row): | ||||
|                 # only traverse half the symmetric matrix | ||||
|                 if iy <= ix: | ||||
|                     continue | ||||
|                 if bond == 1: | ||||
|                     bond_type = Chem.rdchem.BondType.SINGLE | ||||
|                 elif bond == 2: | ||||
|                     bond_type = Chem.rdchem.BondType.DOUBLE | ||||
|                 elif bond == 3: | ||||
|                     bond_type = Chem.rdchem.BondType.TRIPLE | ||||
|                 elif bond == 4: | ||||
|                     bond_type = Chem.rdchem.BondType.AROMATIC | ||||
|                 else: | ||||
|                     continue | ||||
|                 mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) | ||||
|  | ||||
|         try: | ||||
|             mol = mol.GetMol() | ||||
|         except rdkit.Chem.KekulizeException: | ||||
|             print("Can't kekulize molecule") | ||||
|             mol = None | ||||
|         return mol | ||||
|  | ||||
|     def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'): | ||||
|         # define path to save figures | ||||
|         if not os.path.exists(path): | ||||
|             os.makedirs(path) | ||||
|  | ||||
|         # visualize the final molecules | ||||
|         print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}") | ||||
|         if num_molecules_to_visualize > len(molecules): | ||||
|             print(f"Shortening to {len(molecules)}") | ||||
|             num_molecules_to_visualize = len(molecules) | ||||
|          | ||||
|         for i in range(num_molecules_to_visualize): | ||||
|             file_path = os.path.join(path, 'molecule_{}.png'.format(i)) | ||||
|             mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy()) | ||||
|             try: | ||||
|                 Draw.MolToFile(mol, file_path) | ||||
|             except rdkit.Chem.KekulizeException: | ||||
|                 print("Can't kekulize molecule") | ||||
|  | ||||
|     def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None): | ||||
|         RDLogger.DisableLog('rdApp.*') | ||||
|         # convert graphs to the rdkit molecules | ||||
|         mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] | ||||
|  | ||||
|         # find the coordinates of atoms in the final molecule | ||||
|         final_molecule = mols[-1] | ||||
|         AllChem.Compute2DCoords(final_molecule) | ||||
|  | ||||
|         coords = [] | ||||
|         for i, atom in enumerate(final_molecule.GetAtoms()): | ||||
|             positions = final_molecule.GetConformer().GetAtomPosition(i) | ||||
|             coords.append((positions.x, positions.y, positions.z)) | ||||
|  | ||||
|         # align all the molecules | ||||
|         for i, mol in enumerate(mols): | ||||
|             AllChem.Compute2DCoords(mol) | ||||
|             conf = mol.GetConformer() | ||||
|             for j, atom in enumerate(mol.GetAtoms()): | ||||
|                 x, y, z = coords[j] | ||||
|                 conf.SetAtomPosition(j, Point3D(x, y, z)) | ||||
|  | ||||
|         # draw gif | ||||
|         save_paths = [] | ||||
|         num_frams = nodes_list.shape[0] | ||||
|  | ||||
|         for frame in range(num_frams): | ||||
|             file_name = os.path.join(path, 'fram_{}.png'.format(frame)) | ||||
|             Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}") | ||||
|             save_paths.append(file_name) | ||||
|  | ||||
|         imgs = [imageio.imread(fn) for fn in save_paths] | ||||
|         gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1])) | ||||
|         imgs.extend([imgs[-1]] * 10) | ||||
|         imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5) | ||||
|  | ||||
|         # draw grid image | ||||
|         try: | ||||
|             img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200)) | ||||
|             img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1]))) | ||||
|         except Chem.rdchem.KekulizeException: | ||||
|             print("Can't kekulize molecule") | ||||
|         return mols | ||||
|      | ||||
|     def visualize_by_smiles(self, path: str, smiles_list: list, num_to_visualize: int): | ||||
|         os.makedirs(path, exist_ok=True) | ||||
|  | ||||
|         print(f"Visualizing corrected {num_to_visualize} of {len(smiles_list)}") | ||||
|         if num_to_visualize > len(smiles_list): | ||||
|             print(f"Shortening to {len(smiles_list)}") | ||||
|             num_to_visualize = len(smiles_list) | ||||
|  | ||||
|         for i in range(num_to_visualize): | ||||
|             file_path = os.path.join(path, 'molecule_corrected_{}.png'.format(i)) | ||||
|             if smiles_list[i] is None: | ||||
|                 continue | ||||
|             mol = Chem.MolFromSmiles(smiles_list[i]) | ||||
|             try: | ||||
|                 Draw.MolToFile(mol, file_path) | ||||
|             except rdkit.Chem.KekulizeException: | ||||
|                 print("Can't kekulize molecule") | ||||
|  | ||||
| class NonMolecularVisualization: | ||||
|     def to_networkx(self, node_list, adjacency_matrix): | ||||
|         """ | ||||
|         Convert graphs to networkx graphs | ||||
|         node_list: the nodes of a batch of nodes (bs x n) | ||||
|         adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) | ||||
|         """ | ||||
|         graph = nx.Graph() | ||||
|  | ||||
|         for i in range(len(node_list)): | ||||
|             if node_list[i] == -1: | ||||
|                 continue | ||||
|             graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i]) | ||||
|  | ||||
|         rows, cols = np.where(adjacency_matrix >= 1) | ||||
|         edges = zip(rows.tolist(), cols.tolist()) | ||||
|         for edge in edges: | ||||
|             edge_type = adjacency_matrix[edge[0]][edge[1]] | ||||
|             graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type) | ||||
|  | ||||
|         return graph | ||||
|  | ||||
|     def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False): | ||||
|         if largest_component: | ||||
|             CGs = [graph.subgraph(c) for c in nx.connected_components(graph)] | ||||
|             CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True) | ||||
|             graph = CGs[0] | ||||
|  | ||||
|         # Plot the graph structure with colors | ||||
|         if pos is None: | ||||
|             pos = nx.spring_layout(graph, iterations=iterations) | ||||
|  | ||||
|         # Set node colors based on the eigenvectors | ||||
|         w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray()) | ||||
|         vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1]) | ||||
|         m = max(np.abs(vmin), vmax) | ||||
|         vmin, vmax = -m, m | ||||
|  | ||||
|         plt.figure() | ||||
|         nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1], | ||||
|                 cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey') | ||||
|  | ||||
|         plt.tight_layout() | ||||
|         plt.savefig(path) | ||||
|         plt.close("all") | ||||
|  | ||||
|     def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'): | ||||
|         # define path to save figures | ||||
|         if not os.path.exists(path): | ||||
|             os.makedirs(path) | ||||
|  | ||||
|         # visualize the final molecules | ||||
|         for i in range(num_graphs_to_visualize): | ||||
|             file_path = os.path.join(path, 'graph_{}.png'.format(i)) | ||||
|             graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy()) | ||||
|             self.visualize_non_molecule(graph=graph, pos=None, path=file_path) | ||||
|             im = plt.imread(file_path) | ||||
|  | ||||
|     def visualize_chain(self, path, nodes_list, adjacency_matrix): | ||||
|         # convert graphs to networkx | ||||
|         graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] | ||||
|         # find the coordinates of atoms in the final molecule | ||||
|         final_graph = graphs[-1] | ||||
|         final_pos = nx.spring_layout(final_graph, seed=0) | ||||
|  | ||||
|         # draw gif | ||||
|         save_paths = [] | ||||
|         num_frams = nodes_list.shape[0] | ||||
|  | ||||
|         for frame in range(num_frams): | ||||
|             file_name = os.path.join(path, 'fram_{}.png'.format(frame)) | ||||
|             self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name) | ||||
|             save_paths.append(file_name) | ||||
|  | ||||
|         imgs = [imageio.imread(fn) for fn in save_paths] | ||||
|         gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1])) | ||||
|         imgs.extend([imgs[-1]] * 10) | ||||
|         imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5) | ||||
							
								
								
									
										0
									
								
								graph_dit/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										126
									
								
								graph_dit/datasets/abstract_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								graph_dit/datasets/abstract_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | ||||
| from diffusion.distributions import DistributionNodes | ||||
| import utils as utils | ||||
| import torch | ||||
| import pytorch_lightning as pl | ||||
| from torch_geometric.loader import DataLoader | ||||
|  | ||||
|  | ||||
| class AbstractDataModule(pl.LightningDataModule): | ||||
|     def __init__(self, cfg): | ||||
|         super().__init__() | ||||
|         self.cfg = cfg | ||||
|         self.dataloaders = None | ||||
|         self.input_dims = None | ||||
|         self.output_dims = None | ||||
|  | ||||
|     def prepare_data(self, datasets) -> None: | ||||
|         batch_size = self.cfg.train.batch_size | ||||
|         num_workers = self.cfg.train.num_workers | ||||
|         self.dataloaders = {split: DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, | ||||
|                                               shuffle='debug' not in self.cfg.general.name) | ||||
|                             for split, dataset in datasets.items()} | ||||
|  | ||||
|     def train_dataloader(self): | ||||
|         return self.dataloaders["train"] | ||||
|  | ||||
|     def val_dataloader(self): | ||||
|         return self.dataloaders["val"] | ||||
|  | ||||
|     def test_dataloader(self): | ||||
|         return self.dataloaders["test"] | ||||
|  | ||||
|     def __getitem__(self, idx): | ||||
|         return self.dataloaders['train'][idx] | ||||
|  | ||||
|     def node_counts(self, max_nodes_possible=300): | ||||
|         all_counts = torch.zeros(max_nodes_possible) | ||||
|         for split in ['train', 'val', 'test']: | ||||
|             for i, data in enumerate(self.dataloaders[split]): | ||||
|                 unique, counts = torch.unique(data.batch, return_counts=True) | ||||
|                 for count in counts: | ||||
|                     all_counts[count] += 1 | ||||
|         max_index = max(all_counts.nonzero()) | ||||
|         all_counts = all_counts[:max_index + 1] | ||||
|         all_counts = all_counts / all_counts.sum() | ||||
|         return all_counts | ||||
|  | ||||
|     def node_types(self): | ||||
|         num_classes = None | ||||
|         for data in self.dataloaders['train']: | ||||
|             num_classes = data.x.shape[1] | ||||
|             break | ||||
|  | ||||
|         counts = torch.zeros(num_classes) | ||||
|  | ||||
|         for split in ['train', 'val', 'test']: | ||||
|             for i, data in enumerate(self.dataloaders[split]): | ||||
|                 counts += data.x.sum(dim=0) | ||||
|  | ||||
|         counts = counts / counts.sum() | ||||
|         return counts | ||||
|  | ||||
|     def edge_counts(self): | ||||
|         num_classes = None | ||||
|         for data in self.dataloaders['train']: | ||||
|             num_classes = 5 | ||||
|             break | ||||
|  | ||||
|         d = torch.Tensor(num_classes) | ||||
|  | ||||
|         for split in ['train', 'val', 'test']: | ||||
|             for i, data in enumerate(self.dataloaders[split]): | ||||
|                 unique, counts = torch.unique(data.batch, return_counts=True) | ||||
|  | ||||
|                 all_pairs = 0 | ||||
|                 for count in counts: | ||||
|                     all_pairs += count * (count - 1) | ||||
|  | ||||
|                 num_edges = data.edge_index.shape[1] | ||||
|                 num_non_edges = all_pairs - num_edges | ||||
|  | ||||
|                 data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float() | ||||
|                 edge_types = data_edge_attr.sum(dim=0) | ||||
|                 assert num_non_edges >= 0 | ||||
|                 d[0] += num_non_edges | ||||
|                 d[1:] += edge_types[1:] | ||||
|  | ||||
|         d = d / d.sum() | ||||
|         return d | ||||
|  | ||||
|  | ||||
| class MolecularDataModule(AbstractDataModule): | ||||
|     def valency_count(self, max_n_nodes): | ||||
|         valencies = torch.zeros(3 * max_n_nodes - 2)   # Max valency possible if everything is connected | ||||
|         multiplier = torch.Tensor([0, 1, 2, 3, 1.5]) | ||||
|         for split in ['train', 'val', 'test']: | ||||
|             for i, data in enumerate(self.dataloaders[split]): | ||||
|                 n = data.x.shape[0] | ||||
|                 for atom in range(n): | ||||
|                     data_edge_attr = torch.nn.functional.one_hot(data.edge_attr, num_classes=5).float() | ||||
|                     edges = data_edge_attr[data.edge_index[0] == atom] | ||||
|                     edges_total = edges.sum(dim=0) | ||||
|                     valency = (edges_total * multiplier).sum() | ||||
|                     valencies[valency.long().item()] += 1 | ||||
|         valencies = valencies / valencies.sum() | ||||
|         return valencies | ||||
|  | ||||
|  | ||||
| class AbstractDatasetInfos: | ||||
|     def complete_infos(self, n_nodes, node_types): | ||||
|         self.input_dims = None | ||||
|         self.output_dims = None | ||||
|         self.num_classes = len(node_types) | ||||
|         self.max_n_nodes = len(n_nodes) - 1 | ||||
|         self.nodes_dist = DistributionNodes(n_nodes) | ||||
|  | ||||
|     def compute_input_output_dims(self, datamodule): | ||||
|         example_batch = datamodule.example_batch() | ||||
|         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] | ||||
|         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float() | ||||
|  | ||||
|         self.input_dims = {'X': example_batch_x.size(1),  | ||||
|                            'E': example_batch_edge_attr.size(1),  | ||||
|                            'y': example_batch['y'].size(1)} | ||||
|         self.output_dims = {'X': example_batch_x.size(1), | ||||
|                             'E': example_batch_edge_attr.size(1), | ||||
|                             'y': example_batch['y'].size(1)} | ||||
							
								
								
									
										381
									
								
								graph_dit/datasets/dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										381
									
								
								graph_dit/datasets/dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,381 @@ | ||||
|  | ||||
| import sys | ||||
| sys.path.append('../')  | ||||
|  | ||||
| import os | ||||
| import os.path as osp | ||||
| import pathlib | ||||
| import json | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from rdkit import Chem, RDLogger | ||||
| from rdkit.Chem.rdchem import BondType as BT | ||||
| from tqdm import tqdm | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| from torch_geometric.data import Data, InMemoryDataset | ||||
| from torch_geometric.loader import DataLoader | ||||
| from sklearn.model_selection import train_test_split | ||||
|  | ||||
| import utils as utils | ||||
| from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | ||||
| from diffusion.distributions import DistributionNodes | ||||
|  | ||||
| bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||
|  | ||||
| class DataModule(AbstractDataModule): | ||||
|     def __init__(self, cfg): | ||||
|         self.datadir = cfg.dataset.datadir | ||||
|         self.task = cfg.dataset.task_name | ||||
|         super().__init__(cfg) | ||||
|  | ||||
|     def prepare_data(self) -> None: | ||||
|         target = getattr(self.cfg.dataset, 'guidance_target', None) | ||||
|         base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||
|         root_path = os.path.join(base_path, self.datadir) | ||||
|         self.root_path = root_path | ||||
|  | ||||
|         batch_size = self.cfg.train.batch_size | ||||
|         num_workers = self.cfg.train.num_workers | ||||
|         pin_memory = self.cfg.dataset.pin_memory | ||||
|  | ||||
|         dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) | ||||
|  | ||||
|         if len(self.task.split('-')) == 2: | ||||
|             train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) | ||||
|         else: | ||||
|             train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) | ||||
|  | ||||
|         self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index | ||||
|         train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) | ||||
|         if len(unlabeled_index) > 0: | ||||
|             train_index = torch.cat([train_index, unlabeled_index], dim=0) | ||||
|          | ||||
|         train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] | ||||
|         self.train_dataset = train_dataset | ||||
|         self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) | ||||
|         self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) | ||||
|         self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) | ||||
|  | ||||
|         training_iterations = len(train_dataset) // batch_size | ||||
|         self.training_iterations = training_iterations | ||||
|      | ||||
|     def random_data_split(self, dataset): | ||||
|         nan_count = torch.isnan(dataset.y[:, 0]).sum().item() | ||||
|         labeled_len = len(dataset) - nan_count | ||||
|         full_idx = list(range(labeled_len)) | ||||
|         train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 | ||||
|         train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) | ||||
|         train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42) | ||||
|         unlabeled_index = list(range(labeled_len, len(dataset))) | ||||
|         print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index)) | ||||
|         return train_index, val_index, test_index, unlabeled_index | ||||
|      | ||||
|     def fixed_split(self, dataset): | ||||
|         if self.task == 'O2-N2': | ||||
|             test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604] | ||||
|         else: | ||||
|             raise ValueError('Invalid task name: {}'.format(self.task)) | ||||
|         full_idx = list(range(len(dataset))) | ||||
|         full_idx = list(set(full_idx) - set(test_index)) | ||||
|         train_ratio = 0.8 | ||||
|         train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42) | ||||
|         print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) | ||||
|         return train_index, val_index, test_index, [] | ||||
|  | ||||
|     def get_train_smiles(self): | ||||
|         filename = f'{self.task}.csv.gz' | ||||
|         df = pd.read_csv(f'{self.root_path}/raw/{filename}') | ||||
|         df_test = df.iloc[self.test_index] | ||||
|         df = df.iloc[self.train_index] | ||||
|         smiles_list = df['smiles'].tolist() | ||||
|         smiles_list_test = df_test['smiles'].tolist() | ||||
|         smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list] | ||||
|         smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test] | ||||
|         return smiles_list, smiles_list_test | ||||
|      | ||||
|     def get_data_split(self): | ||||
|         filename = f'{self.task}.csv.gz' | ||||
|         df = pd.read_csv(f'{self.root_path}/raw/{filename}') | ||||
|         df_val = df.iloc[self.val_index] | ||||
|         df_test = df.iloc[self.test_index] | ||||
|         df_train = df.iloc[self.train_index] | ||||
|         return df_train, df_val, df_test | ||||
|  | ||||
|     def example_batch(self): | ||||
|         return next(iter(self.val_loader)) | ||||
|      | ||||
|     def train_dataloader(self): | ||||
|         return self.train_loader | ||||
|  | ||||
|     def val_dataloader(self): | ||||
|         return self.val_loader | ||||
|      | ||||
|     def test_dataloader(self): | ||||
|         return self.test_loader | ||||
|  | ||||
|  | ||||
| class Dataset(InMemoryDataset): | ||||
|     def __init__(self, source, root, target_prop=None, | ||||
|                  transform=None, pre_transform=None, pre_filter=None): | ||||
|         self.target_prop = target_prop | ||||
|         self.source = source | ||||
|         super().__init__(root, transform, pre_transform, pre_filter) | ||||
|         self.data, self.slices = torch.load(self.processed_paths[0]) | ||||
|  | ||||
|     @property | ||||
|     def raw_file_names(self): | ||||
|         return [f'{self.source}.csv.gz'] | ||||
|      | ||||
|     @property | ||||
|     def processed_file_names(self): | ||||
|         return [f'{self.source}.pt'] | ||||
|  | ||||
|     def process(self): | ||||
|         RDLogger.DisableLog('rdApp.*') | ||||
|         data_path = osp.join(self.raw_dir, self.raw_file_names[0]) | ||||
|         data_df = pd.read_csv(data_path) | ||||
|         | ||||
|         def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None): | ||||
|             type_idx = [] | ||||
|             heavy_atom_indices, active_atoms = [], [] | ||||
|             for atom in mol.GetAtoms(): | ||||
|                 if atom.GetAtomicNum() != 1: | ||||
|                     type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2) | ||||
|                     heavy_atom_indices.append(atom.GetIdx()) | ||||
|                     active_atoms.append(atom.GetSymbol()) | ||||
|                     if valid_atoms is not None: | ||||
|                         if not atom.GetSymbol() in valid_atoms: | ||||
|                             return None, None | ||||
|             x = torch.LongTensor(type_idx) | ||||
|  | ||||
|             edges_list = [] | ||||
|             edge_type = [] | ||||
|             for bond in mol.GetBonds(): | ||||
|                 start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | ||||
|                 if start in heavy_atom_indices and end in heavy_atom_indices: | ||||
|                     start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end) | ||||
|                     edges_list.append((start_new, end_new)) | ||||
|                     edge_type.append(bonds[bond.GetBondType()]) | ||||
|                     edges_list.append((end_new, start_new)) | ||||
|                     edge_type.append(bonds[bond.GetBondType()]) | ||||
|             edge_index = torch.tensor(edges_list, dtype=torch.long).t() | ||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) | ||||
|             edge_attr = edge_type | ||||
|  | ||||
|             if target3 is not None: | ||||
|                 y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1) | ||||
|             elif target2 is not None: | ||||
|                 y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1) | ||||
|             else: | ||||
|                 y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1) | ||||
|  | ||||
|             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||
|             if self.pre_transform is not None: | ||||
|                 data = self.pre_transform(data) | ||||
|             return data, active_atoms | ||||
|          | ||||
|         # Loop through every row in the DataFrame and apply the function | ||||
|         data_list = [] | ||||
|         len_data = len(data_df) | ||||
|         with tqdm(total=len_data) as pbar: | ||||
|             # --- data processing start --- | ||||
|             active_atoms = set() | ||||
|             for i, (sms, df_row) in enumerate(data_df.iterrows()): | ||||
|                 if i == sms: | ||||
|                     sms = df_row['smiles'] | ||||
|                 mol = Chem.MolFromSmiles(sms, sanitize=False) | ||||
|                 if len(self.target_prop.split('-')) == 2: | ||||
|                     target1, target2 = self.target_prop.split('-') | ||||
|                     data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2]) | ||||
|                 elif len(self.target_prop.split('-')) == 3: | ||||
|                     target1, target2, target3 = self.target_prop.split('-') | ||||
|                     data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3]) | ||||
|                 else: | ||||
|                     data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop]) | ||||
|                 active_atoms.update(cur_active_atoms) | ||||
|                 data_list.append(data) | ||||
|                 pbar.update(1) | ||||
|  | ||||
|         torch.save(self.collate(data_list), self.processed_paths[0]) | ||||
|  | ||||
|  | ||||
| class DataInfos(AbstractDatasetInfos): | ||||
|     def __init__(self, datamodule, cfg): | ||||
|         tasktype_dict = { | ||||
|             'hiv_b': 'classification', | ||||
|             'bace_b': 'classification', | ||||
|             'bbbp_b': 'classification', | ||||
|             'O2': 'regression', | ||||
|             'N2': 'regression', | ||||
|             'CO2': 'regression', | ||||
|         } | ||||
|         task_name = cfg.dataset.task_name | ||||
|         self.task = task_name | ||||
|         self.task_type = tasktype_dict.get(task_name, "regression") | ||||
|         self.ensure_connected = cfg.model.ensure_connected | ||||
|  | ||||
|         datadir = cfg.dataset.datadir | ||||
|  | ||||
|         base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||
|         meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json') | ||||
|         data_root = os.path.join(base_path, datadir, 'raw') | ||||
|         if os.path.exists(meta_filename): | ||||
|             with open(meta_filename, 'r') as f: | ||||
|                 meta_dict = json.load(f) | ||||
|         else: | ||||
|             meta_dict = compute_meta(data_root, task_name, datamodule.train_index, datamodule.test_index) | ||||
|  | ||||
|         self.base_path = base_path | ||||
|         self.active_atoms = meta_dict['active_atoms'] | ||||
|         self.max_n_nodes = meta_dict['max_node'] | ||||
|         self.original_max_n_nodes = meta_dict['max_node'] | ||||
|         self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist']) | ||||
|         self.edge_types = torch.Tensor(meta_dict['bond_type_dist']) | ||||
|         self.transition_E = torch.Tensor(meta_dict['transition_E']) | ||||
|  | ||||
|         self.atom_decoder = meta_dict['active_atoms'] | ||||
|         node_types = torch.Tensor(meta_dict['atom_type_dist']) | ||||
|         active_index = (node_types > 0).nonzero().squeeze() | ||||
|         self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index] | ||||
|         self.nodes_dist = DistributionNodes(self.n_nodes) | ||||
|         self.active_index = active_index | ||||
|  | ||||
|         val_len = 3 * self.original_max_n_nodes - 2 | ||||
|         meta_val = torch.Tensor(meta_dict['valencies']) | ||||
|         self.valency_distribution = torch.zeros(val_len) | ||||
|         val_len = min(val_len, len(meta_val)) | ||||
|         self.valency_distribution[:val_len] = meta_val[:val_len] | ||||
|         self.y_prior = None | ||||
|         self.train_ymin = [] | ||||
|         self.train_ymax = [] | ||||
|  | ||||
|  | ||||
| def compute_meta(root, source_name, train_index, test_index): | ||||
|     pt = Chem.GetPeriodicTable() | ||||
|     atom_name_list = [] | ||||
|     atom_count_list = [] | ||||
|     for i in range(2, 119): | ||||
|         atom_name_list.append(pt.GetElementSymbol(i)) | ||||
|         atom_count_list.append(0) | ||||
|     atom_name_list.append('*') | ||||
|     atom_count_list.append(0) | ||||
|     n_atoms_per_mol = [0] * 500 | ||||
|     bond_count_list = [0, 0, 0, 0, 0] | ||||
|     bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||
|     valencies = [0] * 500 | ||||
|     tansition_E = np.zeros((118, 118, 5)) | ||||
|      | ||||
|     filename = f'{source_name}.csv.gz' | ||||
|     df = pd.read_csv(f'{root}/{filename}') | ||||
|     all_index = list(range(len(df))) | ||||
|     non_test_index = list(set(all_index) - set(test_index)) | ||||
|     df = df.iloc[non_test_index] | ||||
|     tot_smiles = df['smiles'].tolist() | ||||
|  | ||||
|     n_atom_list = [] | ||||
|     n_bond_list = [] | ||||
|     for i, sms in enumerate(tot_smiles): | ||||
|         try: | ||||
|             mol = Chem.MolFromSmiles(sms) | ||||
|         except: | ||||
|             continue | ||||
|  | ||||
|         n_atom = mol.GetNumHeavyAtoms() | ||||
|         n_bond = mol.GetNumBonds() | ||||
|         n_atom_list.append(n_atom) | ||||
|         n_bond_list.append(n_bond) | ||||
|  | ||||
|         n_atoms_per_mol[n_atom] += 1 | ||||
|         cur_atom_count_arr = np.zeros(118) | ||||
|         for atom in mol.GetAtoms(): | ||||
|             symbol = atom.GetSymbol() | ||||
|             if symbol == 'H': | ||||
|                 continue | ||||
|             elif symbol == '*': | ||||
|                 atom_count_list[-1] += 1 | ||||
|                 cur_atom_count_arr[-1] += 1 | ||||
|             else: | ||||
|                 atom_count_list[atom.GetAtomicNum()-2] += 1 | ||||
|                 cur_atom_count_arr[atom.GetAtomicNum()-2] += 1 | ||||
|                 try: | ||||
|                     valencies[int(atom.GetExplicitValence())] += 1 | ||||
|                 except: | ||||
|                     print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence())) | ||||
|          | ||||
|         tansition_E_temp = np.zeros((118, 118, 5)) | ||||
|         for bond in mol.GetBonds(): | ||||
|             start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() | ||||
|             if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H': | ||||
|                 continue | ||||
|              | ||||
|             if start_atom.GetSymbol() == '*': | ||||
|                 start_index = 117 | ||||
|             else: | ||||
|                 start_index = start_atom.GetAtomicNum() - 2 | ||||
|             if end_atom.GetSymbol() == '*': | ||||
|                 end_index = 117 | ||||
|             else: | ||||
|                 end_index = end_atom.GetAtomicNum() - 2 | ||||
|  | ||||
|             bond_type = bond.GetBondType() | ||||
|             bond_index = bond_type_to_index[bond_type] | ||||
|             bond_count_list[bond_index] += 2 | ||||
|  | ||||
|             tansition_E[start_index, end_index, bond_index] += 2 | ||||
|             tansition_E[end_index, start_index, bond_index] += 2 | ||||
|             tansition_E_temp[start_index, end_index, bond_index] += 2 | ||||
|             tansition_E_temp[end_index, start_index, bond_index] += 2 | ||||
|  | ||||
|         bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 | ||||
|         cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118 | ||||
|         cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118 | ||||
|         tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1) | ||||
|         assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' | ||||
|      | ||||
|     n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) | ||||
|     n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] | ||||
|  | ||||
|     atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) | ||||
|     print('processed meta info: ------', filename, '------') | ||||
|     print('len atom_count_list', len(atom_count_list)) | ||||
|     print('len atom_name_list', len(atom_name_list)) | ||||
|     active_atoms = np.array(atom_name_list)[atom_count_list > 0] | ||||
|     active_atoms = active_atoms.tolist() | ||||
|     atom_count_list = atom_count_list.tolist() | ||||
|  | ||||
|     bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) | ||||
|     bond_count_list = bond_count_list.tolist() | ||||
|     valencies = np.array(valencies) / np.sum(valencies) | ||||
|     valencies = valencies.tolist() | ||||
|  | ||||
|     no_edge = np.sum(tansition_E, axis=-1) == 0 | ||||
|     first_elt = tansition_E[:, :, 0] | ||||
|     first_elt[no_edge] = 1 | ||||
|     tansition_E[:, :, 0] = first_elt | ||||
|  | ||||
|     tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True) | ||||
|      | ||||
|     meta_dict = { | ||||
|         'source': source_name,  | ||||
|         'num_graph': len(n_atom_list),  | ||||
|         'n_atoms_per_mol_dist': n_atoms_per_mol, | ||||
|         'max_node': max(n_atom_list),  | ||||
|         'max_bond': max(n_bond_list),  | ||||
|         'atom_type_dist': atom_count_list, | ||||
|         'bond_type_dist': bond_count_list, | ||||
|         'valencies': valencies, | ||||
|         'active_atoms': active_atoms, | ||||
|         'num_atom_type': len(active_atoms), | ||||
|         'transition_E': tansition_E.tolist(), | ||||
|         } | ||||
|  | ||||
|     with open(f'{root}/{source_name}.meta.json', "w") as f: | ||||
|         json.dump(meta_dict, f) | ||||
|      | ||||
|     return meta_dict | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     pass | ||||
							
								
								
									
										0
									
								
								graph_dit/diffusion/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/diffusion/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										224
									
								
								graph_dit/diffusion/diffusion_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										224
									
								
								graph_dit/diffusion/diffusion_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,224 @@ | ||||
| import torch | ||||
| from torch.nn import functional as F | ||||
| import numpy as np | ||||
| from utils import PlaceHolder | ||||
|  | ||||
|  | ||||
| def sum_except_batch(x): | ||||
|     return x.reshape(x.size(0), -1).sum(dim=-1) | ||||
|  | ||||
| def assert_correctly_masked(variable, node_mask): | ||||
|     assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \ | ||||
|         'Variables not masked properly.' | ||||
|  | ||||
| def cosine_beta_schedule_discrete(timesteps, s=0.008): | ||||
|     """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """ | ||||
|     steps = timesteps + 2 | ||||
|     x = np.linspace(0, steps, steps) | ||||
|  | ||||
|     alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2 | ||||
|     alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | ||||
|     alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) | ||||
|     betas = 1 - alphas | ||||
|     return betas.squeeze() | ||||
|  | ||||
| def custom_beta_schedule_discrete(timesteps, average_num_nodes=30, s=0.008): | ||||
|     """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """ | ||||
|     steps = timesteps + 2 | ||||
|     x = np.linspace(0, steps, steps) | ||||
|  | ||||
|     alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2 | ||||
|     alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | ||||
|     alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) | ||||
|     betas = 1 - alphas | ||||
|  | ||||
|     assert timesteps >= 100 | ||||
|  | ||||
|     p = 4 / 5       # 1 - 1 / num_edge_classes | ||||
|     num_edges = average_num_nodes * (average_num_nodes - 1) / 2 | ||||
|  | ||||
|     # First 100 steps: only a few updates per graph | ||||
|     updates_per_graph = 1.2 | ||||
|     beta_first = updates_per_graph / (p * num_edges) | ||||
|  | ||||
|     betas[betas < beta_first] = beta_first | ||||
|     return np.array(betas) | ||||
|  | ||||
|  | ||||
| def check_mask_correct(variables, node_mask): | ||||
|     for i, variable in enumerate(variables): | ||||
|         if len(variable) > 0: | ||||
|             assert_correctly_masked(variable, node_mask) | ||||
|  | ||||
|  | ||||
| def check_tensor_same_size(*args): | ||||
|     for i, arg in enumerate(args): | ||||
|         if i == 0: | ||||
|             continue | ||||
|         assert args[0].size() == arg.size() | ||||
|  | ||||
|  | ||||
|  | ||||
| def reverse_tensor(x): | ||||
|     return x[torch.arange(x.size(0) - 1, -1, -1)] | ||||
|  | ||||
|  | ||||
| def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True): | ||||
|     ''' Sample features from multinomial distribution with given probabilities (probX, probE, proby) | ||||
|         :param probX: bs, n, dx_out        node features | ||||
|         :param probE: bs, n, n, de_out     edge features | ||||
|         :param proby: bs, dy_out           global features. | ||||
|     ''' | ||||
|     bs, n, _ = probX.shape | ||||
|  | ||||
|     # Noise X | ||||
|     # The masked rows should define probability distributions as well | ||||
|     probX[~node_mask] = 1 / probX.shape[-1] | ||||
|  | ||||
|     # Flatten the probability tensor to sample with multinomial | ||||
|     probX = probX.reshape(bs * n, -1)       # (bs * n, dx_out) | ||||
|  | ||||
|     # Sample X | ||||
|     probX = probX + 1e-12 | ||||
|     probX = probX / probX.sum(dim=-1, keepdim=True) | ||||
|     X_t = probX.multinomial(1)      # (bs * n, 1) | ||||
|     X_t = X_t.reshape(bs, n)        # (bs, n) | ||||
|  | ||||
|     # Noise E | ||||
|     # The masked rows should define probability distributions as well | ||||
|     inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2)) | ||||
|     diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1) | ||||
|  | ||||
|     probE[inverse_edge_mask] = 1 / probE.shape[-1] | ||||
|     probE[diag_mask.bool()] = 1 / probE.shape[-1] | ||||
|     probE = probE.reshape(bs * n * n, -1)           # (bs * n * n, de_out) | ||||
|     probE = probE + 1e-12 | ||||
|     probE = probE / probE.sum(dim=-1, keepdim=True) | ||||
|  | ||||
|     # Sample E | ||||
|     E_t = probE.multinomial(1).reshape(bs, n, n)    # (bs, n, n) | ||||
|     E_t = torch.triu(E_t, diagonal=1) | ||||
|     E_t = (E_t + torch.transpose(E_t, 1, 2)) | ||||
|  | ||||
|     return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t)) | ||||
|  | ||||
|  | ||||
| def compute_batched_over0_posterior_distribution(X_t, Qt, Qsb, Qtb): | ||||
|     """ M: X or E | ||||
|         Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0 | ||||
|         X_t: bs, n, dt          or bs, n, n, dt | ||||
|         Qt: bs, d_t-1, dt | ||||
|         Qsb: bs, d0, d_t-1 | ||||
|         Qtb: bs, d0, dt. | ||||
|     """ | ||||
|     X_t = X_t.float() | ||||
|     Qt_T = Qt.transpose(-1, -2).float()                                       # bs, N, dt | ||||
|     assert Qt.dim() == 3 | ||||
|     left_term = X_t @ Qt_T | ||||
|     left_term = left_term.unsqueeze(dim=2)                    # bs, N, 1, d_t-1 | ||||
|     right_term = Qsb.unsqueeze(1)  | ||||
|     numerator = left_term * right_term                        # bs, N, d0, d_t-1 | ||||
|      | ||||
|     denominator = Qtb @ X_t.transpose(-1, -2)                 # bs, d0, N | ||||
|     denominator = denominator.transpose(-1, -2)               # bs, N, d0 | ||||
|     denominator = denominator.unsqueeze(-1)                   # bs, N, d0, 1 | ||||
|  | ||||
|     denominator[denominator == 0] = 1. | ||||
|     return numerator / denominator | ||||
|  | ||||
|  | ||||
| def mask_distributions(true_X, true_E, pred_X, pred_E, node_mask): | ||||
|     # Add a small value everywhere to avoid nans | ||||
|     pred_X = pred_X.clamp_min(1e-18) | ||||
|     pred_X = pred_X / torch.sum(pred_X, dim=-1, keepdim=True) | ||||
|  | ||||
|     pred_E = pred_E.clamp_min(1e-18) | ||||
|     pred_E = pred_E / torch.sum(pred_E, dim=-1, keepdim=True) | ||||
|  | ||||
|     # Set masked rows to arbitrary distributions, so it doesn't contribute to loss | ||||
|     row_X = torch.ones(true_X.size(-1), dtype=true_X.dtype, device=true_X.device) | ||||
|     row_E = torch.zeros(true_E.size(-1), dtype=true_E.dtype, device=true_E.device).clamp_min(1e-18) | ||||
|     row_E[0] = 1. | ||||
|  | ||||
|     diag_mask = ~torch.eye(node_mask.size(1), device=node_mask.device, dtype=torch.bool).unsqueeze(0) | ||||
|     true_X[~node_mask] = row_X | ||||
|     true_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E | ||||
|     pred_X[~node_mask] = row_X | ||||
|     pred_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E | ||||
|  | ||||
|     return true_X, true_E, pred_X, pred_E | ||||
|  | ||||
| def posterior_distributions(X, X_t, Qt, Qsb, Qtb, X_dim): | ||||
|     bs, n, d = X.shape | ||||
|     X = X.float() | ||||
|     Qt_X_T = torch.transpose(Qt.X, -2, -1).float()                  # (bs, d, d) | ||||
|     left_term = X_t @ Qt_X_T                                        # (bs, N, d) | ||||
|     right_term = X @ Qsb.X                                          # (bs, N, d) | ||||
|      | ||||
|     numerator = left_term * right_term                              # (bs, N, d) | ||||
|     denominator = X @ Qtb.X                                         # (bs, N, d) @ (bs, d, d) = (bs, N, d) | ||||
|     denominator = denominator * X_t  | ||||
|      | ||||
|     num_X = numerator[:, :, :X_dim] | ||||
|     num_E = numerator[:, :, X_dim:].reshape(bs, n*n, -1) | ||||
|  | ||||
|     deno_X = denominator[:, :, :X_dim] | ||||
|     deno_E = denominator[:, :, X_dim:].reshape(bs, n*n, -1) | ||||
|  | ||||
|     # denominator = (denominator * X_t).sum(dim=-1)                   # (bs, N, d) * (bs, N, d) + sum = (bs, N) | ||||
|     denominator = denominator.unsqueeze(-1)                         # (bs, N, 1) | ||||
|  | ||||
|     deno_X = deno_X.sum(dim=-1).unsqueeze(-1) | ||||
|     deno_E = deno_E.sum(dim=-1).unsqueeze(-1) | ||||
|  | ||||
|     deno_X[deno_X == 0.] = 1 | ||||
|     deno_E[deno_E == 0.] = 1 | ||||
|     prob_X = num_X / deno_X | ||||
|     prob_E = num_E / deno_E | ||||
|      | ||||
|     prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True) | ||||
|     prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True) | ||||
|     return PlaceHolder(X=prob_X, E=prob_E, y=None) | ||||
|  | ||||
|  | ||||
| def sample_discrete_feature_noise(limit_dist, node_mask): | ||||
|     """ Sample from the limit distribution of the diffusion process""" | ||||
|     bs, n_max = node_mask.shape | ||||
|     x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1) | ||||
|     U_X = x_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max) | ||||
|     U_X = F.one_hot(U_X.long(), num_classes=x_limit.shape[-1]).float() | ||||
|  | ||||
|     e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1) | ||||
|     U_E = e_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max, n_max) | ||||
|     U_E = F.one_hot(U_E.long(), num_classes=e_limit.shape[-1]).float() | ||||
|  | ||||
|     U_X = U_X.to(node_mask.device) | ||||
|     U_E = U_E.to(node_mask.device) | ||||
|  | ||||
|     # Get upper triangular part of edge noise, without main diagonal | ||||
|     upper_triangular_mask = torch.zeros_like(U_E) | ||||
|     indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1) | ||||
|     upper_triangular_mask[:, indices[0], indices[1], :] = 1 | ||||
|  | ||||
|     U_E = U_E * upper_triangular_mask | ||||
|     U_E = (U_E + torch.transpose(U_E, 1, 2)) | ||||
|  | ||||
|     assert (U_E == torch.transpose(U_E, 1, 2)).all() | ||||
|     return PlaceHolder(X=U_X, E=U_E, y=None).mask(node_mask) | ||||
|  | ||||
| def index_QE(X, q_e, n_bond=5): | ||||
|     bs, n, n_atom = X.shape | ||||
|     node_indices = X.argmax(-1)  # (bs, n) | ||||
|  | ||||
|     exp_ind1 = node_indices[ :, :, None, None, None].expand(bs, n, n_atom, n_bond, n_bond) | ||||
|     exp_ind2 = node_indices[ :, :, None, None, None].expand(bs, n, n, n_bond, n_bond) | ||||
|      | ||||
|     q_e = torch.gather(q_e, 1, exp_ind1) | ||||
|     q_e = torch.gather(q_e, 2, exp_ind2) # (bs, n, n, n_bond, n_bond) | ||||
|  | ||||
|  | ||||
|     node_mask = X.sum(-1) != 0 | ||||
|     no_edge = (~node_mask)[:, :, None] & (~node_mask)[:, None, :] | ||||
|     q_e[no_edge] = torch.tensor([1, 0, 0, 0, 0]).type_as(q_e) | ||||
|  | ||||
|     return q_e | ||||
							
								
								
									
										30
									
								
								graph_dit/diffusion/distributions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								graph_dit/diffusion/distributions.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| import torch | ||||
|  | ||||
| class DistributionNodes: | ||||
|     def __init__(self, histogram): | ||||
|         """ Compute the distribution of the number of nodes in the dataset, and sample from this distribution. | ||||
|             historgram: dict. The keys are num_nodes, the values are counts | ||||
|         """ | ||||
|  | ||||
|         if type(histogram) == dict: | ||||
|             max_n_nodes = max(histogram.keys()) | ||||
|             prob = torch.zeros(max_n_nodes + 1) | ||||
|             for num_nodes, count in histogram.items(): | ||||
|                 prob[num_nodes] = count | ||||
|         else: | ||||
|             prob = histogram | ||||
|  | ||||
|         self.prob = prob / prob.sum() | ||||
|         self.m = torch.distributions.Categorical(prob) | ||||
|  | ||||
|     def sample_n(self, n_samples, device): | ||||
|         idx = self.m.sample((n_samples,)) | ||||
|         return idx.to(device) | ||||
|  | ||||
|     def log_prob(self, batch_n_nodes): | ||||
|         assert len(batch_n_nodes.size()) == 1 | ||||
|         p = self.prob.to(batch_n_nodes.device) | ||||
|  | ||||
|         probas = p[batch_n_nodes] | ||||
|         log_p = torch.log(probas + 1e-30) | ||||
|         return log_p | ||||
							
								
								
									
										159
									
								
								graph_dit/diffusion/noise_schedule.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								graph_dit/diffusion/noise_schedule.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| import torch | ||||
| import utils | ||||
| from diffusion import diffusion_utils | ||||
|      | ||||
| class PredefinedNoiseScheduleDiscrete(torch.nn.Module): | ||||
|     def __init__(self, noise_schedule, timesteps): | ||||
|         super(PredefinedNoiseScheduleDiscrete, self).__init__() | ||||
|         self.timesteps = timesteps | ||||
|  | ||||
|         if noise_schedule == 'cosine': | ||||
|             betas = diffusion_utils.cosine_beta_schedule_discrete(timesteps) | ||||
|         elif noise_schedule == 'custom': | ||||
|             betas = diffusion_utils.custom_beta_schedule_discrete(timesteps) | ||||
|         else: | ||||
|             raise NotImplementedError(noise_schedule) | ||||
|  | ||||
|         self.register_buffer('betas', torch.from_numpy(betas).float()) | ||||
|  | ||||
|         # 0.9999 | ||||
|         self.alphas = 1 - torch.clamp(self.betas, min=0, max=1) | ||||
|  | ||||
|         log_alpha = torch.log(self.alphas) | ||||
|         log_alpha_bar = torch.cumsum(log_alpha, dim=0) | ||||
|         self.alphas_bar = torch.exp(log_alpha_bar) | ||||
|  | ||||
|     def forward(self, t_normalized=None, t_int=None): | ||||
|         assert int(t_normalized is None) + int(t_int is None) == 1 | ||||
|         if t_int is None: | ||||
|             t_int = torch.round(t_normalized * self.timesteps) | ||||
|         return self.betas[t_int.long()] | ||||
|  | ||||
|     def get_alpha_bar(self, t_normalized=None, t_int=None): | ||||
|         assert int(t_normalized is None) + int(t_int is None) == 1 | ||||
|         if t_int is None: | ||||
|             t_int = torch.round(t_normalized * self.timesteps) | ||||
|         ### new | ||||
|         self.alphas_bar = self.alphas_bar.to(t_int.device) | ||||
|         return self.alphas_bar[t_int.long()] | ||||
|  | ||||
|  | ||||
| class DiscreteUniformTransition: | ||||
|     def __init__(self, x_classes: int, e_classes: int, y_classes: int): | ||||
|         self.X_classes = x_classes | ||||
|         self.E_classes = e_classes | ||||
|         self.y_classes = y_classes | ||||
|         self.u_x = torch.ones(1, self.X_classes, self.X_classes) | ||||
|         if self.X_classes > 0: | ||||
|             self.u_x = self.u_x / self.X_classes | ||||
|  | ||||
|         self.u_e = torch.ones(1, self.E_classes, self.E_classes) | ||||
|         if self.E_classes > 0: | ||||
|             self.u_e = self.u_e / self.E_classes | ||||
|  | ||||
|         self.u_y = torch.ones(1, self.y_classes, self.y_classes) | ||||
|         if self.y_classes > 0: | ||||
|             self.u_y = self.u_y / self.y_classes | ||||
|  | ||||
|     def get_Qt(self, beta_t, device, X=None, flatten_e=None): | ||||
|         """ Returns one-step transition matrices for X and E, from step t - 1 to step t. | ||||
|         Qt = (1 - beta_t) * I + beta_t / K | ||||
|  | ||||
|         beta_t: (bs)                         noise level between 0 and 1 | ||||
|         returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). | ||||
|         """ | ||||
|         beta_t = beta_t.unsqueeze(1) | ||||
|         beta_t = beta_t.to(device) | ||||
|         self.u_x = self.u_x.to(device) | ||||
|         self.u_e = self.u_e.to(device) | ||||
|         self.u_y = self.u_y.to(device) | ||||
|  | ||||
|         q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes, device=device).unsqueeze(0) | ||||
|         q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye(self.E_classes, device=device).unsqueeze(0) | ||||
|         q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye(self.y_classes, device=device).unsqueeze(0) | ||||
|  | ||||
|         return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) | ||||
|  | ||||
|     def get_Qt_bar(self, alpha_bar_t, device, X=None, flatten_e=None): | ||||
|         """ Returns t-step transition matrices for X and E, from step 0 to step t. | ||||
|         Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) / K | ||||
|  | ||||
|         alpha_bar_t: (bs)         Product of the (1 - beta_t) for each time step from 0 to t. | ||||
|         returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). | ||||
|         """ | ||||
|         alpha_bar_t = alpha_bar_t.unsqueeze(1) | ||||
|         alpha_bar_t = alpha_bar_t.to(device) | ||||
|         self.u_x = self.u_x.to(device) | ||||
|         self.u_e = self.u_e.to(device) | ||||
|         self.u_y = self.u_y.to(device) | ||||
|  | ||||
|         q_x = alpha_bar_t * torch.eye(self.X_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_x | ||||
|         q_e = alpha_bar_t * torch.eye(self.E_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_e | ||||
|         q_y = alpha_bar_t * torch.eye(self.y_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u_y | ||||
|  | ||||
|         return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) | ||||
|  | ||||
|      | ||||
| class MarginalTransition: | ||||
|     def __init__(self, x_marginals, e_marginals, xe_conditions, ex_conditions, y_classes, n_nodes): | ||||
|         self.X_classes = len(x_marginals) | ||||
|         self.E_classes = len(e_marginals) | ||||
|         self.y_classes = y_classes | ||||
|         self.x_marginals = x_marginals # Dx | ||||
|         self.e_marginals = e_marginals # Dx, De | ||||
|         self.xe_conditions = xe_conditions | ||||
|  | ||||
|         self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) # 1, Dx, Dx | ||||
|         self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) # 1, De, De | ||||
|         self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De | ||||
|         self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx | ||||
|         self.u = self.get_union_transition(self.u_x, self.u_e, self.u_xe, self.u_ex, n_nodes) # 1, Dx + n*De, Dx + n*De | ||||
|  | ||||
|     def get_union_transition(self, u_x, u_e, u_xe, u_ex, n_nodes): | ||||
|         u_e = u_e.repeat(1, n_nodes, n_nodes) # (1, n*de, n*de) | ||||
|         u_xe = u_xe.repeat(1, 1, n_nodes) # (1, dx, n*de) | ||||
|         u_ex = u_ex.repeat(1, n_nodes, 1) # (1, n*de, dx) | ||||
|         u0 = torch.cat([u_x, u_xe], dim=2) # (1, dx, dx + n*de) | ||||
|         u1 = torch.cat([u_ex, u_e], dim=2) # (1, n*de, dx + n*de) | ||||
|         u = torch.cat([u0, u1], dim=1) # (1, dx + n*de, dx + n*de) | ||||
|         return u | ||||
|  | ||||
|     def index_edge_margin(self, X, q_e, n_bond=5): | ||||
|         # q_e: (bs, dx, de) --> (bs, n, de) | ||||
|         bs, n, n_atom = X.shape | ||||
|         node_indices = X.argmax(-1)  # (bs, n) | ||||
|         ind = node_indices[ :, :, None].expand(bs, n, n_bond) | ||||
|         q_e = torch.gather(q_e, 1, ind) | ||||
|         return q_e | ||||
|      | ||||
|     def get_Qt(self, beta_t, device): | ||||
|         """ Returns one-step transition matrices for X and E, from step t - 1 to step t. | ||||
|         Qt = (1 - beta_t) * I + beta_t / K | ||||
|         beta_t: (bs) | ||||
|         returns: q (bs, d0, d0) | ||||
|         """ | ||||
|         bs = beta_t.size(0) | ||||
|         d0 = self.u.size(-1) | ||||
|         self.u = self.u.to(device) | ||||
|         u = self.u.expand(bs, d0, d0) | ||||
|  | ||||
|         beta_t = beta_t.to(device) | ||||
|         beta_t = beta_t.view(bs, 1, 1) | ||||
|         q = beta_t * u + (1 - beta_t) * torch.eye(d0, device=device).unsqueeze(0) | ||||
|  | ||||
|         return utils.PlaceHolder(X=q, E=None, y=None) | ||||
|      | ||||
|     def get_Qt_bar(self, alpha_bar_t, device): | ||||
|         """ Returns t-step transition matrices for X and E, from step 0 to step t. | ||||
|         Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) * K | ||||
|         alpha_bar_t: (bs, 1) roduct of the (1 - beta_t) for each time step from 0 to t. | ||||
|         returns: q (bs, d0, d0) | ||||
|         """ | ||||
|         bs = alpha_bar_t.size(0) | ||||
|         d0 = self.u.size(-1) | ||||
|         alpha_bar_t = alpha_bar_t.to(device) | ||||
|         alpha_bar_t = alpha_bar_t.view(bs, 1, 1) | ||||
|         self.u = self.u.to(device) | ||||
|         q = alpha_bar_t * torch.eye(d0, device=device).unsqueeze(0) + (1 - alpha_bar_t) * self.u | ||||
|  | ||||
|         return utils.PlaceHolder(X=q, E=None, y=None) | ||||
							
								
								
									
										617
									
								
								graph_dit/diffusion_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										617
									
								
								graph_dit/diffusion_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,617 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import pytorch_lightning as pl | ||||
| import time | ||||
| import os | ||||
|  | ||||
| from models.transformer import Denoiser | ||||
| from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition | ||||
|  | ||||
| from diffusion import diffusion_utils | ||||
| from metrics.train_loss import TrainLossDiscrete | ||||
| from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL | ||||
| import utils | ||||
|  | ||||
| class Graph_DiT(pl.LightningModule): | ||||
|     def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): | ||||
|         super().__init__() | ||||
|         self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) | ||||
|         self.test_only = cfg.general.test_only | ||||
|         self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) | ||||
|  | ||||
|         input_dims = dataset_infos.input_dims | ||||
|         output_dims = dataset_infos.output_dims | ||||
|         nodes_dist = dataset_infos.nodes_dist | ||||
|         active_index = dataset_infos.active_index | ||||
|  | ||||
|         self.cfg = cfg | ||||
|         self.name = cfg.general.name | ||||
|         self.T = cfg.model.diffusion_steps | ||||
|         self.guide_scale = cfg.model.guide_scale | ||||
|  | ||||
|         self.Xdim = input_dims['X'] | ||||
|         self.Edim = input_dims['E'] | ||||
|         self.ydim = input_dims['y'] | ||||
|         self.Xdim_output = output_dims['X'] | ||||
|         self.Edim_output = output_dims['E'] | ||||
|         self.ydim_output = output_dims['y'] | ||||
|         self.node_dist = nodes_dist | ||||
|         self.active_index = active_index | ||||
|         self.dataset_info = dataset_infos | ||||
|  | ||||
|         self.train_loss = TrainLossDiscrete(self.cfg.model.lambda_train) | ||||
|  | ||||
|         self.val_nll = NLL() | ||||
|         self.val_X_kl = SumExceptBatchKL() | ||||
|         self.val_E_kl = SumExceptBatchKL() | ||||
|         self.val_X_logp = SumExceptBatchMetric() | ||||
|         self.val_E_logp = SumExceptBatchMetric() | ||||
|         self.val_y_collection = [] | ||||
|  | ||||
|         self.test_nll = NLL() | ||||
|         self.test_X_kl = SumExceptBatchKL() | ||||
|         self.test_E_kl = SumExceptBatchKL() | ||||
|         self.test_X_logp = SumExceptBatchMetric() | ||||
|         self.test_E_logp = SumExceptBatchMetric() | ||||
|         self.test_y_collection = [] | ||||
|  | ||||
|         self.train_metrics = train_metrics | ||||
|         self.sampling_metrics = sampling_metrics | ||||
|  | ||||
|         self.visualization_tools = visualization_tools | ||||
|         self.max_n_nodes = dataset_infos.max_n_nodes | ||||
|  | ||||
|         self.model = Denoiser(max_n_nodes=self.max_n_nodes, | ||||
|                         hidden_size=cfg.model.hidden_size, | ||||
|                         depth=cfg.model.depth, | ||||
|                         num_heads=cfg.model.num_heads, | ||||
|                         mlp_ratio=cfg.model.mlp_ratio, | ||||
|                         drop_condition=cfg.model.drop_condition, | ||||
|                         Xdim=self.Xdim,  | ||||
|                         Edim=self.Edim, | ||||
|                         ydim=self.ydim, | ||||
|                         task_type=dataset_infos.task_type) | ||||
|          | ||||
|         self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule, | ||||
|                                                               timesteps=cfg.model.diffusion_steps) | ||||
|  | ||||
|  | ||||
|         x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) | ||||
|          | ||||
|         e_marginals = self.dataset_info.edge_types.float() / torch.sum(self.dataset_info.edge_types.float()) | ||||
|         x_marginals = x_marginals / (x_marginals ).sum() | ||||
|         e_marginals = e_marginals / (e_marginals ).sum() | ||||
|  | ||||
|         xe_conditions = self.dataset_info.transition_E.float() | ||||
|         xe_conditions = xe_conditions[self.active_index][:, self.active_index]  | ||||
|          | ||||
|         xe_conditions = xe_conditions.sum(dim=1)  | ||||
|         ex_conditions = xe_conditions.t() | ||||
|         xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) | ||||
|         ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) | ||||
|          | ||||
|         self.transition_model = MarginalTransition(x_marginals=x_marginals,  | ||||
|                                                           e_marginals=e_marginals,  | ||||
|                                                           xe_conditions=xe_conditions, | ||||
|                                                           ex_conditions=ex_conditions, | ||||
|                                                           y_classes=self.ydim_output, | ||||
|                                                           n_nodes=self.max_n_nodes) | ||||
|  | ||||
|         self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) | ||||
|  | ||||
|         self.start_epoch_time = None | ||||
|         self.train_iterations = None | ||||
|         self.val_iterations = None | ||||
|         self.log_every_steps = cfg.general.log_every_steps | ||||
|         self.number_chain_steps = cfg.general.number_chain_steps | ||||
|  | ||||
|         self.best_val_nll = 1e8 | ||||
|         self.val_counter = 0 | ||||
|         self.batch_size = self.cfg.train.batch_size | ||||
|     | ||||
|  | ||||
|     def forward(self, noisy_data, unconditioned=False): | ||||
|         x, e, y = noisy_data['X_t'].float(), noisy_data['E_t'].float(), noisy_data['y_t'].float().clone() | ||||
|         node_mask, t =  noisy_data['node_mask'], noisy_data['t'] | ||||
|         pred = self.model(x, e, node_mask, y=y, t=t, unconditioned=unconditioned) | ||||
|         return pred | ||||
|          | ||||
|     def training_step(self, data, i): | ||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() | ||||
|  | ||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||
|         dense_data = dense_data.mask(node_mask) | ||||
|         X, E = dense_data.X, dense_data.E | ||||
|         noisy_data = self.apply_noise(X, E, data.y, node_mask) | ||||
|         pred = self.forward(noisy_data) | ||||
|         loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, | ||||
|                             true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, | ||||
|                             log=i % self.log_every_steps == 0) | ||||
|  | ||||
|         self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, | ||||
|                         log=i % self.log_every_steps == 0) | ||||
|         self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) | ||||
|         return {'loss': loss} | ||||
|  | ||||
|  | ||||
|     def configure_optimizers(self): | ||||
|         params = self.parameters() | ||||
|         optimizer = torch.optim.AdamW(params, lr=self.cfg.train.lr, amsgrad=True, | ||||
|                                  weight_decay=self.cfg.train.weight_decay) | ||||
|         return optimizer | ||||
|      | ||||
|     def on_fit_start(self) -> None: | ||||
|         self.train_iterations = self.trainer.datamodule.training_iterations | ||||
|         print('on fit train iteration:', self.train_iterations) | ||||
|         print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) | ||||
|  | ||||
|     def on_train_epoch_start(self) -> None: | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) | ||||
|         self.start_epoch_time = time.time() | ||||
|         self.train_loss.reset() | ||||
|         self.train_metrics.reset() | ||||
|  | ||||
|     def on_train_epoch_end(self) -> None: | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             log = True | ||||
|         else: | ||||
|             log = False | ||||
|         self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log) | ||||
|         self.train_metrics.log_epoch_metrics(self.current_epoch, log) | ||||
|  | ||||
|     def on_validation_epoch_start(self) -> None: | ||||
|         self.val_nll.reset() | ||||
|         self.val_X_kl.reset() | ||||
|         self.val_E_kl.reset() | ||||
|         self.val_X_logp.reset() | ||||
|         self.val_E_logp.reset() | ||||
|         self.sampling_metrics.reset() | ||||
|         self.val_y_collection = [] | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def validation_step(self, data, i): | ||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() | ||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||
|         dense_data = dense_data.mask(node_mask) | ||||
|         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||
|         pred = self.forward(noisy_data) | ||||
|         nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) | ||||
|         self.val_y_collection.append(data.y) | ||||
|         self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True) | ||||
|         return {'loss': nll} | ||||
|  | ||||
|     def on_validation_epoch_end(self) -> None: | ||||
|         metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T, | ||||
|                    self.val_X_logp.compute(), self.val_E_logp.compute()] | ||||
|          | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||
|                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||
|  | ||||
|         # Log val nll with default Lightning logger, so it can be monitored by checkpoint callback | ||||
|         self.log("val/NLL",  metrics[0], sync_dist=True) | ||||
|  | ||||
|         if metrics[0] < self.best_val_nll: | ||||
|             self.best_val_nll = metrics[0] | ||||
|  | ||||
|         self.val_counter += 1 | ||||
|          | ||||
|         if self.val_counter % self.cfg.general.sample_every_val == 0 and self.val_counter > 1: | ||||
|             start = time.time() | ||||
|             samples_left_to_generate = self.cfg.general.samples_to_generate | ||||
|             samples_left_to_save = self.cfg.general.samples_to_save | ||||
|             chains_left_to_save = self.cfg.general.chains_to_save | ||||
|  | ||||
|             samples, all_ys, ident = [], [], 0 | ||||
|  | ||||
|             self.val_y_collection = torch.cat(self.val_y_collection, dim=0) | ||||
|             num_examples = self.val_y_collection.size(0) | ||||
|             start_index = 0 | ||||
|             while samples_left_to_generate > 0:                 | ||||
|                 bs = 1 * self.cfg.train.batch_size | ||||
|                 to_generate = min(samples_left_to_generate, bs) | ||||
|                 to_save = min(samples_left_to_save, bs) | ||||
|                 chains_save = min(chains_left_to_save, bs) | ||||
|  | ||||
|                 if start_index + to_generate > num_examples: | ||||
|                     start_index = 0 | ||||
|                 if to_generate > num_examples: | ||||
|                     ratio = to_generate // num_examples | ||||
|                     self.val_y_collection = self.val_y_collection.repeat(ratio+1, 1) | ||||
|                     num_examples = self.val_y_collection.size(0) | ||||
|                 batch_y = self.val_y_collection[start_index:start_index + to_generate]                 | ||||
|                 all_ys.append(batch_y) | ||||
|                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||
|                                                 save_final=to_save, | ||||
|                                                 keep_chain=chains_save, | ||||
|                                                 number_chain_steps=self.number_chain_steps)) | ||||
|                 ident += to_generate | ||||
|                 start_index += to_generate | ||||
|  | ||||
|                 samples_left_to_save -= to_save | ||||
|                 samples_left_to_generate -= to_generate | ||||
|                 chains_left_to_save -= chains_save | ||||
|  | ||||
|             print(f"Computing sampling metrics", ' ...') | ||||
|             valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) | ||||
|             print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') | ||||
|             current_path = os.getcwd() | ||||
|             result_path = os.path.join(current_path, | ||||
|                                        f'graphs/{self.name}/epoch{self.current_epoch}_b0/') | ||||
|             self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) | ||||
|             self.sampling_metrics.reset() | ||||
|  | ||||
|     def on_test_epoch_start(self) -> None: | ||||
|         print("Starting test...") | ||||
|         self.test_nll.reset() | ||||
|         self.test_X_kl.reset() | ||||
|         self.test_E_kl.reset() | ||||
|         self.test_X_logp.reset() | ||||
|         self.test_E_logp.reset() | ||||
|         self.test_y_collection = [] | ||||
|      | ||||
|     @torch.no_grad() | ||||
|     def test_step(self, data, i): | ||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() | ||||
|  | ||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||
|         dense_data = dense_data.mask(node_mask) | ||||
|         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||
|         pred = self.forward(noisy_data) | ||||
|         nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) | ||||
|         self.test_y_collection.append(data.y) | ||||
|         return {'loss': nll} | ||||
|  | ||||
|     def on_test_epoch_end(self) -> None: | ||||
|         """ Measure likelihood on a test set and compute stability metrics. """ | ||||
|         metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(), | ||||
|                    self.test_X_logp.compute(), self.test_E_logp.compute()] | ||||
|  | ||||
|         print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ", | ||||
|               f"Test Edge type KL: {metrics[2] :.2f}") | ||||
|  | ||||
|         ## final epcoh | ||||
|         samples_left_to_generate = self.cfg.general.final_model_samples_to_generate | ||||
|         samples_left_to_save = self.cfg.general.final_model_samples_to_save | ||||
|         chains_left_to_save = self.cfg.general.final_model_chains_to_save | ||||
|  | ||||
|         samples, all_ys, batch_id = [], [], 0 | ||||
|         test_y_collection = torch.cat(self.test_y_collection, dim=0) | ||||
|         num_examples = test_y_collection.size(0) | ||||
|         if self.cfg.general.final_model_samples_to_generate > num_examples: | ||||
|             ratio = self.cfg.general.final_model_samples_to_generate // num_examples | ||||
|             test_y_collection = test_y_collection.repeat(ratio+1, 1) | ||||
|             num_examples = test_y_collection.size(0) | ||||
|          | ||||
|         while samples_left_to_generate > 0: | ||||
|             print(f'samples left to generate: {samples_left_to_generate}/' | ||||
|                 f'{self.cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||
|             bs = 1 * self.cfg.train.batch_size | ||||
|             to_generate = min(samples_left_to_generate, bs) | ||||
|             to_save = min(samples_left_to_save, bs) | ||||
|             chains_save = min(chains_left_to_save, bs) | ||||
|             batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||
|  | ||||
|             cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps) | ||||
|             samples = samples + cur_sample | ||||
|              | ||||
|             all_ys.append(batch_y) | ||||
|             batch_id += to_generate | ||||
|  | ||||
|             samples_left_to_save -= to_save | ||||
|             samples_left_to_generate -= to_generate | ||||
|             chains_left_to_save -= chains_save | ||||
|              | ||||
|         print(f"final Computing sampling metrics...") | ||||
|         self.sampling_metrics.reset() | ||||
|         self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, self.val_counter, test=True) | ||||
|         self.sampling_metrics.reset() | ||||
|         print(f"Done.") | ||||
|              | ||||
|  | ||||
|     def kl_prior(self, X, E, node_mask): | ||||
|         """Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1). | ||||
|  | ||||
|         This is essentially a lot of work for something that is in practice negligible in the loss. However, you | ||||
|         compute it so that you see it when you've made a mistake in your noise schedule. | ||||
|         """ | ||||
|         # Compute the last alpha value, alpha_T. | ||||
|         ones = torch.ones((X.size(0), 1), device=X.device) | ||||
|         Ts = self.T * ones | ||||
|         alpha_t_bar = self.noise_schedule.get_alpha_bar(t_int=Ts)  # (bs, 1) | ||||
|          | ||||
|         Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) | ||||
|          | ||||
|         bs, n, d = X.shape | ||||
|         X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) | ||||
|         prob_all = X_all @ Qtb.X | ||||
|         probX = prob_all[:, :, :self.Xdim_output] | ||||
|         probE = prob_all[:, :, self.Xdim_output:].reshape((bs, n, n, -1)) | ||||
|  | ||||
|         assert probX.shape == X.shape | ||||
|  | ||||
|         limit_X = self.limit_dist.X[None, None, :].expand(bs, n, -1).type_as(probX) | ||||
|         limit_E = self.limit_dist.E[None, None, None, :].expand(bs, n, n, -1).type_as(probE) | ||||
|  | ||||
|         # Make sure that masked rows do not contribute to the loss | ||||
|         limit_dist_X, limit_dist_E, probX, probE = diffusion_utils.mask_distributions(true_X=limit_X.clone(), | ||||
|                                                                                       true_E=limit_E.clone(), | ||||
|                                                                                       pred_X=probX, | ||||
|                                                                                       pred_E=probE, | ||||
|                                                                                       node_mask=node_mask) | ||||
|  | ||||
|         kl_distance_X = F.kl_div(input=probX.log(), target=limit_dist_X, reduction='none') | ||||
|         kl_distance_E = F.kl_div(input=probE.log(), target=limit_dist_E, reduction='none') | ||||
|  | ||||
|         return diffusion_utils.sum_except_batch(kl_distance_X) + \ | ||||
|                diffusion_utils.sum_except_batch(kl_distance_E) | ||||
|  | ||||
|     def compute_Lt(self, X, E, y, pred, noisy_data, node_mask, test): | ||||
|         pred_probs_X = F.softmax(pred.X, dim=-1) | ||||
|         pred_probs_E = F.softmax(pred.E, dim=-1) | ||||
|  | ||||
|         Qtb = self.transition_model.get_Qt_bar(noisy_data['alpha_t_bar'], self.device) | ||||
|         Qsb = self.transition_model.get_Qt_bar(noisy_data['alpha_s_bar'], self.device) | ||||
|         Qt = self.transition_model.get_Qt(noisy_data['beta_t'], self.device) | ||||
|  | ||||
|         # Compute distributions to compare with KL | ||||
|         bs, n, d = X.shape | ||||
|         X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1).float() | ||||
|         Xt_all = torch.cat([noisy_data['X_t'], noisy_data['E_t'].reshape(bs, n, -1)], dim=-1).float() | ||||
|         pred_probs_all = torch.cat([pred_probs_X, pred_probs_E.reshape(bs, n, -1)], dim=-1).float() | ||||
|  | ||||
|         prob_true = diffusion_utils.posterior_distributions(X=X_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.Xdim_output) | ||||
|         prob_true.E = prob_true.E.reshape((bs, n, n, -1)) | ||||
|         prob_pred = diffusion_utils.posterior_distributions(X=pred_probs_all, X_t=Xt_all, Qt=Qt, Qsb=Qsb, Qtb=Qtb, X_dim=self.Xdim_output) | ||||
|         prob_pred.E = prob_pred.E.reshape((bs, n, n, -1)) | ||||
|  | ||||
|         # Reshape and filter masked rows | ||||
|         prob_true_X, prob_true_E, prob_pred.X, prob_pred.E = diffusion_utils.mask_distributions(true_X=prob_true.X, | ||||
|                                                                                                 true_E=prob_true.E, | ||||
|                                                                                                 pred_X=prob_pred.X, | ||||
|                                                                                                 pred_E=prob_pred.E, | ||||
|                                                                                                 node_mask=node_mask) | ||||
|         kl_x = (self.test_X_kl if test else self.val_X_kl)(prob_true.X, torch.log(prob_pred.X)) | ||||
|         kl_e = (self.test_E_kl if test else self.val_E_kl)(prob_true.E, torch.log(prob_pred.E)) | ||||
|  | ||||
|         return self.T * (kl_x + kl_e) | ||||
|  | ||||
|     def reconstruction_logp(self, t, X, E, y, node_mask): | ||||
|         # Compute noise values for t = 0. | ||||
|         t_zeros = torch.zeros_like(t) | ||||
|         beta_0 = self.noise_schedule(t_zeros) | ||||
|         Q0 = self.transition_model.get_Qt(beta_0, self.device) | ||||
|          | ||||
|         bs, n, d = X.shape | ||||
|         X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) | ||||
|         prob_all = X_all @ Q0.X | ||||
|         probX0 = prob_all[:, :, :self.Xdim_output] | ||||
|         probE0 = prob_all[:, :, self.Xdim_output:].reshape((bs, n, n, -1)) | ||||
|  | ||||
|         sampled0 = diffusion_utils.sample_discrete_features(probX=probX0, probE=probE0, node_mask=node_mask) | ||||
|  | ||||
|         X0 = F.one_hot(sampled0.X, num_classes=self.Xdim_output).float() | ||||
|         E0 = F.one_hot(sampled0.E, num_classes=self.Edim_output).float() | ||||
|  | ||||
|         assert (X.shape == X0.shape) and (E.shape == E0.shape) | ||||
|         sampled_0 = utils.PlaceHolder(X=X0, E=E0, y=y).mask(node_mask) | ||||
|  | ||||
|         # Predictions | ||||
|         noisy_data = {'X_t': sampled_0.X, 'E_t': sampled_0.E, 'y_t': sampled_0.y, 'node_mask': node_mask, | ||||
|                       't': torch.zeros(X0.shape[0], 1).type_as(y)} | ||||
|         pred0 = self.forward(noisy_data) | ||||
|  | ||||
|         # Normalize predictions | ||||
|         probX0 = F.softmax(pred0.X, dim=-1) | ||||
|         probE0 = F.softmax(pred0.E, dim=-1) | ||||
|         proby0 = None | ||||
|  | ||||
|         # Set masked rows to arbitrary values that don't contribute to loss | ||||
|         probX0[~node_mask] = torch.ones(self.Xdim_output).type_as(probX0) | ||||
|         probE0[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2))] = torch.ones(self.Edim_output).type_as(probE0) | ||||
|  | ||||
|         diag_mask = torch.eye(probE0.size(1)).type_as(probE0).bool() | ||||
|         diag_mask = diag_mask.unsqueeze(0).expand(probE0.size(0), -1, -1) | ||||
|         probE0[diag_mask] = torch.ones(self.Edim_output).type_as(probE0) | ||||
|  | ||||
|         return utils.PlaceHolder(X=probX0, E=probE0, y=proby0) | ||||
|  | ||||
|     def apply_noise(self, X, E, y, node_mask): | ||||
|         """ Sample noise and apply it to the data. """ | ||||
|  | ||||
|         # Sample a timestep t. | ||||
|         # When evaluating, the loss for t=0 is computed separately | ||||
|         lowest_t = 0 if self.training else 1 | ||||
|         t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1) | ||||
|         s_int = t_int - 1 | ||||
|  | ||||
|         t_float = t_int / self.T | ||||
|         s_float = s_int / self.T | ||||
|  | ||||
|         # beta_t and alpha_s_bar are used for denoising/loss computation | ||||
|         beta_t = self.noise_schedule(t_normalized=t_float)                         # (bs, 1) | ||||
|         alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)      # (bs, 1) | ||||
|         alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)      # (bs, 1) | ||||
|  | ||||
|         Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)  # (bs, dx_in, dx_out), (bs, de_in, de_out) | ||||
|          | ||||
|         bs, n, d = X.shape | ||||
|         X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) | ||||
|         prob_all = X_all @ Qtb.X | ||||
|         probX = prob_all[:, :, :self.Xdim_output] | ||||
|         probE = prob_all[:, :, self.Xdim_output:].reshape(bs, n, n, -1) | ||||
|  | ||||
|         sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask) | ||||
|  | ||||
|         X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) | ||||
|         E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) | ||||
|         assert (X.shape == X_t.shape) and (E.shape == E_t.shape) | ||||
|  | ||||
|         y_t = y | ||||
|         z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) | ||||
|  | ||||
|         noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar, | ||||
|                       'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask} | ||||
|         return noisy_data | ||||
|  | ||||
|     def compute_val_loss(self, pred, noisy_data, X, E, y, node_mask, test=False): | ||||
|         """Computes an estimator for the variational lower bound. | ||||
|            pred: (batch_size, n, total_features) | ||||
|            noisy_data: dict | ||||
|            X, E, y : (bs, n, dx),  (bs, n, n, de), (bs, dy) | ||||
|            node_mask : (bs, n) | ||||
|            Output: nll (size 1) | ||||
|        """ | ||||
|         t = noisy_data['t'] | ||||
|  | ||||
|         # 1. | ||||
|         N = node_mask.sum(1).long() | ||||
|         log_pN = self.node_dist.log_prob(N) | ||||
|  | ||||
|         # 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero. | ||||
|         kl_prior = self.kl_prior(X, E, node_mask) | ||||
|  | ||||
|         # 3. Diffusion loss | ||||
|         loss_all_t = self.compute_Lt(X, E, y, pred, noisy_data, node_mask, test) | ||||
|  | ||||
|         # 4. Reconstruction loss | ||||
|         # Compute L0 term : -log p (X, E, y | z_0) = reconstruction loss | ||||
|         prob0 = self.reconstruction_logp(t, X, E, y, node_mask) | ||||
|  | ||||
|         eps = 1e-8 | ||||
|         loss_term_0 = self.val_X_logp(X * (prob0.X+eps).log()) + self.val_E_logp(E * (prob0.E+eps).log()) | ||||
|  | ||||
|         # Combine terms | ||||
|         nlls = - log_pN + kl_prior + loss_all_t - loss_term_0 | ||||
|         assert len(nlls.shape) == 1, f'{nlls.shape} has more than only batch dim.' | ||||
|  | ||||
|         # Update NLL metric object and return batch nll | ||||
|         nll = (self.test_nll if test else self.val_nll)(nlls)        # Average over the batch | ||||
|          | ||||
|         return nll | ||||
|      | ||||
|     @torch.no_grad() | ||||
|     def sample_batch(self, batch_id, batch_size, y, keep_chain, number_chain_steps, save_final, num_nodes=None): | ||||
|         """ | ||||
|         :param batch_id: int | ||||
|         :param batch_size: int | ||||
|         :param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes | ||||
|         :param save_final: int: number of predictions to save to file | ||||
|         :param keep_chain: int: number of chains to save to file (disabled) | ||||
|         :param keep_chain_steps: number of timesteps to save for each chain (disabled) | ||||
|         :return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions) | ||||
|         """ | ||||
|         if num_nodes is None: | ||||
|             n_nodes = self.node_dist.sample_n(batch_size, self.device) | ||||
|         elif type(num_nodes) == int: | ||||
|             n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int) | ||||
|         else: | ||||
|             assert isinstance(num_nodes, torch.Tensor) | ||||
|             n_nodes = num_nodes | ||||
|         n_max = self.max_n_nodes | ||||
|         arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1) | ||||
|         node_mask = arange < n_nodes.unsqueeze(1) | ||||
|          | ||||
|         z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask) | ||||
|         X, E = z_T.X, z_T.E | ||||
|  | ||||
|         assert (E == torch.transpose(E, 1, 2)).all() | ||||
|  | ||||
|         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | ||||
|         for s_int in reversed(range(0, self.T)): | ||||
|             s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | ||||
|             t_array = s_array + 1 | ||||
|             s_norm = s_array / self.T | ||||
|             t_norm = t_array / self.T | ||||
|  | ||||
|             # Sample z_s | ||||
|             sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) | ||||
|             X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||
|  | ||||
|         # Sample | ||||
|         sampled_s = sampled_s.mask(node_mask, collapse=True) | ||||
|         X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||
|          | ||||
|         molecule_list = [] | ||||
|         for i in range(batch_size): | ||||
|             n = n_nodes[i] | ||||
|             atom_types = X[i, :n].cpu() | ||||
|             edge_types = E[i, :n, :n].cpu() | ||||
|             molecule_list.append([atom_types, edge_types]) | ||||
|          | ||||
|         return molecule_list | ||||
|  | ||||
|     def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): | ||||
|         """Samples from zs ~ p(zs | zt). Only used during sampling. | ||||
|            if last_step, return the graph prediction as well""" | ||||
|         bs, n, dxs = X_t.shape | ||||
|         beta_t = self.noise_schedule(t_normalized=t)  # (bs, 1) | ||||
|         alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) | ||||
|         alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) | ||||
|  | ||||
|         # Neural net predictions | ||||
|         noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask} | ||||
|          | ||||
|         def get_prob(noisy_data, unconditioned=False): | ||||
|             pred = self.forward(noisy_data, unconditioned=unconditioned) | ||||
|  | ||||
|             # Normalize predictions | ||||
|             pred_X = F.softmax(pred.X, dim=-1)               # bs, n, d0 | ||||
|             pred_E = F.softmax(pred.E, dim=-1)               # bs, n, n, d0 | ||||
|  | ||||
|             # Retrieve transitions matrix | ||||
|             Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) | ||||
|             Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device) | ||||
|             Qt = self.transition_model.get_Qt(beta_t, self.device) | ||||
|  | ||||
|             Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) | ||||
|             p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all, | ||||
|                                                                                             Qt=Qt.X, | ||||
|                                                                                             Qsb=Qsb.X, | ||||
|                                                                                             Qtb=Qtb.X) | ||||
|             predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) | ||||
|             weightedX_all = predX_all.unsqueeze(-1) * p_s_and_t_given_0 | ||||
|             unnormalized_probX_all = weightedX_all.sum(dim=2)                     # bs, n, d_t-1 | ||||
|  | ||||
|             unnormalized_prob_X = unnormalized_probX_all[:, :, :self.Xdim_output] | ||||
|             unnormalized_prob_E = unnormalized_probX_all[:, :, self.Xdim_output:].reshape(bs, n*n, -1) | ||||
|  | ||||
|             unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 | ||||
|             unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 | ||||
|  | ||||
|             prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # bs, n, d_t-1 | ||||
|             prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)  # bs, n, d_t-1 | ||||
|             prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) | ||||
|  | ||||
|             return prob_X, prob_E | ||||
|  | ||||
|         prob_X, prob_E = get_prob(noisy_data) | ||||
|  | ||||
|         ### Guidance | ||||
|         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: | ||||
|             uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True) | ||||
|             prob_X = uncon_prob_X *  (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale   | ||||
|             prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale   | ||||
|             prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10) | ||||
|             prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10) | ||||
|  | ||||
|         assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() | ||||
|         assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() | ||||
|  | ||||
|         sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) | ||||
|  | ||||
|         X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() | ||||
|         E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() | ||||
|  | ||||
|         assert (E_s == torch.transpose(E_s, 1, 2)).all() | ||||
|         assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) | ||||
|  | ||||
|         out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||
|         out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||
|  | ||||
|         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t) | ||||
							
								
								
									
										138
									
								
								graph_dit/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								graph_dit/main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| # These imports are tricky because they use c++, do not move them | ||||
| import os, shutil | ||||
| import warnings | ||||
|  | ||||
| import torch | ||||
| import hydra | ||||
| from omegaconf import DictConfig | ||||
| from pytorch_lightning import Trainer | ||||
|  | ||||
| import utils | ||||
| from datasets import dataset | ||||
| from diffusion_model import Graph_DiT | ||||
| from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete | ||||
| from metrics.molecular_metrics_sampling import SamplingMolecularMetrics | ||||
|  | ||||
| from analysis.visualization import MolecularVisualization | ||||
|  | ||||
| warnings.filterwarnings("ignore", category=UserWarning) | ||||
| torch.set_float32_matmul_precision("medium") | ||||
|  | ||||
| def remove_folder(folder): | ||||
|     for filename in os.listdir(folder): | ||||
|         file_path = os.path.join(folder, filename) | ||||
|         try: | ||||
|             if os.path.isfile(file_path) or os.path.islink(file_path): | ||||
|                 os.unlink(file_path) | ||||
|             elif os.path.isdir(file_path): | ||||
|                 shutil.rmtree(file_path) | ||||
|         except Exception as e: | ||||
|             print("Failed to delete %s. Reason: %s" % (file_path, e)) | ||||
|  | ||||
|  | ||||
| def get_resume(cfg, model_kwargs): | ||||
|     """Resumes a run. It loads previous config without allowing to update keys (used for testing).""" | ||||
|     saved_cfg = cfg.copy() | ||||
|     name = cfg.general.name + "_resume" | ||||
|     resume = cfg.general.test_only | ||||
|     batch_size = cfg.train.batch_size | ||||
|     model = Graph_DiT.load_from_checkpoint(resume, **model_kwargs) | ||||
|     cfg = model.cfg | ||||
|     cfg.general.test_only = resume | ||||
|     cfg.general.name = name | ||||
|     cfg.train.batch_size = batch_size | ||||
|     cfg = utils.update_config_with_new_keys(cfg, saved_cfg) | ||||
|     return cfg, model | ||||
|  | ||||
| def get_resume_adaptive(cfg, model_kwargs): | ||||
|     """Resumes a run. It loads previous config but allows to make some changes (used for resuming training).""" | ||||
|     saved_cfg = cfg.copy() | ||||
|     # Fetch path to this file to get base path | ||||
|     current_path = os.path.dirname(os.path.realpath(__file__)) | ||||
|     root_dir = current_path.split("outputs")[0] | ||||
|  | ||||
|     resume_path = os.path.join(root_dir, cfg.general.resume) | ||||
|  | ||||
|     if cfg.model.type == "discrete": | ||||
|         model = Graph_DiT.load_from_checkpoint( | ||||
|             resume_path, **model_kwargs | ||||
|         ) | ||||
|     else: | ||||
|         raise NotImplementedError("Unknown model") | ||||
|  | ||||
|     new_cfg = model.cfg | ||||
|     for category in cfg: | ||||
|         for arg in cfg[category]: | ||||
|             new_cfg[category][arg] = cfg[category][arg] | ||||
|  | ||||
|     new_cfg.general.resume = resume_path | ||||
|     new_cfg.general.name = new_cfg.general.name + "_resume" | ||||
|  | ||||
|     new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg) | ||||
|     return new_cfg, model | ||||
|  | ||||
|  | ||||
| @hydra.main( | ||||
|     version_base="1.1", config_path="../configs", config_name="config" | ||||
| ) | ||||
| def main(cfg: DictConfig): | ||||
|  | ||||
|     datamodule = dataset.DataModule(cfg) | ||||
|     datamodule.prepare_data() | ||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) | ||||
|     train_smiles, reference_smiles = datamodule.get_train_smiles() | ||||
|  | ||||
|     dataset_infos.compute_input_output_dims(datamodule=datamodule) | ||||
|     train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) | ||||
|  | ||||
|     sampling_metrics = SamplingMolecularMetrics( | ||||
|         dataset_infos, train_smiles, reference_smiles | ||||
|     ) | ||||
|     visualization_tools = MolecularVisualization(dataset_infos) | ||||
|  | ||||
|     model_kwargs = { | ||||
|         "dataset_infos": dataset_infos, | ||||
|         "train_metrics": train_metrics, | ||||
|         "sampling_metrics": sampling_metrics, | ||||
|         "visualization_tools": visualization_tools, | ||||
|     } | ||||
|  | ||||
|     if cfg.general.test_only: | ||||
|         # When testing, previous configuration is fully loaded | ||||
|         cfg, _ = get_resume(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.test_only.split("checkpoints")[0]) | ||||
|     elif cfg.general.resume is not None: | ||||
|         # When resuming, we can override some parts of previous configuration | ||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||
|  | ||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||
|     trainer = Trainer( | ||||
|         gradient_clip_val=cfg.train.clip_grad, | ||||
|         accelerator="gpu" | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else "cpu", | ||||
|         devices=cfg.general.gpus | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else None, | ||||
|         max_epochs=cfg.train.n_epochs, | ||||
|         enable_checkpointing=False, | ||||
|         check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, | ||||
|         val_check_interval=cfg.train.val_check_interval, | ||||
|         strategy="ddp" if cfg.general.gpus > 1 else "auto", | ||||
|         enable_progress_bar=cfg.general.enable_progress_bar, | ||||
|         callbacks=[], | ||||
|         reload_dataloaders_every_n_epochs=0, | ||||
|         logger=[], | ||||
|     ) | ||||
|  | ||||
|     if not cfg.general.test_only: | ||||
|         trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) | ||||
|         if cfg.general.save_model: | ||||
|             trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") | ||||
|         trainer.test(model, datamodule=datamodule) | ||||
|     else: | ||||
|         trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										0
									
								
								graph_dit/metrics/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/metrics/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										138
									
								
								graph_dit/metrics/abstract_metrics.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								graph_dit/metrics/abstract_metrics.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| import torch | ||||
| from torch import Tensor | ||||
| from torch.nn import functional as F | ||||
| from torchmetrics import Metric, MeanSquaredError | ||||
|  | ||||
|  | ||||
| class TrainAbstractMetricsDiscrete(torch.nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|     def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): | ||||
|         pass | ||||
|  | ||||
|     def reset(self): | ||||
|         pass | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class TrainAbstractMetrics(torch.nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|     def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log): | ||||
|         pass | ||||
|  | ||||
|     def reset(self): | ||||
|         pass | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class SumExceptBatchMetric(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, values) -> None: | ||||
|         self.total_value += torch.sum(values) | ||||
|         self.total_samples += values.shape[0] | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_value / self.total_samples | ||||
|  | ||||
|  | ||||
| class SumExceptBatchMSE(MeanSquaredError): | ||||
|     def update(self, preds: Tensor, target: Tensor) -> None: | ||||
|         """Update state with predictions and targets. | ||||
|  | ||||
|         Args: | ||||
|             preds: Predictions from model | ||||
|             target: Ground truth values | ||||
|         """ | ||||
|         assert preds.shape == target.shape | ||||
|         sum_squared_error, n_obs = self._mean_squared_error_update(preds, target) | ||||
|  | ||||
|         self.sum_squared_error += sum_squared_error | ||||
|         self.total += n_obs | ||||
|  | ||||
|     def _mean_squared_error_update(self, preds: Tensor, target: Tensor): | ||||
|             """ Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input | ||||
|             tensors. | ||||
|                 preds: Predicted tensor | ||||
|                 target: Ground truth tensor | ||||
|             """ | ||||
|             diff = preds - target | ||||
|             sum_squared_error = torch.sum(diff * diff) | ||||
|             n_obs = preds.shape[0] | ||||
|             return sum_squared_error, n_obs | ||||
|  | ||||
|  | ||||
| class SumExceptBatchKL(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, p, q) -> None: | ||||
|         self.total_value += F.kl_div(q, p, reduction='sum') | ||||
|         self.total_samples += p.size(0) | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_value / self.total_samples | ||||
|  | ||||
|  | ||||
| class CrossEntropyMetric(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, preds: Tensor, target: Tensor, weight=None) -> None: | ||||
|         """ Update state with predictions and targets. | ||||
|             preds: Predictions from model   (bs * n, d) or (bs * n * n, d) | ||||
|             target: Ground truth values     (bs * n, d) or (bs * n * n, d). """ | ||||
|         target = torch.argmax(target, dim=-1) | ||||
|         if weight is not None: | ||||
|             weight = weight.type_as(preds) | ||||
|             output = F.cross_entropy(preds, target, weight = weight, reduction='sum') | ||||
|         else: | ||||
|             output = F.cross_entropy(preds, target, reduction='sum') | ||||
|         self.total_ce += output | ||||
|         self.total_samples += preds.size(0) | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_ce / self.total_samples | ||||
|  | ||||
|  | ||||
| class ProbabilityMetric(Metric): | ||||
|     def __init__(self): | ||||
|         """ This metric is used to track the marginal predicted probability of a class during training. """ | ||||
|         super().__init__() | ||||
|         self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, preds: Tensor) -> None: | ||||
|         self.prob += preds.sum() | ||||
|         self.total += preds.numel() | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.prob / self.total | ||||
|  | ||||
|  | ||||
| class NLL(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, batch_nll) -> None: | ||||
|         self.total_nll += torch.sum(batch_nll) | ||||
|         self.total_samples += batch_nll.numel() | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_nll / self.total_samples | ||||
							
								
								
									
										
											BIN
										
									
								
								graph_dit/metrics/fpscores.pkl.gz
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								graph_dit/metrics/fpscores.pkl.gz
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										138
									
								
								graph_dit/metrics/molecular_metrics_sampling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								graph_dit/metrics/molecular_metrics_sampling.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| ### packages for visualization | ||||
| from analysis.rdkit_functions import compute_molecular_metrics | ||||
| from mini_moses.metrics.metrics import compute_intermediate_statistics | ||||
| from metrics.property_metric import TaskModel | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| import os | ||||
| import csv | ||||
| import time | ||||
|  | ||||
| def result_to_csv(path, dict_data): | ||||
|     file_exists = os.path.exists(path) | ||||
|     log_name = dict_data.pop("log_name", None) | ||||
|     if log_name is None: | ||||
|         raise ValueError("The provided dictionary must contain a 'log_name' key.") | ||||
|     field_names = ["log_name"] + list(dict_data.keys()) | ||||
|     dict_data["log_name"] = log_name | ||||
|     with open(path, "a", newline="") as file: | ||||
|         writer = csv.DictWriter(file, fieldnames=field_names) | ||||
|         if not file_exists: | ||||
|             writer.writeheader() | ||||
|         writer.writerow(dict_data) | ||||
|  | ||||
|  | ||||
| class SamplingMolecularMetrics(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         dataset_infos, | ||||
|         train_smiles, | ||||
|         reference_smiles, | ||||
|         n_jobs=1, | ||||
|         device="cpu", | ||||
|         batch_size=512, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.task_name = dataset_infos.task | ||||
|         self.dataset_infos = dataset_infos | ||||
|         self.active_atoms = dataset_infos.active_atoms | ||||
|         self.train_smiles = train_smiles | ||||
|  | ||||
|         if reference_smiles is not None: | ||||
|             print( | ||||
|                 f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" | ||||
|             ) | ||||
|             start_time = time.time() | ||||
|             self.stat_ref = compute_intermediate_statistics( | ||||
|                 reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size | ||||
|             ) | ||||
|             end_time = time.time() | ||||
|             elapsed_time = end_time - start_time | ||||
|             print( | ||||
|                 f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" | ||||
|             ) | ||||
|         else: | ||||
|             self.stat_ref = None | ||||
|      | ||||
|         self.comput_config = { | ||||
|             "n_jobs": n_jobs, | ||||
|             "device": device, | ||||
|             "batch_size": batch_size, | ||||
|         } | ||||
|  | ||||
|         self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None} | ||||
|         for cur_task in dataset_infos.task.split("-")[:]: | ||||
|             # print('loading evaluator for task', cur_task) | ||||
|             model_path = os.path.join( | ||||
|                 dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib" | ||||
|             ) | ||||
|             os.makedirs(os.path.dirname(model_path), exist_ok=True) | ||||
|             evaluator = TaskModel(model_path, cur_task) | ||||
|             self.task_evaluator[cur_task] = evaluator | ||||
|  | ||||
|     def forward(self, molecules, targets, name, current_epoch, val_counter, test=False): | ||||
|         if isinstance(targets, list): | ||||
|             targets_cat = torch.cat(targets, dim=0) | ||||
|             targets_np = targets_cat.detach().cpu().numpy() | ||||
|         else: | ||||
|             targets_np = targets.detach().cpu().numpy() | ||||
|  | ||||
|         unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics( | ||||
|             molecules, | ||||
|             targets_np, | ||||
|             self.train_smiles, | ||||
|             self.stat_ref, | ||||
|             self.dataset_infos, | ||||
|             self.task_evaluator, | ||||
|             self.comput_config, | ||||
|         ) | ||||
|  | ||||
|         if test: | ||||
|             file_name = "final_smiles.txt" | ||||
|             with open(file_name, "w") as fp: | ||||
|                 all_tasks_name = list(self.task_evaluator.keys()) | ||||
|                 all_tasks_name = all_tasks_name.copy() | ||||
|                 if 'meta_taskname' in all_tasks_name: | ||||
|                     all_tasks_name.remove('meta_taskname') | ||||
|                 if 'scs' in all_tasks_name: | ||||
|                     all_tasks_name.remove('scs') | ||||
|  | ||||
|                 all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) | ||||
|                 fp.write(all_tasks_str + "\n") | ||||
|                 for i, smiles in enumerate(all_smiles): | ||||
|                     if targets_log is not None: | ||||
|                         all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) | ||||
|                         fp.write(all_result_str + "\n") | ||||
|                     else: | ||||
|                         fp.write("%s\n" % smiles) | ||||
|                 print("All smiles saved") | ||||
|         else: | ||||
|             result_path = os.path.join(os.getcwd(), f"graphs/{name}") | ||||
|             os.makedirs(result_path, exist_ok=True) | ||||
|             text_path = os.path.join( | ||||
|                 result_path, | ||||
|                 f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt", | ||||
|             ) | ||||
|             textfile = open(text_path, "w") | ||||
|             for smiles in unique_smiles: | ||||
|                 textfile.write(smiles + "\n") | ||||
|             textfile.close() | ||||
|  | ||||
|         all_logs = all_metrics | ||||
|         if test: | ||||
|             all_logs["log_name"] = "test" | ||||
|         else: | ||||
|             all_logs["log_name"] = ( | ||||
|                 "epoch" + str(current_epoch) + "_batch" + str(val_counter) | ||||
|             ) | ||||
|          | ||||
|         result_to_csv("output.csv", all_logs) | ||||
|         return all_smiles | ||||
|  | ||||
|     def reset(self): | ||||
|         pass | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     pass | ||||
							
								
								
									
										126
									
								
								graph_dit/metrics/molecular_metrics_train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								graph_dit/metrics/molecular_metrics_train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | ||||
| import torch | ||||
| from torchmetrics import Metric, MetricCollection | ||||
| from torch import Tensor | ||||
| import torch.nn as nn | ||||
|  | ||||
| class CEPerClass(Metric): | ||||
|     full_state_update = False | ||||
|     def __init__(self, class_id): | ||||
|         super().__init__() | ||||
|         self.class_id = class_id | ||||
|         self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.softmax = torch.nn.Softmax(dim=-1) | ||||
|         self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum') | ||||
|      | ||||
|     def update(self, preds: Tensor, target: Tensor) -> None: | ||||
|         """Update state with predictions and targets. | ||||
|         Args: | ||||
|             preds: Predictions from model   (bs, n, d) or (bs, n, n, d) | ||||
|             target: Ground truth values     (bs, n, d) or (bs, n, n, d) | ||||
|         """ | ||||
|         target = target.reshape(-1, target.shape[-1]) | ||||
|         mask = (target != 0.).any(dim=-1) | ||||
|  | ||||
|         prob = self.softmax(preds)[..., self.class_id] | ||||
|         prob = prob.flatten()[mask] | ||||
|  | ||||
|         target = target[:, self.class_id] | ||||
|         target = target[mask] | ||||
|  | ||||
|         output = self.binary_cross_entropy(prob, target) | ||||
|  | ||||
|         self.total_ce += output | ||||
|         self.total_samples += prob.numel() | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_ce / self.total_samples | ||||
|  | ||||
|  | ||||
| class AtomCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
| class NoBondCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class SingleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class DoubleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class TripleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class AromaticCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class AtomMetricsCE(MetricCollection): | ||||
|     def __init__(self, active_atoms): | ||||
|         metrics_list = [] | ||||
|          | ||||
|         for i, atom_type in enumerate(active_atoms): | ||||
|             metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i)) | ||||
|  | ||||
|         super().__init__(metrics_list) | ||||
|  | ||||
|  | ||||
| class BondMetricsCE(MetricCollection): | ||||
|     def __init__(self): | ||||
|         ce_no_bond = NoBondCE(0) | ||||
|         ce_SI = SingleCE(1) | ||||
|         ce_DO = DoubleCE(2) | ||||
|         ce_TR = TripleCE(3) | ||||
|         super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) | ||||
|  | ||||
|  | ||||
| class TrainMolecularMetricsDiscrete(nn.Module): | ||||
|     def __init__(self, dataset_infos): | ||||
|         super().__init__() | ||||
|         active_atoms = dataset_infos.active_atoms | ||||
|         self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms) | ||||
|         self.train_bond_metrics = BondMetricsCE() | ||||
|  | ||||
|     def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): | ||||
|         self.train_atom_metrics(masked_pred_X, true_X) | ||||
|         self.train_bond_metrics(masked_pred_E, true_E) | ||||
|         if log: | ||||
|             to_log = {} | ||||
|             for key, val in self.train_atom_metrics.compute().items(): | ||||
|                 to_log['train/' + key] = val.item() | ||||
|             for key, val in self.train_bond_metrics.compute().items(): | ||||
|                 to_log['train/' + key] = val.item() | ||||
|  | ||||
|     def reset(self): | ||||
|         for metric in [self.train_atom_metrics, self.train_bond_metrics]: | ||||
|             metric.reset() | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch, log=True): | ||||
|         epoch_atom_metrics = self.train_atom_metrics.compute() | ||||
|         epoch_bond_metrics = self.train_bond_metrics.compute() | ||||
|  | ||||
|         to_log = {} | ||||
|         for key, val in epoch_atom_metrics.items(): | ||||
|             to_log['train_epoch/' + key] = val.item() | ||||
|         for key, val in epoch_bond_metrics.items(): | ||||
|             to_log['train_epoch/' + key] = val.item() | ||||
|  | ||||
|         for key, val in epoch_atom_metrics.items(): | ||||
|             epoch_atom_metrics[key] = round(val.item(),4) | ||||
|         for key, val in epoch_bond_metrics.items(): | ||||
|             epoch_bond_metrics[key] = round(val.item(),4) | ||||
|  | ||||
|         if log: | ||||
|             print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}") | ||||
|  | ||||
							
								
								
									
										201
									
								
								graph_dit/metrics/property_metric.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								graph_dit/metrics/property_metric.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,201 @@ | ||||
| import math, os | ||||
| import pickle | ||||
| import os.path as op | ||||
|  | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| from joblib import dump, load | ||||
| from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor | ||||
| from sklearn.metrics import mean_absolute_error, roc_auc_score | ||||
|  | ||||
|  | ||||
| from rdkit import Chem | ||||
| from rdkit import rdBase | ||||
| from rdkit.Chem import AllChem | ||||
| from rdkit import DataStructs | ||||
| from rdkit.Chem import rdMolDescriptors | ||||
| rdBase.DisableLog('rdApp.error') | ||||
|  | ||||
| task_to_colname = { | ||||
|     'hiv_b': 'HIV_active', | ||||
|     'bace_b': 'Class', | ||||
|     'bbbp_b': 'p_np', | ||||
|     'O2': 'O2', | ||||
|     'N2': 'N2', | ||||
|     'CO2': 'CO2', | ||||
| } | ||||
|  | ||||
| tasktype_name = { | ||||
|     'hiv_b': 'classification', | ||||
|     'bace_b': 'classification', | ||||
|     'bbbp_b': 'classification', | ||||
|     'O2': 'regression', | ||||
|     'N2': 'regression', | ||||
|     'CO2': 'regression', | ||||
| } | ||||
|  | ||||
| class TaskModel(): | ||||
|     """Scores based on an ECFP classifier.""" | ||||
|     def __init__(self, model_path, task_name): | ||||
|         task_type = tasktype_name[task_name] | ||||
|         self.task_name = task_name | ||||
|         self.task_type = task_type | ||||
|         self.model_path = model_path | ||||
|         self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error | ||||
|  | ||||
|         try: | ||||
|             self.model = load(model_path) | ||||
|             print(self.task_name, ' evaluator loaded') | ||||
|         except: | ||||
|             print(self.task_name, ' evaluator not found, training new one...') | ||||
|             if 'classification' in task_type: | ||||
|                 self.model = RandomForestClassifier(random_state=0) | ||||
|             elif 'regression' in task_type: | ||||
|                 self.model = RandomForestRegressor(random_state=0) | ||||
|             perfermance = self.train() | ||||
|             dump(self.model, model_path) | ||||
|             print('Oracle peformance: ', perfermance) | ||||
|  | ||||
|     def train(self): | ||||
|         data_path = os.path.dirname(self.model_path) | ||||
|         data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz') | ||||
|         df = pd.read_csv(data_path) | ||||
|         col_name = task_to_colname[self.task_name] | ||||
|         y = df[col_name].to_numpy() | ||||
|         x_smiles = df['smiles'].to_numpy() | ||||
|         mask = ~np.isnan(y) | ||||
|         y = y[mask] | ||||
|  | ||||
|         if 'classification' in self.task_type: | ||||
|             y = y.astype(int) | ||||
|  | ||||
|         x_smiles = x_smiles[mask] | ||||
|         x_fps = [] | ||||
|         mask = [] | ||||
|         for i,smiles in enumerate(x_smiles): | ||||
|             mol = Chem.MolFromSmiles(smiles) | ||||
|             mask.append( int(mol is not None) ) | ||||
|             fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048)) | ||||
|             x_fps.append(fp) | ||||
|         x_fps = np.concatenate(x_fps, axis=0) | ||||
|         self.model.fit(x_fps, y) | ||||
|         y_pred = self.model.predict(x_fps) | ||||
|         perf = self.metric_func(y, y_pred) | ||||
|         print(f'{self.task_name} performance: {perf}') | ||||
|         return perf | ||||
|  | ||||
|     def __call__(self, smiles_list): | ||||
|         fps = [] | ||||
|         mask = [] | ||||
|         for i,smiles in enumerate(smiles_list): | ||||
|             mol = Chem.MolFromSmiles(smiles) | ||||
|             mask.append( int(mol is not None) ) | ||||
|             fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048)) | ||||
|             fps.append(fp) | ||||
|  | ||||
|         fps = np.concatenate(fps, axis=0) | ||||
|         if 'classification' in self.task_type: | ||||
|             scores = self.model.predict_proba(fps)[:, 1] | ||||
|         else: | ||||
|             scores = self.model.predict(fps) | ||||
|         scores = scores * np.array(mask) | ||||
|         return np.float32(scores) | ||||
|  | ||||
|     @classmethod | ||||
|     def fingerprints_from_mol(cls, mol):  # use ECFP4 | ||||
|         features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) | ||||
|         features = np.zeros((1,)) | ||||
|         DataStructs.ConvertToNumpyArray(features_vec, features) | ||||
|         return features.reshape(1, -1) | ||||
|  | ||||
| ###### SAS Score ###### | ||||
| _fscores = None | ||||
|  | ||||
| def readFragmentScores(name='fpscores'): | ||||
|     import gzip | ||||
|     global _fscores | ||||
|     # generate the full path filename: | ||||
|     if name == "fpscores": | ||||
|         name = op.join(op.dirname(__file__), name) | ||||
|     data = pickle.load(gzip.open('%s.pkl.gz' % name)) | ||||
|     outDict = {} | ||||
|     for i in data: | ||||
|         for j in range(1, len(i)): | ||||
|             outDict[i[j]] = float(i[0]) | ||||
|     _fscores = outDict | ||||
|  | ||||
| def numBridgeheadsAndSpiro(mol, ri=None): | ||||
|     nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) | ||||
|     nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) | ||||
|     return nBridgehead, nSpiro | ||||
|  | ||||
| def calculateSAS(smiles_list): | ||||
|     scores = [] | ||||
|     for i, smiles in enumerate(smiles_list): | ||||
|         mol = Chem.MolFromSmiles(smiles) | ||||
|         score = calculateScore(mol) | ||||
|         scores.append(score) | ||||
|     return np.float32(scores) | ||||
|  | ||||
| def calculateScore(m): | ||||
|     if _fscores is None: | ||||
|         readFragmentScores() | ||||
|  | ||||
|     # fragment score | ||||
|     fp = rdMolDescriptors.GetMorganFingerprint(m, | ||||
|                                                2)  # <- 2 is the *radius* of the circular fingerprint | ||||
|     fps = fp.GetNonzeroElements() | ||||
|     score1 = 0. | ||||
|     nf = 0 | ||||
|     for bitId, v in fps.items(): | ||||
|         nf += v | ||||
|         sfp = bitId | ||||
|         score1 += _fscores.get(sfp, -4) * v | ||||
|     score1 /= nf | ||||
|  | ||||
|     # features score | ||||
|     nAtoms = m.GetNumAtoms() | ||||
|     nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) | ||||
|     ri = m.GetRingInfo() | ||||
|     nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) | ||||
|     nMacrocycles = 0 | ||||
|     for x in ri.AtomRings(): | ||||
|         if len(x) > 8: | ||||
|             nMacrocycles += 1 | ||||
|  | ||||
|     sizePenalty = nAtoms**1.005 - nAtoms | ||||
|     stereoPenalty = math.log10(nChiralCenters + 1) | ||||
|     spiroPenalty = math.log10(nSpiro + 1) | ||||
|     bridgePenalty = math.log10(nBridgeheads + 1) | ||||
|     macrocyclePenalty = 0. | ||||
|     # --------------------------------------- | ||||
|     # This differs from the paper, which defines: | ||||
|     #  macrocyclePenalty = math.log10(nMacrocycles+1) | ||||
|     # This form generates better results when 2 or more macrocycles are present | ||||
|     if nMacrocycles > 0: | ||||
|         macrocyclePenalty = math.log10(2) | ||||
|  | ||||
|     score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty | ||||
|  | ||||
|     # correction for the fingerprint density | ||||
|     # not in the original publication, added in version 1.1 | ||||
|     # to make highly symmetrical molecules easier to synthetise | ||||
|     score3 = 0. | ||||
|     if nAtoms > len(fps): | ||||
|         score3 = math.log(float(nAtoms) / len(fps)) * .5 | ||||
|  | ||||
|     sascore = score1 + score2 + score3 | ||||
|  | ||||
|     # need to transform "raw" value into scale between 1 and 10 | ||||
|     min = -4.0 | ||||
|     max = 2.5 | ||||
|     sascore = 11. - (sascore - min + 1) / (max - min) * 9. | ||||
|     # smooth the 10-end | ||||
|     if sascore > 8.: | ||||
|         sascore = 8. + math.log(sascore + 1. - 9.) | ||||
|     if sascore > 10.: | ||||
|         sascore = 10.0 | ||||
|     elif sascore < 1.: | ||||
|         sascore = 1.0 | ||||
|  | ||||
|     return sascore | ||||
							
								
								
									
										94
									
								
								graph_dit/metrics/train_loss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								graph_dit/metrics/train_loss.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| import time | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from metrics.abstract_metrics import CrossEntropyMetric | ||||
| from torchmetrics import Metric, MeanSquaredError | ||||
|  | ||||
| # from 2:He to 119:* | ||||
| valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] | ||||
| valencies_check = torch.tensor(valencies_check) | ||||
|  | ||||
| weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0] | ||||
| weight_check = torch.tensor(weight_check) | ||||
|  | ||||
| class AtomWeightMetric(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         global weight_check | ||||
|         self.weight_check = weight_check | ||||
|  | ||||
|     def update(self, X, Y): | ||||
|         atom_pred_num = X.argmax(dim=-1) | ||||
|         atom_real_num = Y.argmax(dim=-1) | ||||
|         self.weight_check = self.weight_check.type_as(X) | ||||
|  | ||||
|         pred_weight = self.weight_check[atom_pred_num] | ||||
|         real_weight = self.weight_check[atom_real_num] | ||||
|  | ||||
|         lss = 0 | ||||
|         lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum() | ||||
|         self.total_loss += lss | ||||
|         self.total_samples += X.size(0) | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_loss / self.total_samples | ||||
|  | ||||
|  | ||||
| class TrainLossDiscrete(nn.Module): | ||||
|     """ Train with Cross entropy""" | ||||
|     def __init__(self, lambda_train, weight_node=None, weight_edge=None): | ||||
|         super().__init__() | ||||
|         self.node_loss = CrossEntropyMetric() | ||||
|         self.edge_loss = CrossEntropyMetric() | ||||
|         self.weight_loss = AtomWeightMetric() | ||||
|  | ||||
|         self.y_loss = MeanSquaredError() | ||||
|         self.lambda_train = lambda_train | ||||
|  | ||||
|     def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool): | ||||
|         """ Compute train metrics | ||||
|         masked_pred_X : tensor -- (bs, n, dx) | ||||
|         masked_pred_E : tensor -- (bs, n, n, de) | ||||
|         pred_y : tensor -- (bs, ) | ||||
|         true_X : tensor -- (bs, n, dx) | ||||
|         true_E : tensor -- (bs, n, n, de) | ||||
|         true_y : tensor -- (bs, ) | ||||
|         log : boolean. """ | ||||
|  | ||||
|         loss_weight = self.weight_loss(masked_pred_X, true_X) | ||||
|          | ||||
|         true_X = torch.reshape(true_X, (-1, true_X.size(-1)))  # (bs * n, dx) | ||||
|         true_E = torch.reshape(true_E, (-1, true_E.size(-1)))  # (bs * n * n, de) | ||||
|         masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1)))  # (bs * n, dx) | ||||
|         masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1)))   # (bs * n * n, de) | ||||
|  | ||||
|         # Remove masked rows | ||||
|         mask_X = (true_X != 0.).any(dim=-1) | ||||
|         mask_E = (true_E != 0.).any(dim=-1) | ||||
|  | ||||
|         flat_true_X = true_X[mask_X, :] | ||||
|         flat_pred_X = masked_pred_X[mask_X, :] | ||||
|  | ||||
|         flat_true_E = true_E[mask_E, :] | ||||
|         flat_pred_E = masked_pred_E[mask_E, :] | ||||
|          | ||||
|         loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0 | ||||
|         loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0 | ||||
|  | ||||
|         return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight | ||||
|  | ||||
|     def reset(self): | ||||
|         for metric in [self.node_loss, self.edge_loss, self.y_loss]: | ||||
|             metric.reset() | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True): | ||||
|         epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1 | ||||
|         epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1 | ||||
|         epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1 | ||||
|  | ||||
|         if log: | ||||
|             print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} " | ||||
|                 f"Weight: {epoch_weight_loss :.4f} " | ||||
|                 f"-- Time taken {time.time() - start_epoch_time:.1f}s ") | ||||
							
								
								
									
										0
									
								
								graph_dit/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										119
									
								
								graph_dit/models/conditions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								graph_dit/models/conditions.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,119 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import math | ||||
|  | ||||
| class TimestepEmbedder(nn.Module): | ||||
|     """ | ||||
|     Embeds scalar timesteps into vector representations. | ||||
|     """ | ||||
|     def __init__(self, hidden_size, frequency_embedding_size=256): | ||||
|         super().__init__() | ||||
|         self.mlp = nn.Sequential( | ||||
|             nn.Linear(frequency_embedding_size, hidden_size, bias=True), | ||||
|             nn.SiLU(), | ||||
|             nn.Linear(hidden_size, hidden_size, bias=True), | ||||
|         ) | ||||
|         self.frequency_embedding_size = frequency_embedding_size | ||||
|  | ||||
|     @staticmethod | ||||
|     def timestep_embedding(t, dim, max_period=10000): | ||||
|         """ | ||||
|         Create sinusoidal timestep embeddings. | ||||
|         :param t: a 1-D Tensor of N indices, one per batch element. | ||||
|                           These may be fractional. | ||||
|         :param dim: the dimension of the output. | ||||
|         :param max_period: controls the minimum frequency of the embeddings. | ||||
|         :return: an (N, D) Tensor of positional embeddings. | ||||
|         """ | ||||
|         # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | ||||
|         half = dim // 2 | ||||
|         freqs = torch.exp( | ||||
|             -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | ||||
|         ).to(device=t.device) | ||||
|         args = t[:, None].float() * freqs[None] | ||||
|         embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | ||||
|         if dim % 2: | ||||
|             embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | ||||
|         return embedding | ||||
|  | ||||
|     def forward(self, t): | ||||
|         t = t.view(-1) | ||||
|         t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | ||||
|         t_emb = self.mlp(t_freq) | ||||
|         return t_emb | ||||
|  | ||||
| class CategoricalEmbedder(nn.Module): | ||||
|     """ | ||||
|     Embeds categorical conditions such as data sources into vector representations.  | ||||
|     Also handles label dropout for classifier-free guidance. | ||||
|     """ | ||||
|     def __init__(self, num_classes, hidden_size, dropout_prob): | ||||
|         super().__init__() | ||||
|         use_cfg_embedding = dropout_prob > 0 | ||||
|         self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) | ||||
|         self.num_classes = num_classes | ||||
|         self.dropout_prob = dropout_prob | ||||
|  | ||||
|     def token_drop(self, labels, force_drop_ids=None): | ||||
|         """ | ||||
|         Drops labels to enable classifier-free guidance. | ||||
|         """ | ||||
|         if force_drop_ids is None: | ||||
|             drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob | ||||
|         else: | ||||
|             drop_ids = force_drop_ids == 1 | ||||
|         labels = torch.where(drop_ids, self.num_classes, labels) | ||||
|         return labels | ||||
|  | ||||
|     def forward(self, labels, train, force_drop_ids=None, t=None): | ||||
|         labels = labels.long().view(-1) | ||||
|         use_dropout = self.dropout_prob > 0 | ||||
|         if (train and use_dropout) or (force_drop_ids is not None): | ||||
|             labels = self.token_drop(labels, force_drop_ids) | ||||
|         embeddings = self.embedding_table(labels) | ||||
|         if True and train: | ||||
|             noise = torch.randn_like(embeddings) | ||||
|             embeddings = embeddings + noise | ||||
|         return embeddings | ||||
|      | ||||
| class ClusterContinuousEmbedder(nn.Module): | ||||
|     def __init__(self, input_size, hidden_size, dropout_prob): | ||||
|         super().__init__() | ||||
|         use_cfg_embedding = dropout_prob > 0 | ||||
|  | ||||
|         if use_cfg_embedding: | ||||
|             self.embedding_drop = nn.Embedding(1, hidden_size) | ||||
|  | ||||
|         self.mlp = nn.Sequential( | ||||
|             nn.Linear(input_size, hidden_size, bias=True), | ||||
|             nn.Softmax(dim=1), | ||||
|             nn.Linear(hidden_size, hidden_size, bias=False) | ||||
|         ) | ||||
|         self.hidden_size = hidden_size | ||||
|         self.dropout_prob = dropout_prob | ||||
|  | ||||
|     def forward(self, labels, train, force_drop_ids=None, timestep=None): | ||||
|         use_dropout = self.dropout_prob > 0 | ||||
|         if force_drop_ids is not None: | ||||
|             drop_ids = force_drop_ids == 1 | ||||
|         else: | ||||
|             drop_ids = None | ||||
|  | ||||
|         if (train and use_dropout): | ||||
|             drop_ids_rand = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob | ||||
|             if force_drop_ids is not None: | ||||
|                 drop_ids = torch.logical_or(drop_ids, drop_ids_rand) | ||||
|             else: | ||||
|                 drop_ids = drop_ids_rand | ||||
|          | ||||
|         if drop_ids is not None: | ||||
|             embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device) | ||||
|             embeddings[~drop_ids] = self.mlp(labels[~drop_ids]) | ||||
|             embeddings[drop_ids] += self.embedding_drop.weight[0] | ||||
|         else: | ||||
|             embeddings = self.mlp(labels) | ||||
|  | ||||
|         if train: | ||||
|             noise = torch.randn_like(embeddings) | ||||
|             embeddings = embeddings + noise | ||||
|         return embeddings | ||||
							
								
								
									
										114
									
								
								graph_dit/models/layers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								graph_dit/models/layers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| from torch.jit import Final | ||||
| import torch.nn.functional as F | ||||
| from itertools import repeat | ||||
| import collections.abc | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| class Attention(nn.Module): | ||||
|     fast_attn: Final[bool] | ||||
|  | ||||
|     def __init__( | ||||
|             self, | ||||
|             dim, | ||||
|             num_heads=8, | ||||
|             qkv_bias=False, | ||||
|             qk_norm=False, | ||||
|             attn_drop=0, | ||||
|             proj_drop=0, | ||||
|             norm_layer=nn.LayerNorm, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         assert dim % num_heads == 0, 'dim should be divisible by num_heads' | ||||
|         self.num_heads = num_heads | ||||
|         self.head_dim = dim // num_heads | ||||
|  | ||||
|         self.scale = self.head_dim ** -0.5 | ||||
|         self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')  # FIXME | ||||
|         assert self.fast_attn, "scaled_dot_product_attention Not implemented" | ||||
|  | ||||
|         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||||
|  | ||||
|         self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | ||||
|         self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | ||||
|         self.attn_drop = nn.Dropout(attn_drop) | ||||
|  | ||||
|         self.proj = nn.Linear(dim, dim) | ||||
|         self.proj_drop = nn.Dropout(proj_drop) | ||||
|  | ||||
|     def dot_product_attention(self, q, k, v): | ||||
|         q = q * self.scale | ||||
|         attn = q @ k.transpose(-2, -1) | ||||
|         attn_sfmx = attn.softmax(dim=-1) | ||||
|         attn_sfmx = self.attn_drop(attn_sfmx) | ||||
|         x = attn_sfmx @ v | ||||
|         return x, attn | ||||
|      | ||||
|     def forward(self, x, node_mask): | ||||
|         B, N, D = x.shape | ||||
|          | ||||
|         # B, head, N, head_dim | ||||
|         qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | ||||
|         q, k, v = qkv.unbind(0) # B, head, N, head_dim | ||||
|         q, k = self.q_norm(q), self.k_norm(k) | ||||
|          | ||||
|         attn_mask = (node_mask[:, None, :, None] & node_mask[:, None, None, :]).expand(-1, self.num_heads, N, N) | ||||
|         attn_mask[attn_mask.sum(-1) == 0] = True | ||||
|  | ||||
|         x = F.scaled_dot_product_attention( | ||||
|             q, k, v, | ||||
|             dropout_p=self.attn_drop.p, | ||||
|             attn_mask=attn_mask, | ||||
|         ) | ||||
|  | ||||
|         x = x.transpose(1, 2).reshape(B, N, -1) | ||||
|  | ||||
|         x = self.proj(x) | ||||
|         x = self.proj_drop(x) | ||||
|  | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class Mlp(nn.Module): | ||||
|     def __init__( | ||||
|             self, | ||||
|             in_features, | ||||
|             hidden_features=None, | ||||
|             out_features=None, | ||||
|             act_layer=nn.GELU, | ||||
|             bias=True, | ||||
|             drop=0., | ||||
|     ): | ||||
|         super().__init__() | ||||
|         out_features = out_features or in_features | ||||
|         hidden_features = hidden_features or in_features | ||||
|         bias = to_2tuple(bias) | ||||
|         drop_probs = to_2tuple(drop) | ||||
|         linear_layer = nn.Linear | ||||
|  | ||||
|         self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) | ||||
|         self.act = act_layer() | ||||
|         self.drop1 = nn.Dropout(drop_probs[0]) | ||||
|         self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) | ||||
|         self.drop2 = nn.Dropout(drop_probs[1]) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.fc1(x) | ||||
|         x = self.act(x) | ||||
|         x = self.drop1(x) | ||||
|         x = self.fc2(x) | ||||
|         x = self.drop2(x) | ||||
|         return x | ||||
|  | ||||
| # From PyTorch internals | ||||
| def _ntuple(n): | ||||
|     def parse(x): | ||||
|         if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | ||||
|             return tuple(x) | ||||
|         return tuple(repeat(x, n)) | ||||
|     return parse | ||||
|  | ||||
| to_2tuple = _ntuple(2) | ||||
|  | ||||
|  | ||||
							
								
								
									
										184
									
								
								graph_dit/models/transformer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										184
									
								
								graph_dit/models/transformer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,184 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| import utils | ||||
| from models.layers import Attention, Mlp | ||||
| from models.conditions import TimestepEmbedder, CategoricalEmbedder, ClusterContinuousEmbedder | ||||
|  | ||||
| def modulate(x, shift, scale): | ||||
|     return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | ||||
|  | ||||
| class Denoiser(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         max_n_nodes, | ||||
|         hidden_size=384, | ||||
|         depth=12, | ||||
|         num_heads=16, | ||||
|         mlp_ratio=4.0, | ||||
|         drop_condition=0.1, | ||||
|         Xdim=118, | ||||
|         Edim=5, | ||||
|         ydim=3, | ||||
|         task_type='regression', | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.num_heads = num_heads | ||||
|         self.ydim = ydim | ||||
|         self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) | ||||
|  | ||||
|         self.t_embedder = TimestepEmbedder(hidden_size) | ||||
|         self.y_embedding_list = torch.nn.ModuleList() | ||||
|  | ||||
|         self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition)) | ||||
|         for i in range(ydim - 2): | ||||
|             if task_type == 'regression': | ||||
|                 self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) | ||||
|             else: | ||||
|                 self.y_embedding_list.append(CategoricalEmbedder(2, hidden_size, drop_condition)) | ||||
|  | ||||
|         self.encoders = nn.ModuleList( | ||||
|             [ | ||||
|                 SELayer(hidden_size, num_heads, mlp_ratio=mlp_ratio) | ||||
|                 for _ in range(depth) | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         self.out_layer = OutLayer( | ||||
|             max_n_nodes=max_n_nodes, | ||||
|             hidden_size=hidden_size, | ||||
|             atom_type=Xdim, | ||||
|             bond_type=Edim, | ||||
|             mlp_ratio=mlp_ratio, | ||||
|             num_heads=num_heads, | ||||
|         ) | ||||
|  | ||||
|         self.initialize_weights() | ||||
|  | ||||
|     def initialize_weights(self): | ||||
|         # Initialize transformer layers: | ||||
|         def _basic_init(module): | ||||
|             if isinstance(module, nn.Linear): | ||||
|                 torch.nn.init.xavier_uniform_(module.weight) | ||||
|                 if module.bias is not None: | ||||
|                     nn.init.constant_(module.bias, 0) | ||||
|  | ||||
|         def _constant_init(module, i): | ||||
|             if isinstance(module, nn.Linear): | ||||
|                 nn.init.constant_(module.weight, i) | ||||
|                 if module.bias is not None: | ||||
|                     nn.init.constant_(module.bias, i) | ||||
|  | ||||
|         self.apply(_basic_init) | ||||
|  | ||||
|         for block in self.encoders : | ||||
|             _constant_init(block.adaLN_modulation[0], 0) | ||||
|         _constant_init(self.out_layer.adaLN_modulation[0], 0) | ||||
|  | ||||
|     def forward(self, x, e, node_mask, y, t, unconditioned): | ||||
|          | ||||
|         force_drop_id = torch.zeros_like(y.sum(-1)) | ||||
|         force_drop_id[torch.isnan(y.sum(-1))] = 1 | ||||
|         if unconditioned: | ||||
|             force_drop_id = torch.ones_like(y[:, 0]) | ||||
|          | ||||
|         x_in, e_in, y_in = x, e, y | ||||
|         bs, n, _ = x.size() | ||||
|         x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1) | ||||
|         x = self.x_embedder(x) | ||||
|  | ||||
|         c1 = self.t_embedder(t) | ||||
|         for i in range(1, self.ydim): | ||||
|             if i == 1: | ||||
|                 c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) | ||||
|             else: | ||||
|                 c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) | ||||
|         c = c1 + c2 | ||||
|          | ||||
|         for i, block in enumerate(self.encoders): | ||||
|             x = block(x, c, node_mask) | ||||
|  | ||||
|         # X: B * N * dx, E: B * N * N * de | ||||
|         X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask) | ||||
|         return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask) | ||||
|  | ||||
|  | ||||
| class SELayer(nn.Module): | ||||
|     def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): | ||||
|         super().__init__() | ||||
|         self.dropout = 0. | ||||
|         self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False) | ||||
|         self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False) | ||||
|  | ||||
|         self.attn = Attention( | ||||
|             hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True, **block_kwargs | ||||
|         ) | ||||
|  | ||||
|         self.mlp = Mlp( | ||||
|             in_features=hidden_size, | ||||
|             hidden_features=int(hidden_size * mlp_ratio), | ||||
|             drop=self.dropout, | ||||
|         ) | ||||
|  | ||||
|         self.adaLN_modulation = nn.Sequential( | ||||
|             nn.Linear(hidden_size, hidden_size, bias=True), | ||||
|             nn.SiLU(), | ||||
|             nn.Linear(hidden_size, 6 * hidden_size, bias=True) | ||||
|         ) | ||||
|          | ||||
|     def forward(self, x, c, node_mask): | ||||
|         ( | ||||
|             shift_msa, | ||||
|             scale_msa, | ||||
|             gate_msa, | ||||
|             shift_mlp, | ||||
|             scale_mlp, | ||||
|             gate_mlp, | ||||
|         ) = self.adaLN_modulation(c).chunk(6, dim=1) | ||||
|         x = x + gate_msa.unsqueeze(1) * modulate(self.norm1(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa) | ||||
|         x = x + gate_mlp.unsqueeze(1) * modulate(self.norm2(self.mlp(x)), shift_mlp, scale_mlp) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class OutLayer(nn.Module): | ||||
|     # Structure Output Layer | ||||
|     def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None): | ||||
|         super().__init__() | ||||
|         self.atom_type = atom_type | ||||
|         self.bond_type = bond_type | ||||
|         final_size = atom_type + max_n_nodes * bond_type | ||||
|         self.xedecoder = Mlp(in_features=hidden_size,  | ||||
|                             out_features=final_size, drop=0) | ||||
|  | ||||
|         self.norm_final = nn.LayerNorm(final_size, elementwise_affine=False) | ||||
|         self.adaLN_modulation = nn.Sequential( | ||||
|             nn.Linear(hidden_size, hidden_size, bias=True), | ||||
|             nn.SiLU(), | ||||
|             nn.Linear(hidden_size, 2 * final_size, bias=True) | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x, x_in, e_in, c, t, node_mask): | ||||
|         x_all = self.xedecoder(x) | ||||
|         B, N, D = x_all.size() | ||||
|         shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) | ||||
|         x_all = modulate(self.norm_final(x_all), shift, scale) | ||||
|          | ||||
|         atom_out = x_all[:, :, :self.atom_type] | ||||
|         atom_out = x_in + atom_out | ||||
|  | ||||
|         bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type) | ||||
|         bond_out = e_in + bond_out | ||||
|  | ||||
|         ##### standardize adj_out | ||||
|         edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :] | ||||
|         diag_mask = ( | ||||
|             torch.eye(N, dtype=torch.bool) | ||||
|             .unsqueeze(0) | ||||
|             .expand(B, -1, -1) | ||||
|             .type_as(edge_mask) | ||||
|         ) | ||||
|         bond_out.masked_fill_(edge_mask[:, :, :, None], 0) | ||||
|         bond_out.masked_fill_(diag_mask[:, :, :, None], 0) | ||||
|         bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2)) | ||||
|  | ||||
|         return atom_out, bond_out, None | ||||
							
								
								
									
										135
									
								
								graph_dit/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								graph_dit/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,135 @@ | ||||
| import os | ||||
| from omegaconf import OmegaConf, open_dict | ||||
|  | ||||
| import torch | ||||
| import torch_geometric.utils | ||||
| from torch_geometric.utils import to_dense_adj, to_dense_batch | ||||
|  | ||||
| def create_folders(args): | ||||
|     try: | ||||
|         os.makedirs('graphs') | ||||
|         os.makedirs('chains') | ||||
|     except OSError: | ||||
|         pass | ||||
|  | ||||
|     try: | ||||
|         os.makedirs('graphs/' + args.general.name) | ||||
|         os.makedirs('chains/' + args.general.name) | ||||
|     except OSError: | ||||
|         pass | ||||
|  | ||||
| def normalize(X, E, y, norm_values, norm_biases, node_mask): | ||||
|     X = (X - norm_biases[0]) / norm_values[0] | ||||
|     E = (E - norm_biases[1]) / norm_values[1] | ||||
|     y = (y - norm_biases[2]) / norm_values[2] | ||||
|  | ||||
|     diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) | ||||
|     E[diag] = 0 | ||||
|  | ||||
|     return PlaceHolder(X=X, E=E, y=y).mask(node_mask) | ||||
|  | ||||
|  | ||||
| def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False): | ||||
|     """ | ||||
|     X : node features | ||||
|     E : edge features | ||||
|     y : global features` | ||||
|     norm_values : [norm value X, norm value E, norm value y] | ||||
|     norm_biases : same order | ||||
|     node_mask | ||||
|     """ | ||||
|     X = (X * norm_values[0] + norm_biases[0]) | ||||
|     E = (E * norm_values[1] + norm_biases[1]) | ||||
|     y = y * norm_values[2] + norm_biases[2] | ||||
|  | ||||
|     return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse) | ||||
|  | ||||
|  | ||||
| def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None): | ||||
|     X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes) | ||||
|     # node_mask = node_mask.float() | ||||
|     edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr) | ||||
|     if max_num_nodes is None: | ||||
|         max_num_nodes = X.size(1) | ||||
|     E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes) | ||||
|     E = encode_no_edge(E) | ||||
|     return PlaceHolder(X=X, E=E, y=None), node_mask | ||||
|  | ||||
|  | ||||
| def encode_no_edge(E): | ||||
|     assert len(E.shape) == 4 | ||||
|     if E.shape[-1] == 0: | ||||
|         return E | ||||
|     no_edge = torch.sum(E, dim=3) == 0 | ||||
|     first_elt = E[:, :, :, 0] | ||||
|     first_elt[no_edge] = 1 | ||||
|     E[:, :, :, 0] = first_elt | ||||
|     diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) | ||||
|     E[diag] = 0 | ||||
|     return E | ||||
|  | ||||
|  | ||||
| def update_config_with_new_keys(cfg, saved_cfg): | ||||
|     saved_general = saved_cfg.general | ||||
|     saved_train = saved_cfg.train | ||||
|     saved_model = saved_cfg.model | ||||
|     saved_dataset = saved_cfg.dataset | ||||
|      | ||||
|     for key, val in saved_dataset.items(): | ||||
|         OmegaConf.set_struct(cfg.dataset, True) | ||||
|         with open_dict(cfg.dataset): | ||||
|             if key not in cfg.dataset.keys(): | ||||
|                 setattr(cfg.dataset, key, val) | ||||
|  | ||||
|     for key, val in saved_general.items(): | ||||
|         OmegaConf.set_struct(cfg.general, True) | ||||
|         with open_dict(cfg.general): | ||||
|             if key not in cfg.general.keys(): | ||||
|                 setattr(cfg.general, key, val) | ||||
|  | ||||
|     OmegaConf.set_struct(cfg.train, True) | ||||
|     with open_dict(cfg.train): | ||||
|         for key, val in saved_train.items(): | ||||
|             if key not in cfg.train.keys(): | ||||
|                 setattr(cfg.train, key, val) | ||||
|  | ||||
|     OmegaConf.set_struct(cfg.model, True) | ||||
|     with open_dict(cfg.model): | ||||
|         for key, val in saved_model.items(): | ||||
|             if key not in cfg.model.keys(): | ||||
|                 setattr(cfg.model, key, val) | ||||
|     return cfg | ||||
|  | ||||
|  | ||||
| class PlaceHolder: | ||||
|     def __init__(self, X, E, y): | ||||
|         self.X = X | ||||
|         self.E = E | ||||
|         self.y = y | ||||
|  | ||||
|     def type_as(self, x: torch.Tensor, categorical: bool = False): | ||||
|         """ Changes the device and dtype of X, E, y. """ | ||||
|         self.X = self.X.type_as(x) | ||||
|         self.E = self.E.type_as(x) | ||||
|         if categorical: | ||||
|             self.y = self.y.type_as(x) | ||||
|         return self | ||||
|  | ||||
|     def mask(self, node_mask, collapse=False): | ||||
|         x_mask = node_mask.unsqueeze(-1)          # bs, n, 1 | ||||
|         e_mask1 = x_mask.unsqueeze(2)             # bs, n, 1, 1 | ||||
|         e_mask2 = x_mask.unsqueeze(1)             # bs, 1, n, 1 | ||||
|  | ||||
|         if collapse: | ||||
|             self.X = torch.argmax(self.X, dim=-1) | ||||
|             self.E = torch.argmax(self.E, dim=-1) | ||||
|  | ||||
|             self.X[node_mask == 0] = - 1 | ||||
|             self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1 | ||||
|         else: | ||||
|             self.X = self.X * x_mask | ||||
|             self.E = self.E * e_mask1 * e_mask2 | ||||
|             assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) | ||||
|         return self | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user