223 lines
8.7 KiB
Python
223 lines
8.7 KiB
Python
import os
|
|
|
|
from rdkit import Chem
|
|
from rdkit.Chem import Draw, AllChem
|
|
from rdkit.Geometry import Point3D
|
|
from rdkit import RDLogger
|
|
import imageio
|
|
import networkx as nx
|
|
import numpy as np
|
|
import rdkit.Chem
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class MolecularVisualization:
|
|
def __init__(self, dataset_infos):
|
|
self.dataset_infos = dataset_infos
|
|
|
|
def mol_from_graphs(self, node_list, adjacency_matrix):
|
|
"""
|
|
Convert graphs to rdkit molecules
|
|
node_list: the nodes of a batch of nodes (bs x n)
|
|
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
|
"""
|
|
# dictionary to map integer value to the char of atom
|
|
atom_decoder = self.dataset_infos.atom_decoder
|
|
# [list(self.dataset_infos.atom_decoder.keys())[0]]
|
|
|
|
# create empty editable mol object
|
|
mol = Chem.RWMol()
|
|
|
|
# add atoms to mol and keep track of index
|
|
node_to_idx = {}
|
|
for i in range(len(node_list)):
|
|
if node_list[i] == -1:
|
|
continue
|
|
a = Chem.Atom(atom_decoder[int(node_list[i])])
|
|
molIdx = mol.AddAtom(a)
|
|
node_to_idx[i] = molIdx
|
|
|
|
for ix, row in enumerate(adjacency_matrix):
|
|
for iy, bond in enumerate(row):
|
|
# only traverse half the symmetric matrix
|
|
if iy <= ix:
|
|
continue
|
|
if bond == 1:
|
|
bond_type = Chem.rdchem.BondType.SINGLE
|
|
elif bond == 2:
|
|
bond_type = Chem.rdchem.BondType.DOUBLE
|
|
elif bond == 3:
|
|
bond_type = Chem.rdchem.BondType.TRIPLE
|
|
elif bond == 4:
|
|
bond_type = Chem.rdchem.BondType.AROMATIC
|
|
else:
|
|
continue
|
|
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
|
|
|
|
try:
|
|
mol = mol.GetMol()
|
|
except rdkit.Chem.KekulizeException:
|
|
print("Can't kekulize molecule")
|
|
mol = None
|
|
return mol
|
|
|
|
def visualize(self, path: str, molecules: list, num_molecules_to_visualize: int, log='graph'):
|
|
# define path to save figures
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
# visualize the final molecules
|
|
print(f"Visualizing {num_molecules_to_visualize} of {len(molecules)}")
|
|
if num_molecules_to_visualize > len(molecules):
|
|
print(f"Shortening to {len(molecules)}")
|
|
num_molecules_to_visualize = len(molecules)
|
|
|
|
for i in range(num_molecules_to_visualize):
|
|
file_path = os.path.join(path, 'molecule_{}.png'.format(i))
|
|
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
|
|
try:
|
|
Draw.MolToFile(mol, file_path)
|
|
except rdkit.Chem.KekulizeException:
|
|
print("Can't kekulize molecule")
|
|
|
|
def visualize_chain(self, path, nodes_list, adjacency_matrix, trainer=None):
|
|
RDLogger.DisableLog('rdApp.*')
|
|
# convert graphs to the rdkit molecules
|
|
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
|
|
|
|
# find the coordinates of atoms in the final molecule
|
|
final_molecule = mols[-1]
|
|
AllChem.Compute2DCoords(final_molecule)
|
|
|
|
coords = []
|
|
for i, atom in enumerate(final_molecule.GetAtoms()):
|
|
positions = final_molecule.GetConformer().GetAtomPosition(i)
|
|
coords.append((positions.x, positions.y, positions.z))
|
|
|
|
# align all the molecules
|
|
for i, mol in enumerate(mols):
|
|
AllChem.Compute2DCoords(mol)
|
|
conf = mol.GetConformer()
|
|
for j, atom in enumerate(mol.GetAtoms()):
|
|
x, y, z = coords[j]
|
|
conf.SetAtomPosition(j, Point3D(x, y, z))
|
|
|
|
# draw gif
|
|
save_paths = []
|
|
num_frams = nodes_list.shape[0]
|
|
|
|
for frame in range(num_frams):
|
|
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
|
|
Draw.MolToFile(mols[frame], file_name, size=(300, 300), legend=f"Frame {frame}")
|
|
save_paths.append(file_name)
|
|
|
|
imgs = [imageio.imread(fn) for fn in save_paths]
|
|
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
|
imgs.extend([imgs[-1]] * 10)
|
|
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
|
|
|
# draw grid image
|
|
try:
|
|
img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200, 200))
|
|
img.save(os.path.join(path, '{}_grid_image.png'.format(path.split('/')[-1])))
|
|
except Chem.rdchem.KekulizeException:
|
|
print("Can't kekulize molecule")
|
|
return mols
|
|
|
|
def visualize_by_smiles(self, path: str, smiles_list: list, num_to_visualize: int):
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
print(f"Visualizing corrected {num_to_visualize} of {len(smiles_list)}")
|
|
if num_to_visualize > len(smiles_list):
|
|
print(f"Shortening to {len(smiles_list)}")
|
|
num_to_visualize = len(smiles_list)
|
|
|
|
for i in range(num_to_visualize):
|
|
file_path = os.path.join(path, 'molecule_corrected_{}.png'.format(i))
|
|
if smiles_list[i] is None:
|
|
continue
|
|
mol = Chem.MolFromSmiles(smiles_list[i])
|
|
try:
|
|
Draw.MolToFile(mol, file_path)
|
|
except rdkit.Chem.KekulizeException:
|
|
print("Can't kekulize molecule")
|
|
|
|
class NonMolecularVisualization:
|
|
def to_networkx(self, node_list, adjacency_matrix):
|
|
"""
|
|
Convert graphs to networkx graphs
|
|
node_list: the nodes of a batch of nodes (bs x n)
|
|
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
|
"""
|
|
graph = nx.Graph()
|
|
|
|
for i in range(len(node_list)):
|
|
if node_list[i] == -1:
|
|
continue
|
|
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
|
|
|
|
rows, cols = np.where(adjacency_matrix >= 1)
|
|
edges = zip(rows.tolist(), cols.tolist())
|
|
for edge in edges:
|
|
edge_type = adjacency_matrix[edge[0]][edge[1]]
|
|
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
|
|
|
|
return graph
|
|
|
|
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=100, largest_component=False):
|
|
if largest_component:
|
|
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
|
|
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
|
graph = CGs[0]
|
|
|
|
# Plot the graph structure with colors
|
|
if pos is None:
|
|
pos = nx.spring_layout(graph, iterations=iterations)
|
|
|
|
# Set node colors based on the eigenvectors
|
|
w, U = np.linalg.eigh(nx.normalized_laplacian_matrix(graph).toarray())
|
|
vmin, vmax = np.min(U[:, 1]), np.max(U[:, 1])
|
|
m = max(np.abs(vmin), vmax)
|
|
vmin, vmax = -m, m
|
|
|
|
plt.figure()
|
|
nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
|
|
cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(path)
|
|
plt.close("all")
|
|
|
|
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
|
|
# define path to save figures
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
# visualize the final molecules
|
|
for i in range(num_graphs_to_visualize):
|
|
file_path = os.path.join(path, 'graph_{}.png'.format(i))
|
|
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
|
|
self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
|
|
im = plt.imread(file_path)
|
|
|
|
def visualize_chain(self, path, nodes_list, adjacency_matrix):
|
|
# convert graphs to networkx
|
|
graphs = [self.to_networkx(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
|
|
# find the coordinates of atoms in the final molecule
|
|
final_graph = graphs[-1]
|
|
final_pos = nx.spring_layout(final_graph, seed=0)
|
|
|
|
# draw gif
|
|
save_paths = []
|
|
num_frams = nodes_list.shape[0]
|
|
|
|
for frame in range(num_frams):
|
|
file_name = os.path.join(path, 'fram_{}.png'.format(frame))
|
|
self.visualize_non_molecule(graph=graphs[frame], pos=final_pos, path=file_name)
|
|
save_paths.append(file_name)
|
|
|
|
imgs = [imageio.imread(fn) for fn in save_paths]
|
|
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
|
imgs.extend([imgs[-1]] * 10)
|
|
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|