init_code
This commit is contained in:
		
							
								
								
									
										0
									
								
								mcd/metrics/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								mcd/metrics/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										138
									
								
								mcd/metrics/abstract_metrics.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								mcd/metrics/abstract_metrics.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| import torch | ||||
| from torch import Tensor | ||||
| from torch.nn import functional as F | ||||
| from torchmetrics import Metric, MeanSquaredError | ||||
|  | ||||
|  | ||||
| class TrainAbstractMetricsDiscrete(torch.nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|     def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): | ||||
|         pass | ||||
|  | ||||
|     def reset(self): | ||||
|         pass | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class TrainAbstractMetrics(torch.nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|     def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log): | ||||
|         pass | ||||
|  | ||||
|     def reset(self): | ||||
|         pass | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class SumExceptBatchMetric(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, values) -> None: | ||||
|         self.total_value += torch.sum(values) | ||||
|         self.total_samples += values.shape[0] | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_value / self.total_samples | ||||
|  | ||||
|  | ||||
| class SumExceptBatchMSE(MeanSquaredError): | ||||
|     def update(self, preds: Tensor, target: Tensor) -> None: | ||||
|         """Update state with predictions and targets. | ||||
|  | ||||
|         Args: | ||||
|             preds: Predictions from model | ||||
|             target: Ground truth values | ||||
|         """ | ||||
|         assert preds.shape == target.shape | ||||
|         sum_squared_error, n_obs = self._mean_squared_error_update(preds, target) | ||||
|  | ||||
|         self.sum_squared_error += sum_squared_error | ||||
|         self.total += n_obs | ||||
|  | ||||
|     def _mean_squared_error_update(self, preds: Tensor, target: Tensor): | ||||
|             """ Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input | ||||
|             tensors. | ||||
|                 preds: Predicted tensor | ||||
|                 target: Ground truth tensor | ||||
|             """ | ||||
|             diff = preds - target | ||||
|             sum_squared_error = torch.sum(diff * diff) | ||||
|             n_obs = preds.shape[0] | ||||
|             return sum_squared_error, n_obs | ||||
|  | ||||
|  | ||||
| class SumExceptBatchKL(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, p, q) -> None: | ||||
|         self.total_value += F.kl_div(q, p, reduction='sum') | ||||
|         self.total_samples += p.size(0) | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_value / self.total_samples | ||||
|  | ||||
|  | ||||
| class CrossEntropyMetric(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, preds: Tensor, target: Tensor, weight=None) -> None: | ||||
|         """ Update state with predictions and targets. | ||||
|             preds: Predictions from model   (bs * n, d) or (bs * n * n, d) | ||||
|             target: Ground truth values     (bs * n, d) or (bs * n * n, d). """ | ||||
|         target = torch.argmax(target, dim=-1) | ||||
|         if weight is not None: | ||||
|             weight = weight.type_as(preds) | ||||
|             output = F.cross_entropy(preds, target, weight = weight, reduction='sum') | ||||
|         else: | ||||
|             output = F.cross_entropy(preds, target, reduction='sum') | ||||
|         self.total_ce += output | ||||
|         self.total_samples += preds.size(0) | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_ce / self.total_samples | ||||
|  | ||||
|  | ||||
| class ProbabilityMetric(Metric): | ||||
|     def __init__(self): | ||||
|         """ This metric is used to track the marginal predicted probability of a class during training. """ | ||||
|         super().__init__() | ||||
|         self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, preds: Tensor) -> None: | ||||
|         self.prob += preds.sum() | ||||
|         self.total += preds.numel() | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.prob / self.total | ||||
|  | ||||
|  | ||||
| class NLL(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|  | ||||
|     def update(self, batch_nll) -> None: | ||||
|         self.total_nll += torch.sum(batch_nll) | ||||
|         self.total_samples += batch_nll.numel() | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_nll / self.total_samples | ||||
							
								
								
									
										
											BIN
										
									
								
								mcd/metrics/fpscores.pkl.gz
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								mcd/metrics/fpscores.pkl.gz
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										138
									
								
								mcd/metrics/molecular_metrics_sampling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								mcd/metrics/molecular_metrics_sampling.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| ### packages for visualization | ||||
| from analysis.rdkit_functions import compute_molecular_metrics | ||||
| from mini_moses.metrics.metrics import compute_intermediate_statistics | ||||
| from metrics.property_metric import TaskModel | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| import os | ||||
| import csv | ||||
| import time | ||||
|  | ||||
| def result_to_csv(path, dict_data): | ||||
|     file_exists = os.path.exists(path) | ||||
|     log_name = dict_data.pop("log_name", None) | ||||
|     if log_name is None: | ||||
|         raise ValueError("The provided dictionary must contain a 'log_name' key.") | ||||
|     field_names = ["log_name"] + list(dict_data.keys()) | ||||
|     dict_data["log_name"] = log_name | ||||
|     with open(path, "a", newline="") as file: | ||||
|         writer = csv.DictWriter(file, fieldnames=field_names) | ||||
|         if not file_exists: | ||||
|             writer.writeheader() | ||||
|         writer.writerow(dict_data) | ||||
|  | ||||
|  | ||||
| class SamplingMolecularMetrics(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         dataset_infos, | ||||
|         train_smiles, | ||||
|         reference_smiles, | ||||
|         n_jobs=1, | ||||
|         device="cpu", | ||||
|         batch_size=512, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.task_name = dataset_infos.task | ||||
|         self.dataset_infos = dataset_infos | ||||
|         self.active_atoms = dataset_infos.active_atoms | ||||
|         self.train_smiles = train_smiles | ||||
|  | ||||
|         if reference_smiles is not None: | ||||
|             print( | ||||
|                 f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" | ||||
|             ) | ||||
|             start_time = time.time() | ||||
|             self.stat_ref = compute_intermediate_statistics( | ||||
|                 reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size | ||||
|             ) | ||||
|             end_time = time.time() | ||||
|             elapsed_time = end_time - start_time | ||||
|             print( | ||||
|                 f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" | ||||
|             ) | ||||
|         else: | ||||
|             self.stat_ref = None | ||||
|      | ||||
|         self.comput_config = { | ||||
|             "n_jobs": n_jobs, | ||||
|             "device": device, | ||||
|             "batch_size": batch_size, | ||||
|         } | ||||
|  | ||||
|         self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None} | ||||
|         for cur_task in dataset_infos.task.split("-")[:]: | ||||
|             # print('loading evaluator for task', cur_task) | ||||
|             model_path = os.path.join( | ||||
|                 dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib" | ||||
|             ) | ||||
|             os.makedirs(os.path.dirname(model_path), exist_ok=True) | ||||
|             evaluator = TaskModel(model_path, cur_task) | ||||
|             self.task_evaluator[cur_task] = evaluator | ||||
|  | ||||
|     def forward(self, molecules, targets, name, current_epoch, val_counter, test=False): | ||||
|         if isinstance(targets, list): | ||||
|             targets_cat = torch.cat(targets, dim=0) | ||||
|             targets_np = targets_cat.detach().cpu().numpy() | ||||
|         else: | ||||
|             targets_np = targets.detach().cpu().numpy() | ||||
|  | ||||
|         unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics( | ||||
|             molecules, | ||||
|             targets_np, | ||||
|             self.train_smiles, | ||||
|             self.stat_ref, | ||||
|             self.dataset_infos, | ||||
|             self.task_evaluator, | ||||
|             self.comput_config, | ||||
|         ) | ||||
|  | ||||
|         if test: | ||||
|             file_name = "final_smiles.txt" | ||||
|             with open(file_name, "w") as fp: | ||||
|                 all_tasks_name = list(self.task_evaluator.keys()) | ||||
|                 all_tasks_name = all_tasks_name.copy() | ||||
|                 if 'meta_taskname' in all_tasks_name: | ||||
|                     all_tasks_name.remove('meta_taskname') | ||||
|                 if 'scs' in all_tasks_name: | ||||
|                     all_tasks_name.remove('scs') | ||||
|  | ||||
|                 all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) | ||||
|                 fp.write(all_tasks_str + "\n") | ||||
|                 for i, smiles in enumerate(all_smiles): | ||||
|                     if targets_log is not None: | ||||
|                         all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) | ||||
|                         fp.write(all_result_str + "\n") | ||||
|                     else: | ||||
|                         fp.write("%s\n" % smiles) | ||||
|                 print("All smiles saved") | ||||
|         else: | ||||
|             result_path = os.path.join(os.getcwd(), f"graphs/{name}") | ||||
|             os.makedirs(result_path, exist_ok=True) | ||||
|             text_path = os.path.join( | ||||
|                 result_path, | ||||
|                 f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt", | ||||
|             ) | ||||
|             textfile = open(text_path, "w") | ||||
|             for smiles in unique_smiles: | ||||
|                 textfile.write(smiles + "\n") | ||||
|             textfile.close() | ||||
|  | ||||
|         all_logs = all_metrics | ||||
|         if test: | ||||
|             all_logs["log_name"] = "test" | ||||
|         else: | ||||
|             all_logs["log_name"] = ( | ||||
|                 "epoch" + str(current_epoch) + "_batch" + str(val_counter) | ||||
|             ) | ||||
|          | ||||
|         result_to_csv("output.csv", all_logs) | ||||
|         return all_smiles | ||||
|  | ||||
|     def reset(self): | ||||
|         pass | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     pass | ||||
							
								
								
									
										126
									
								
								mcd/metrics/molecular_metrics_train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								mcd/metrics/molecular_metrics_train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | ||||
| import torch | ||||
| from torchmetrics import Metric, MetricCollection | ||||
| from torch import Tensor | ||||
| import torch.nn as nn | ||||
|  | ||||
| class CEPerClass(Metric): | ||||
|     full_state_update = False | ||||
|     def __init__(self, class_id): | ||||
|         super().__init__() | ||||
|         self.class_id = class_id | ||||
|         self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.softmax = torch.nn.Softmax(dim=-1) | ||||
|         self.binary_cross_entropy = torch.nn.BCELoss(reduction='sum') | ||||
|      | ||||
|     def update(self, preds: Tensor, target: Tensor) -> None: | ||||
|         """Update state with predictions and targets. | ||||
|         Args: | ||||
|             preds: Predictions from model   (bs, n, d) or (bs, n, n, d) | ||||
|             target: Ground truth values     (bs, n, d) or (bs, n, n, d) | ||||
|         """ | ||||
|         target = target.reshape(-1, target.shape[-1]) | ||||
|         mask = (target != 0.).any(dim=-1) | ||||
|  | ||||
|         prob = self.softmax(preds)[..., self.class_id] | ||||
|         prob = prob.flatten()[mask] | ||||
|  | ||||
|         target = target[:, self.class_id] | ||||
|         target = target[mask] | ||||
|  | ||||
|         output = self.binary_cross_entropy(prob, target) | ||||
|  | ||||
|         self.total_ce += output | ||||
|         self.total_samples += prob.numel() | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_ce / self.total_samples | ||||
|  | ||||
|  | ||||
| class AtomCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
| class NoBondCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class SingleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class DoubleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class TripleCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class AromaticCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
|  | ||||
| class AtomMetricsCE(MetricCollection): | ||||
|     def __init__(self, active_atoms): | ||||
|         metrics_list = [] | ||||
|          | ||||
|         for i, atom_type in enumerate(active_atoms): | ||||
|             metrics_list.append(type(f'{atom_type}_CE', (AtomCE,), {})(i)) | ||||
|  | ||||
|         super().__init__(metrics_list) | ||||
|  | ||||
|  | ||||
| class BondMetricsCE(MetricCollection): | ||||
|     def __init__(self): | ||||
|         ce_no_bond = NoBondCE(0) | ||||
|         ce_SI = SingleCE(1) | ||||
|         ce_DO = DoubleCE(2) | ||||
|         ce_TR = TripleCE(3) | ||||
|         super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) | ||||
|  | ||||
|  | ||||
| class TrainMolecularMetricsDiscrete(nn.Module): | ||||
|     def __init__(self, dataset_infos): | ||||
|         super().__init__() | ||||
|         active_atoms = dataset_infos.active_atoms | ||||
|         self.train_atom_metrics = AtomMetricsCE(active_atoms=active_atoms) | ||||
|         self.train_bond_metrics = BondMetricsCE() | ||||
|  | ||||
|     def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): | ||||
|         self.train_atom_metrics(masked_pred_X, true_X) | ||||
|         self.train_bond_metrics(masked_pred_E, true_E) | ||||
|         if log: | ||||
|             to_log = {} | ||||
|             for key, val in self.train_atom_metrics.compute().items(): | ||||
|                 to_log['train/' + key] = val.item() | ||||
|             for key, val in self.train_bond_metrics.compute().items(): | ||||
|                 to_log['train/' + key] = val.item() | ||||
|  | ||||
|     def reset(self): | ||||
|         for metric in [self.train_atom_metrics, self.train_bond_metrics]: | ||||
|             metric.reset() | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch, log=True): | ||||
|         epoch_atom_metrics = self.train_atom_metrics.compute() | ||||
|         epoch_bond_metrics = self.train_bond_metrics.compute() | ||||
|  | ||||
|         to_log = {} | ||||
|         for key, val in epoch_atom_metrics.items(): | ||||
|             to_log['train_epoch/' + key] = val.item() | ||||
|         for key, val in epoch_bond_metrics.items(): | ||||
|             to_log['train_epoch/' + key] = val.item() | ||||
|  | ||||
|         for key, val in epoch_atom_metrics.items(): | ||||
|             epoch_atom_metrics[key] = round(val.item(),4) | ||||
|         for key, val in epoch_bond_metrics.items(): | ||||
|             epoch_bond_metrics[key] = round(val.item(),4) | ||||
|  | ||||
|         if log: | ||||
|             print(f"Epoch {current_epoch}: {epoch_atom_metrics} -- {epoch_bond_metrics}") | ||||
|  | ||||
							
								
								
									
										201
									
								
								mcd/metrics/property_metric.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								mcd/metrics/property_metric.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,201 @@ | ||||
| import math, os | ||||
| import pickle | ||||
| import os.path as op | ||||
|  | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| from joblib import dump, load | ||||
| from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor | ||||
| from sklearn.metrics import mean_absolute_error, roc_auc_score | ||||
|  | ||||
|  | ||||
| from rdkit import Chem | ||||
| from rdkit import rdBase | ||||
| from rdkit.Chem import AllChem | ||||
| from rdkit import DataStructs | ||||
| from rdkit.Chem import rdMolDescriptors | ||||
| rdBase.DisableLog('rdApp.error') | ||||
|  | ||||
| task_to_colname = { | ||||
|     'hiv_b': 'HIV_active', | ||||
|     'bace_b': 'Class', | ||||
|     'bbbp_b': 'p_np', | ||||
|     'O2': 'O2', | ||||
|     'N2': 'N2', | ||||
|     'CO2': 'CO2', | ||||
| } | ||||
|  | ||||
| tasktype_name = { | ||||
|     'hiv_b': 'classification', | ||||
|     'bace_b': 'classification', | ||||
|     'bbbp_b': 'classification', | ||||
|     'O2': 'regression', | ||||
|     'N2': 'regression', | ||||
|     'CO2': 'regression', | ||||
| } | ||||
|  | ||||
| class TaskModel(): | ||||
|     """Scores based on an ECFP classifier.""" | ||||
|     def __init__(self, model_path, task_name): | ||||
|         task_type = tasktype_name[task_name] | ||||
|         self.task_name = task_name | ||||
|         self.task_type = task_type | ||||
|         self.model_path = model_path | ||||
|         self.metric_func = roc_auc_score if 'classification' in self.task_type else mean_absolute_error | ||||
|  | ||||
|         try: | ||||
|             self.model = load(model_path) | ||||
|             print(self.task_name, ' evaluator loaded') | ||||
|         except: | ||||
|             print(self.task_name, ' evaluator not found, training new one...') | ||||
|             if 'classification' in task_type: | ||||
|                 self.model = RandomForestClassifier(random_state=0) | ||||
|             elif 'regression' in task_type: | ||||
|                 self.model = RandomForestRegressor(random_state=0) | ||||
|             perfermance = self.train() | ||||
|             dump(self.model, model_path) | ||||
|             print('Oracle peformance: ', perfermance) | ||||
|  | ||||
|     def train(self): | ||||
|         data_path = os.path.dirname(self.model_path) | ||||
|         data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz') | ||||
|         df = pd.read_csv(data_path) | ||||
|         col_name = task_to_colname[self.task_name] | ||||
|         y = df[col_name].to_numpy() | ||||
|         x_smiles = df['smiles'].to_numpy() | ||||
|         mask = ~np.isnan(y) | ||||
|         y = y[mask] | ||||
|  | ||||
|         if 'classification' in self.task_type: | ||||
|             y = y.astype(int) | ||||
|  | ||||
|         x_smiles = x_smiles[mask] | ||||
|         x_fps = [] | ||||
|         mask = [] | ||||
|         for i,smiles in enumerate(x_smiles): | ||||
|             mol = Chem.MolFromSmiles(smiles) | ||||
|             mask.append( int(mol is not None) ) | ||||
|             fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048)) | ||||
|             x_fps.append(fp) | ||||
|         x_fps = np.concatenate(x_fps, axis=0) | ||||
|         self.model.fit(x_fps, y) | ||||
|         y_pred = self.model.predict(x_fps) | ||||
|         perf = self.metric_func(y, y_pred) | ||||
|         print(f'{self.task_name} performance: {perf}') | ||||
|         return perf | ||||
|  | ||||
|     def __call__(self, smiles_list): | ||||
|         fps = [] | ||||
|         mask = [] | ||||
|         for i,smiles in enumerate(smiles_list): | ||||
|             mol = Chem.MolFromSmiles(smiles) | ||||
|             mask.append( int(mol is not None) ) | ||||
|             fp = TaskModel.fingerprints_from_mol(mol) if mol else np.zeros((1, 2048)) | ||||
|             fps.append(fp) | ||||
|  | ||||
|         fps = np.concatenate(fps, axis=0) | ||||
|         if 'classification' in self.task_type: | ||||
|             scores = self.model.predict_proba(fps)[:, 1] | ||||
|         else: | ||||
|             scores = self.model.predict(fps) | ||||
|         scores = scores * np.array(mask) | ||||
|         return np.float32(scores) | ||||
|  | ||||
|     @classmethod | ||||
|     def fingerprints_from_mol(cls, mol):  # use ECFP4 | ||||
|         features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) | ||||
|         features = np.zeros((1,)) | ||||
|         DataStructs.ConvertToNumpyArray(features_vec, features) | ||||
|         return features.reshape(1, -1) | ||||
|  | ||||
| ###### SAS Score ###### | ||||
| _fscores = None | ||||
|  | ||||
| def readFragmentScores(name='fpscores'): | ||||
|     import gzip | ||||
|     global _fscores | ||||
|     # generate the full path filename: | ||||
|     if name == "fpscores": | ||||
|         name = op.join(op.dirname(__file__), name) | ||||
|     data = pickle.load(gzip.open('%s.pkl.gz' % name)) | ||||
|     outDict = {} | ||||
|     for i in data: | ||||
|         for j in range(1, len(i)): | ||||
|             outDict[i[j]] = float(i[0]) | ||||
|     _fscores = outDict | ||||
|  | ||||
| def numBridgeheadsAndSpiro(mol, ri=None): | ||||
|     nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) | ||||
|     nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) | ||||
|     return nBridgehead, nSpiro | ||||
|  | ||||
| def calculateSAS(smiles_list): | ||||
|     scores = [] | ||||
|     for i, smiles in enumerate(smiles_list): | ||||
|         mol = Chem.MolFromSmiles(smiles) | ||||
|         score = calculateScore(mol) | ||||
|         scores.append(score) | ||||
|     return np.float32(scores) | ||||
|  | ||||
| def calculateScore(m): | ||||
|     if _fscores is None: | ||||
|         readFragmentScores() | ||||
|  | ||||
|     # fragment score | ||||
|     fp = rdMolDescriptors.GetMorganFingerprint(m, | ||||
|                                                2)  # <- 2 is the *radius* of the circular fingerprint | ||||
|     fps = fp.GetNonzeroElements() | ||||
|     score1 = 0. | ||||
|     nf = 0 | ||||
|     for bitId, v in fps.items(): | ||||
|         nf += v | ||||
|         sfp = bitId | ||||
|         score1 += _fscores.get(sfp, -4) * v | ||||
|     score1 /= nf | ||||
|  | ||||
|     # features score | ||||
|     nAtoms = m.GetNumAtoms() | ||||
|     nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) | ||||
|     ri = m.GetRingInfo() | ||||
|     nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) | ||||
|     nMacrocycles = 0 | ||||
|     for x in ri.AtomRings(): | ||||
|         if len(x) > 8: | ||||
|             nMacrocycles += 1 | ||||
|  | ||||
|     sizePenalty = nAtoms**1.005 - nAtoms | ||||
|     stereoPenalty = math.log10(nChiralCenters + 1) | ||||
|     spiroPenalty = math.log10(nSpiro + 1) | ||||
|     bridgePenalty = math.log10(nBridgeheads + 1) | ||||
|     macrocyclePenalty = 0. | ||||
|     # --------------------------------------- | ||||
|     # This differs from the paper, which defines: | ||||
|     #  macrocyclePenalty = math.log10(nMacrocycles+1) | ||||
|     # This form generates better results when 2 or more macrocycles are present | ||||
|     if nMacrocycles > 0: | ||||
|         macrocyclePenalty = math.log10(2) | ||||
|  | ||||
|     score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty | ||||
|  | ||||
|     # correction for the fingerprint density | ||||
|     # not in the original publication, added in version 1.1 | ||||
|     # to make highly symmetrical molecules easier to synthetise | ||||
|     score3 = 0. | ||||
|     if nAtoms > len(fps): | ||||
|         score3 = math.log(float(nAtoms) / len(fps)) * .5 | ||||
|  | ||||
|     sascore = score1 + score2 + score3 | ||||
|  | ||||
|     # need to transform "raw" value into scale between 1 and 10 | ||||
|     min = -4.0 | ||||
|     max = 2.5 | ||||
|     sascore = 11. - (sascore - min + 1) / (max - min) * 9. | ||||
|     # smooth the 10-end | ||||
|     if sascore > 8.: | ||||
|         sascore = 8. + math.log(sascore + 1. - 9.) | ||||
|     if sascore > 10.: | ||||
|         sascore = 10.0 | ||||
|     elif sascore < 1.: | ||||
|         sascore = 1.0 | ||||
|  | ||||
|     return sascore | ||||
							
								
								
									
										94
									
								
								mcd/metrics/train_loss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								mcd/metrics/train_loss.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| import time | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from metrics.abstract_metrics import CrossEntropyMetric | ||||
| from torchmetrics import Metric, MeanSquaredError | ||||
|  | ||||
| # from 2:He to 119:* | ||||
| valencies_check = [0, 1, 2, 3, 4, 3, 2, 1, 0, 1, 2, 6, 6, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 4, 7, 6, 5, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] | ||||
| valencies_check = torch.tensor(valencies_check) | ||||
|  | ||||
| weight_check = [4.003, 6.941, 9.012, 10.812, 12.011, 14.007, 15.999, 18.998, 20.18, 22.99, 24.305, 26.982, 28.086, 30.974, 32.067, 35.453, 39.948, 39.098, 40.078, 44.956, 47.867, 50.942, 51.996, 54.938, 55.845, 58.933, 58.693, 63.546, 65.39, 69.723, 72.61, 74.922, 78.96, 79.904, 83.8, 85.468, 87.62, 88.906, 91.224, 92.906, 95.94, 98.0, 101.07, 102.906, 106.42, 107.868, 112.412, 114.818, 118.711, 121.76, 127.6, 126.904, 131.29, 132.905, 137.328, 138.906, 140.116, 140.908, 144.24, 145.0, 150.36, 151.964, 157.25, 158.925, 162.5, 164.93, 167.26, 168.934, 173.04, 174.967, 178.49, 180.948, 183.84, 186.207, 190.23, 192.217, 195.078, 196.967, 200.59, 204.383, 207.2, 208.98, 209.0, 210.0, 222.0, 223.0, 226.0, 227.0, 232.038, 231.036, 238.029, 237.0, 244.0, 243.0, 247.0, 247.0, 251.0, 252.0, 257.0, 258.0, 259.0, 262.0, 267.0, 268.0, 269.0, 270.0, 269.0, 278.0, 281.0, 281.0, 285.0, 284.0, 289.0, 288.0, 293.0, 292.0, 294.0, 294.0] | ||||
| weight_check = torch.tensor(weight_check) | ||||
|  | ||||
| class AtomWeightMetric(Metric): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.add_state('total_loss', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum") | ||||
|         global weight_check | ||||
|         self.weight_check = weight_check | ||||
|  | ||||
|     def update(self, X, Y): | ||||
|         atom_pred_num = X.argmax(dim=-1) | ||||
|         atom_real_num = Y.argmax(dim=-1) | ||||
|         self.weight_check = self.weight_check.type_as(X) | ||||
|  | ||||
|         pred_weight = self.weight_check[atom_pred_num] | ||||
|         real_weight = self.weight_check[atom_real_num] | ||||
|  | ||||
|         lss = 0 | ||||
|         lss += torch.abs(pred_weight.sum(dim=-1) - real_weight.sum(dim=-1)).sum() | ||||
|         self.total_loss += lss | ||||
|         self.total_samples += X.size(0) | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_loss / self.total_samples | ||||
|  | ||||
|  | ||||
| class TrainLossDiscrete(nn.Module): | ||||
|     """ Train with Cross entropy""" | ||||
|     def __init__(self, lambda_train, weight_node=None, weight_edge=None): | ||||
|         super().__init__() | ||||
|         self.node_loss = CrossEntropyMetric() | ||||
|         self.edge_loss = CrossEntropyMetric() | ||||
|         self.weight_loss = AtomWeightMetric() | ||||
|  | ||||
|         self.y_loss = MeanSquaredError() | ||||
|         self.lambda_train = lambda_train | ||||
|  | ||||
|     def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, node_mask, log: bool): | ||||
|         """ Compute train metrics | ||||
|         masked_pred_X : tensor -- (bs, n, dx) | ||||
|         masked_pred_E : tensor -- (bs, n, n, de) | ||||
|         pred_y : tensor -- (bs, ) | ||||
|         true_X : tensor -- (bs, n, dx) | ||||
|         true_E : tensor -- (bs, n, n, de) | ||||
|         true_y : tensor -- (bs, ) | ||||
|         log : boolean. """ | ||||
|  | ||||
|         loss_weight = self.weight_loss(masked_pred_X, true_X) | ||||
|          | ||||
|         true_X = torch.reshape(true_X, (-1, true_X.size(-1)))  # (bs * n, dx) | ||||
|         true_E = torch.reshape(true_E, (-1, true_E.size(-1)))  # (bs * n * n, de) | ||||
|         masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1)))  # (bs * n, dx) | ||||
|         masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1)))   # (bs * n * n, de) | ||||
|  | ||||
|         # Remove masked rows | ||||
|         mask_X = (true_X != 0.).any(dim=-1) | ||||
|         mask_E = (true_E != 0.).any(dim=-1) | ||||
|  | ||||
|         flat_true_X = true_X[mask_X, :] | ||||
|         flat_pred_X = masked_pred_X[mask_X, :] | ||||
|  | ||||
|         flat_true_E = true_E[mask_E, :] | ||||
|         flat_pred_E = masked_pred_E[mask_E, :] | ||||
|          | ||||
|         loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0 | ||||
|         loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0 | ||||
|  | ||||
|         return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + loss_weight | ||||
|  | ||||
|     def reset(self): | ||||
|         for metric in [self.node_loss, self.edge_loss, self.y_loss]: | ||||
|             metric.reset() | ||||
|  | ||||
|     def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True): | ||||
|         epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1 | ||||
|         epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1 | ||||
|         epoch_weight_loss = self.weight_loss.compute() if self.weight_loss.total_samples > 0 else -1 | ||||
|  | ||||
|         if log: | ||||
|             print(f"Epoch {current_epoch} finished: X_CE: {epoch_node_loss :.4f} -- E_CE: {epoch_edge_loss :.4f} " | ||||
|                 f"Weight: {epoch_weight_loss :.4f} " | ||||
|                 f"-- Time taken {time.time() - start_epoch_time:.1f}s ") | ||||
		Reference in New Issue
	
	Block a user