Graph-DiT/graph_dit/metrics/molecular_metrics_sampling.py
2024-05-25 15:32:36 -04:00

138 lines
5.0 KiB
Python

### packages for visualization
from analysis.rdkit_functions import compute_molecular_metrics
from mini_moses.metrics.metrics import compute_intermediate_statistics
from metrics.property_metric import TaskModel
import torch
import torch.nn as nn
import os
import csv
import time
def result_to_csv(path, dict_data):
file_exists = os.path.exists(path)
log_name = dict_data.pop("log_name", None)
if log_name is None:
raise ValueError("The provided dictionary must contain a 'log_name' key.")
field_names = ["log_name"] + list(dict_data.keys())
dict_data["log_name"] = log_name
with open(path, "a", newline="") as file:
writer = csv.DictWriter(file, fieldnames=field_names)
if not file_exists:
writer.writeheader()
writer.writerow(dict_data)
class SamplingMolecularMetrics(nn.Module):
def __init__(
self,
dataset_infos,
train_smiles,
reference_smiles,
n_jobs=1,
device="cpu",
batch_size=512,
):
super().__init__()
self.task_name = dataset_infos.task
self.dataset_infos = dataset_infos
self.active_atoms = dataset_infos.active_atoms
self.train_smiles = train_smiles
if reference_smiles is not None:
print(
f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
)
start_time = time.time()
self.stat_ref = compute_intermediate_statistics(
reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
)
end_time = time.time()
elapsed_time = end_time - start_time
print(
f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
)
else:
self.stat_ref = None
self.comput_config = {
"n_jobs": n_jobs,
"device": device,
"batch_size": batch_size,
}
self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None}
for cur_task in dataset_infos.task.split("-")[:]:
# print('loading evaluator for task', cur_task)
model_path = os.path.join(
dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
evaluator = TaskModel(model_path, cur_task)
self.task_evaluator[cur_task] = evaluator
def forward(self, molecules, targets, name, current_epoch, val_counter, test=False):
if isinstance(targets, list):
targets_cat = torch.cat(targets, dim=0)
targets_np = targets_cat.detach().cpu().numpy()
else:
targets_np = targets.detach().cpu().numpy()
unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics(
molecules,
targets_np,
self.train_smiles,
self.stat_ref,
self.dataset_infos,
self.task_evaluator,
self.comput_config,
)
if test:
file_name = "final_smiles.txt"
with open(file_name, "w") as fp:
all_tasks_name = list(self.task_evaluator.keys())
all_tasks_name = all_tasks_name.copy()
if 'meta_taskname' in all_tasks_name:
all_tasks_name.remove('meta_taskname')
if 'scs' in all_tasks_name:
all_tasks_name.remove('scs')
all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
fp.write(all_tasks_str + "\n")
for i, smiles in enumerate(all_smiles):
if targets_log is not None:
all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
fp.write(all_result_str + "\n")
else:
fp.write("%s\n" % smiles)
print("All smiles saved")
else:
result_path = os.path.join(os.getcwd(), f"graphs/{name}")
os.makedirs(result_path, exist_ok=True)
text_path = os.path.join(
result_path,
f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt",
)
textfile = open(text_path, "w")
for smiles in unique_smiles:
textfile.write(smiles + "\n")
textfile.close()
all_logs = all_metrics
if test:
all_logs["log_name"] = "test"
else:
all_logs["log_name"] = (
"epoch" + str(current_epoch) + "_batch" + str(val_counter)
)
result_to_csv("output.csv", all_logs)
return all_smiles
def reset(self):
pass
if __name__ == "__main__":
pass