update config and remove the randomforest judging
This commit is contained in:
parent
b47e83330e
commit
d800679b88
@ -2,7 +2,7 @@ general:
|
||||
name: 'graph_dit'
|
||||
wandb: 'disabled'
|
||||
gpus: 1
|
||||
gpu_number: 3
|
||||
gpu_number: 0
|
||||
resume: null
|
||||
test_only: null
|
||||
sample_every_val: 2500
|
||||
@ -31,7 +31,7 @@ model:
|
||||
lambda_train: [1, 10] # node and edge training weight
|
||||
ensure_connected: True
|
||||
train:
|
||||
n_epochs: 5000
|
||||
n_epochs: 500
|
||||
batch_size: 1200
|
||||
lr: 0.0002
|
||||
clip_grad: null
|
||||
|
@ -37,6 +37,144 @@ def selectivity_evaluation(gas1, gas2, prop_name):
|
||||
y = np.log10(np.array(gas1) / np.array(gas2))
|
||||
upper = (y - (a_dict[prop_name] * x + b_dict[prop_name])) > 0
|
||||
return upper
|
||||
class BasicGraphMetrics(object):
|
||||
def __init__(self, graph_decoder, train_graphs=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512):
|
||||
self.dataset_graphs_list = train_graphs
|
||||
self.graph_decoder = graph_decoder
|
||||
self.n_jobs = n_jobs
|
||||
self.device = device
|
||||
self.batch_size = batch_size
|
||||
self.stat_ref = stat_ref
|
||||
self.task_evaluator = task_evaluator
|
||||
def compute_relaxed_validity(self, generated, ensure_connected):
|
||||
valid = []
|
||||
num_components = []
|
||||
all_graphs = []
|
||||
valid_graphs = []
|
||||
covered_nodes = set()
|
||||
direct_valid_count = 0
|
||||
print(f"generated number: {len(generated)}")
|
||||
for graph in generated:
|
||||
node_types, edge_types = graph
|
||||
direct_valid_flag = True
|
||||
direct_valid_count += 1
|
||||
valid.append(graph)
|
||||
num_components.append(1)
|
||||
covered_nodes.update(set(node_types))
|
||||
all_graphs.append(graph)
|
||||
return valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_graphs, covered_nodes
|
||||
|
||||
def evaluate(self, generated, targets, ensure_connected, active_atoms=None):
|
||||
valid, validity, nc_validity, num_components, all_graphs, covered_nodes = self.compute_relaxed_validity(generated, ensure_connected=ensure_connected)
|
||||
nc_mu = num_components.mean() if len(num_components) > 0 else 0
|
||||
nc_min = num_components.min() if len(num_components) > 0 else 0
|
||||
nc_max = num_components.max() if len(num_components) > 0 else 0
|
||||
|
||||
len_active = len(active_atoms) if active_atoms is not None else 1
|
||||
|
||||
cover_str = f"Cover {len(covered_nodes)} ({len(covered_nodes)/len_active * 100:.2f}%) atoms: {covered_nodes}"
|
||||
print(f"Validity over {len(generated)} graphs: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_nodes)} ({len(covered_nodes)/len_active * 100:.2f}%) nodes: {covered_nodes}")
|
||||
print(f"Number of connected components of {len(generated)} graphs: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")
|
||||
|
||||
if validity > 0:
|
||||
dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity}
|
||||
unique = valid
|
||||
close_pool = False
|
||||
if self.n_jobs != 1:
|
||||
pool = Pool(self.n_jobs)
|
||||
close_pool = True
|
||||
else:
|
||||
pool = 1
|
||||
# valid_graphs = mapper(pool)(get_mol, valid)
|
||||
valid_graphs = valid
|
||||
"""
|
||||
Computes internal diversity as:
|
||||
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
|
||||
"""
|
||||
# dist_metrics['interval_diversity'] = internal_diversity(valid_graphs, pool, device=self.device)
|
||||
|
||||
start_time = time.time()
|
||||
if self.stat_ref is not None:
|
||||
kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size}
|
||||
kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size}
|
||||
try:
|
||||
dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_graphs, pref=self.stat_ref['Frag'])
|
||||
except:
|
||||
print('error: ', 'pool', pool)
|
||||
print('valid_graphs: ', valid_graphs)
|
||||
dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD'])
|
||||
|
||||
if self.task_evaluator is not None:
|
||||
evaluation_list = list(self.task_evaluator.keys())
|
||||
print('evaluation_list: ', evaluation_list)
|
||||
evaluation_list = evaluation_list.copy()
|
||||
|
||||
assert 'meta_taskname' in evaluation_list
|
||||
meta_taskname = self.task_evaluator['meta_taskname']
|
||||
evaluation_list.remove('meta_taskname')
|
||||
# meta_split = meta_taskname.split('-')
|
||||
|
||||
valid_index = np.array([True if graphs else False for graphs in all_graphs])
|
||||
targets_log = {}
|
||||
for i, name in enumerate(evaluation_list):
|
||||
targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index))
|
||||
targets_log[f'input_{name}'] = targets[:, i]
|
||||
|
||||
targets = targets[valid_index]
|
||||
# if len(meta_split) == 2:
|
||||
# cached_perm = {meta_split[0]: None, meta_split[1]: None}
|
||||
|
||||
for i, name in enumerate(evaluation_list):
|
||||
# if name == 'scs':
|
||||
# continue
|
||||
# elif name == 'sas':
|
||||
# scores = calculateSAS(valid)
|
||||
# else:
|
||||
# scores = self.task_evaluator[name](valid)
|
||||
# fix the scores
|
||||
scores = np.random.rand(len(valid_index))
|
||||
targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index))
|
||||
targets_log[f'output_{name}'][valid_index] = scores
|
||||
# if name in ['O2', 'N2', 'CO2']:
|
||||
# if len(meta_split) == 2:
|
||||
# cached_perm[name] = scores
|
||||
# scores, cur_targets = np.log10(scores), np.log10(targets[:, i])
|
||||
# dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets))
|
||||
# elif name == 'sas':
|
||||
# dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i]))
|
||||
# else:
|
||||
true_y = targets[:, i]
|
||||
predicted_labels = (scores >= 0.5).astype(int)
|
||||
acc = (predicted_labels == true_y).sum() / len(true_y)
|
||||
dist_metrics[f'{name}/acc'] = acc
|
||||
|
||||
# if len(meta_split) == 2:
|
||||
# if cached_perm[meta_split[0]] is not None and cached_perm[meta_split[1]] is not None:
|
||||
# task_name = self.task_evaluator['meta_taskname']
|
||||
# upper = selectivity_evaluation(cached_perm[meta_split[0]], cached_perm[meta_split[1]], task_name)
|
||||
# dist_metrics[f'selectivity/{task_name}'] = np.sum(upper)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
max_key_length = max(len(key) for key in dist_metrics)
|
||||
print(f'Details over {len(valid)} ({len(generated)}) valid (total) graphs, calculating metrics using {elapsed_time:.2f} s:')
|
||||
strs = ''
|
||||
for i, (key, value) in enumerate(dist_metrics.items()):
|
||||
if isinstance(value, (int, float, np.floating, np.integer)):
|
||||
strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t'
|
||||
if i % 4 == 3:
|
||||
strs = strs + '\n'
|
||||
print(strs)
|
||||
|
||||
if close_pool:
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
unique = []
|
||||
dist_metrics = {}
|
||||
targets_log = None
|
||||
return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_graphs, dist_metrics, targets_log
|
||||
|
||||
|
||||
class BasicMolecularMetrics(object):
|
||||
def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512):
|
||||
@ -388,6 +526,18 @@ def connect_fragments(mol):
|
||||
return combined_mol
|
||||
|
||||
#### connect fragements
|
||||
def compute_graph_metrics(graph_list, targets, train_graphs, stat_ref, dataset_info, task_evaluator, comput_config):
|
||||
""" graph_list: (dict) """
|
||||
node_decoder = dataset_info.node_decoder
|
||||
active_nodes = dataset_info.active_nodes
|
||||
ensure_connected = dataset_info.ensure_connected
|
||||
metrics = BasicGraphMetrics(node_decoder, train_graphs, stat_ref, task_evaluator, **comput_config)
|
||||
evaluated_res = metrics.evaluate(graph_list, targets, ensure_connected, active_nodes)
|
||||
all_graphs = evaluated_res[-3]
|
||||
all_metrics = evaluated_res[-2]
|
||||
targets_log = evaluated_res[-1]
|
||||
unique_graphs = evaluated_res[0]
|
||||
return unique_graphs, all_graphs, all_metrics, targets_log
|
||||
|
||||
def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config):
|
||||
""" molecule_list: (dict) """
|
||||
|
@ -10,7 +10,41 @@ import numpy as np
|
||||
import rdkit.Chem
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class GraphVisualization:
|
||||
def __init__(self, dataset_infos):
|
||||
self.dataset_infos = dataset_infos
|
||||
def graph_from_graphs(self, node_list, adjency_matrix):
|
||||
"""
|
||||
Convert graphs to networkx graphs
|
||||
node_list: the nodes of a batch of nodes (bs x n)
|
||||
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
||||
"""
|
||||
graph = nx.Graph()
|
||||
|
||||
for i in range(len(node_list)):
|
||||
if node_list[i] == -1:
|
||||
continue
|
||||
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
|
||||
|
||||
rows, cols = np.where(adjency_matrix >= 1)
|
||||
edges = zip(rows.tolist(), cols.tolist())
|
||||
for edge in edges:
|
||||
edge_type = adjency_matrix[edge[0]][edge[1]]
|
||||
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
|
||||
|
||||
return graph
|
||||
|
||||
def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'):
|
||||
# define path to save figures
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
# visualize the final molecules
|
||||
for i in range(num_graphs_to_visualize):
|
||||
file_path = os.path.join(path, 'graph_{}.png'.format(i))
|
||||
graph = self.graph_from_graphs(graphs[i][0].numpy(), graphs[i][1].numpy())
|
||||
self.visualize_graph(graph=graph, pos=None, path=file_path)
|
||||
im = plt.imread(file_path)
|
||||
class MolecularVisualization:
|
||||
def __init__(self, dataset_infos):
|
||||
self.dataset_infos = dataset_infos
|
||||
|
Loading…
Reference in New Issue
Block a user