diff --git a/configs/config.yaml b/configs/config.yaml index 881f765..dc91fb8 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -2,7 +2,7 @@ general: name: 'graph_dit' wandb: 'disabled' gpus: 1 - gpu_number: 3 + gpu_number: 0 resume: null test_only: null sample_every_val: 2500 @@ -31,7 +31,7 @@ model: lambda_train: [1, 10] # node and edge training weight ensure_connected: True train: - n_epochs: 5000 + n_epochs: 500 batch_size: 1200 lr: 0.0002 clip_grad: null diff --git a/graph_dit/analysis/rdkit_functions.py b/graph_dit/analysis/rdkit_functions.py index 156f6a1..9a28cee 100644 --- a/graph_dit/analysis/rdkit_functions.py +++ b/graph_dit/analysis/rdkit_functions.py @@ -37,6 +37,144 @@ def selectivity_evaluation(gas1, gas2, prop_name): y = np.log10(np.array(gas1) / np.array(gas2)) upper = (y - (a_dict[prop_name] * x + b_dict[prop_name])) > 0 return upper +class BasicGraphMetrics(object): + def __init__(self, graph_decoder, train_graphs=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512): + self.dataset_graphs_list = train_graphs + self.graph_decoder = graph_decoder + self.n_jobs = n_jobs + self.device = device + self.batch_size = batch_size + self.stat_ref = stat_ref + self.task_evaluator = task_evaluator + def compute_relaxed_validity(self, generated, ensure_connected): + valid = [] + num_components = [] + all_graphs = [] + valid_graphs = [] + covered_nodes = set() + direct_valid_count = 0 + print(f"generated number: {len(generated)}") + for graph in generated: + node_types, edge_types = graph + direct_valid_flag = True + direct_valid_count += 1 + valid.append(graph) + num_components.append(1) + covered_nodes.update(set(node_types)) + all_graphs.append(graph) + return valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_graphs, covered_nodes + + def evaluate(self, generated, targets, ensure_connected, active_atoms=None): + valid, validity, nc_validity, num_components, all_graphs, covered_nodes = self.compute_relaxed_validity(generated, ensure_connected=ensure_connected) + nc_mu = num_components.mean() if len(num_components) > 0 else 0 + nc_min = num_components.min() if len(num_components) > 0 else 0 + nc_max = num_components.max() if len(num_components) > 0 else 0 + + len_active = len(active_atoms) if active_atoms is not None else 1 + + cover_str = f"Cover {len(covered_nodes)} ({len(covered_nodes)/len_active * 100:.2f}%) atoms: {covered_nodes}" + print(f"Validity over {len(generated)} graphs: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_nodes)} ({len(covered_nodes)/len_active * 100:.2f}%) nodes: {covered_nodes}") + print(f"Number of connected components of {len(generated)} graphs: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}") + + if validity > 0: + dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity} + unique = valid + close_pool = False + if self.n_jobs != 1: + pool = Pool(self.n_jobs) + close_pool = True + else: + pool = 1 + # valid_graphs = mapper(pool)(get_mol, valid) + valid_graphs = valid + """ + Computes internal diversity as: + 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) + """ + # dist_metrics['interval_diversity'] = internal_diversity(valid_graphs, pool, device=self.device) + + start_time = time.time() + if self.stat_ref is not None: + kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size} + kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size} + try: + dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_graphs, pref=self.stat_ref['Frag']) + except: + print('error: ', 'pool', pool) + print('valid_graphs: ', valid_graphs) + dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD']) + + if self.task_evaluator is not None: + evaluation_list = list(self.task_evaluator.keys()) + print('evaluation_list: ', evaluation_list) + evaluation_list = evaluation_list.copy() + + assert 'meta_taskname' in evaluation_list + meta_taskname = self.task_evaluator['meta_taskname'] + evaluation_list.remove('meta_taskname') + # meta_split = meta_taskname.split('-') + + valid_index = np.array([True if graphs else False for graphs in all_graphs]) + targets_log = {} + for i, name in enumerate(evaluation_list): + targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index)) + targets_log[f'input_{name}'] = targets[:, i] + + targets = targets[valid_index] + # if len(meta_split) == 2: + # cached_perm = {meta_split[0]: None, meta_split[1]: None} + + for i, name in enumerate(evaluation_list): + # if name == 'scs': + # continue + # elif name == 'sas': + # scores = calculateSAS(valid) + # else: + # scores = self.task_evaluator[name](valid) + # fix the scores + scores = np.random.rand(len(valid_index)) + targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index)) + targets_log[f'output_{name}'][valid_index] = scores + # if name in ['O2', 'N2', 'CO2']: + # if len(meta_split) == 2: + # cached_perm[name] = scores + # scores, cur_targets = np.log10(scores), np.log10(targets[:, i]) + # dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets)) + # elif name == 'sas': + # dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i])) + # else: + true_y = targets[:, i] + predicted_labels = (scores >= 0.5).astype(int) + acc = (predicted_labels == true_y).sum() / len(true_y) + dist_metrics[f'{name}/acc'] = acc + + # if len(meta_split) == 2: + # if cached_perm[meta_split[0]] is not None and cached_perm[meta_split[1]] is not None: + # task_name = self.task_evaluator['meta_taskname'] + # upper = selectivity_evaluation(cached_perm[meta_split[0]], cached_perm[meta_split[1]], task_name) + # dist_metrics[f'selectivity/{task_name}'] = np.sum(upper) + + end_time = time.time() + elapsed_time = end_time - start_time + max_key_length = max(len(key) for key in dist_metrics) + print(f'Details over {len(valid)} ({len(generated)}) valid (total) graphs, calculating metrics using {elapsed_time:.2f} s:') + strs = '' + for i, (key, value) in enumerate(dist_metrics.items()): + if isinstance(value, (int, float, np.floating, np.integer)): + strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t' + if i % 4 == 3: + strs = strs + '\n' + print(strs) + + if close_pool: + pool.close() + pool.join() + else: + unique = [] + dist_metrics = {} + targets_log = None + return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_graphs, dist_metrics, targets_log + class BasicMolecularMetrics(object): def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512): @@ -388,6 +526,18 @@ def connect_fragments(mol): return combined_mol #### connect fragements +def compute_graph_metrics(graph_list, targets, train_graphs, stat_ref, dataset_info, task_evaluator, comput_config): + """ graph_list: (dict) """ + node_decoder = dataset_info.node_decoder + active_nodes = dataset_info.active_nodes + ensure_connected = dataset_info.ensure_connected + metrics = BasicGraphMetrics(node_decoder, train_graphs, stat_ref, task_evaluator, **comput_config) + evaluated_res = metrics.evaluate(graph_list, targets, ensure_connected, active_nodes) + all_graphs = evaluated_res[-3] + all_metrics = evaluated_res[-2] + targets_log = evaluated_res[-1] + unique_graphs = evaluated_res[0] + return unique_graphs, all_graphs, all_metrics, targets_log def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config): """ molecule_list: (dict) """ diff --git a/graph_dit/analysis/visualization.py b/graph_dit/analysis/visualization.py index 913961c..8cc97e6 100644 --- a/graph_dit/analysis/visualization.py +++ b/graph_dit/analysis/visualization.py @@ -10,7 +10,41 @@ import numpy as np import rdkit.Chem import matplotlib.pyplot as plt +class GraphVisualization: + def __init__(self, dataset_infos): + self.dataset_infos = dataset_infos + def graph_from_graphs(self, node_list, adjency_matrix): + """ + Convert graphs to networkx graphs + node_list: the nodes of a batch of nodes (bs x n) + adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) + """ + graph = nx.Graph() + for i in range(len(node_list)): + if node_list[i] == -1: + continue + graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i]) + + rows, cols = np.where(adjency_matrix >= 1) + edges = zip(rows.tolist(), cols.tolist()) + for edge in edges: + edge_type = adjency_matrix[edge[0]][edge[1]] + graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type) + + return graph + + def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'): + # define path to save figures + if not os.path.exists(path): + os.makedirs(path) + + # visualize the final molecules + for i in range(num_graphs_to_visualize): + file_path = os.path.join(path, 'graph_{}.png'.format(i)) + graph = self.graph_from_graphs(graphs[i][0].numpy(), graphs[i][1].numpy()) + self.visualize_graph(graph=graph, pos=None, path=file_path) + im = plt.imread(file_path) class MolecularVisualization: def __init__(self, dataset_infos): self.dataset_infos = dataset_infos