138 lines
5.0 KiB
Python
138 lines
5.0 KiB
Python
### packages for visualization
|
|
from analysis.rdkit_functions import compute_molecular_metrics
|
|
from mini_moses.metrics.metrics import compute_intermediate_statistics
|
|
from metrics.property_metric import TaskModel
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import os
|
|
import csv
|
|
import time
|
|
|
|
def result_to_csv(path, dict_data):
|
|
file_exists = os.path.exists(path)
|
|
log_name = dict_data.pop("log_name", None)
|
|
if log_name is None:
|
|
raise ValueError("The provided dictionary must contain a 'log_name' key.")
|
|
field_names = ["log_name"] + list(dict_data.keys())
|
|
dict_data["log_name"] = log_name
|
|
with open(path, "a", newline="") as file:
|
|
writer = csv.DictWriter(file, fieldnames=field_names)
|
|
if not file_exists:
|
|
writer.writeheader()
|
|
writer.writerow(dict_data)
|
|
|
|
|
|
class SamplingMolecularMetrics(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dataset_infos,
|
|
train_smiles,
|
|
reference_smiles,
|
|
n_jobs=1,
|
|
device="cpu",
|
|
batch_size=512,
|
|
):
|
|
super().__init__()
|
|
self.task_name = dataset_infos.task
|
|
self.dataset_infos = dataset_infos
|
|
self.active_atoms = dataset_infos.active_atoms
|
|
self.train_smiles = train_smiles
|
|
|
|
if reference_smiles is not None:
|
|
print(
|
|
f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
|
|
)
|
|
start_time = time.time()
|
|
self.stat_ref = compute_intermediate_statistics(
|
|
reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
|
|
)
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
print(
|
|
f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
|
|
)
|
|
else:
|
|
self.stat_ref = None
|
|
|
|
self.comput_config = {
|
|
"n_jobs": n_jobs,
|
|
"device": device,
|
|
"batch_size": batch_size,
|
|
}
|
|
|
|
self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None}
|
|
for cur_task in dataset_infos.task.split("-")[:]:
|
|
# print('loading evaluator for task', cur_task)
|
|
model_path = os.path.join(
|
|
dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
|
|
)
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
evaluator = TaskModel(model_path, cur_task)
|
|
self.task_evaluator[cur_task] = evaluator
|
|
|
|
def forward(self, molecules, targets, name, current_epoch, val_counter, test=False):
|
|
if isinstance(targets, list):
|
|
targets_cat = torch.cat(targets, dim=0)
|
|
targets_np = targets_cat.detach().cpu().numpy()
|
|
else:
|
|
targets_np = targets.detach().cpu().numpy()
|
|
|
|
unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics(
|
|
molecules,
|
|
targets_np,
|
|
self.train_smiles,
|
|
self.stat_ref,
|
|
self.dataset_infos,
|
|
self.task_evaluator,
|
|
self.comput_config,
|
|
)
|
|
|
|
if test:
|
|
file_name = "final_smiles.txt"
|
|
with open(file_name, "w") as fp:
|
|
all_tasks_name = list(self.task_evaluator.keys())
|
|
all_tasks_name = all_tasks_name.copy()
|
|
if 'meta_taskname' in all_tasks_name:
|
|
all_tasks_name.remove('meta_taskname')
|
|
if 'scs' in all_tasks_name:
|
|
all_tasks_name.remove('scs')
|
|
|
|
all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
|
|
fp.write(all_tasks_str + "\n")
|
|
for i, smiles in enumerate(all_smiles):
|
|
if targets_log is not None:
|
|
all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
|
|
fp.write(all_result_str + "\n")
|
|
else:
|
|
fp.write("%s\n" % smiles)
|
|
print("All smiles saved")
|
|
else:
|
|
result_path = os.path.join(os.getcwd(), f"graphs/{name}")
|
|
os.makedirs(result_path, exist_ok=True)
|
|
text_path = os.path.join(
|
|
result_path,
|
|
f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt",
|
|
)
|
|
textfile = open(text_path, "w")
|
|
for smiles in unique_smiles:
|
|
textfile.write(smiles + "\n")
|
|
textfile.close()
|
|
|
|
all_logs = all_metrics
|
|
if test:
|
|
all_logs["log_name"] = "test"
|
|
else:
|
|
all_logs["log_name"] = (
|
|
"epoch" + str(current_epoch) + "_batch" + str(val_counter)
|
|
)
|
|
|
|
result_to_csv("output.csv", all_logs)
|
|
return all_smiles
|
|
|
|
def reset(self):
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
pass |