diff --git a/graph_dit/metrics/molecular_metrics_sampling.py b/graph_dit/metrics/molecular_metrics_sampling.py index 4eae271..159a22b 100644 --- a/graph_dit/metrics/molecular_metrics_sampling.py +++ b/graph_dit/metrics/molecular_metrics_sampling.py @@ -23,7 +23,104 @@ def result_to_csv(path, dict_data): writer.writeheader() writer.writerow(dict_data) +class SamplingGraphMetrics(nn.Module): + def __init__( + self, + dataset_infos, + train_graphs, + reference_graphs, + n_jobs=1, + device="cpu", + batch_size=512, + ): + super().__init__() + self.task_name = dataset_infos.task + self.dataset_infos = dataset_infos + self.active_nodes = dataset_infos.active_nodes + self.train_graphs = train_graphs + self.stat_ref = None + + self.compute_config = { + "n_jobs": n_jobs, + "device": device, + "batch_size": batch_size, + } + + self.task_evaluator = { + 'meta_taskname': dataset_infos.task, + 'sas': None, + 'scs': None + } + + for cur_task in dataset_infos.task.split("-")[:]: + model_path = os.path.join( + dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib" + ) + os.makedirs(os.path.dirname(model_path), exist_ok=True) + evaluator = TaskModel(model_path, cur_task) + self.task_evaluator[cur_task] = evaluator + + def forward(self, graphs, targets, name, current_epoch, val_counter, test=False): + if isinstance(targets, list): + targets_cat = torch.cat(targets, dim=0) + targets_np = targets_cat.detach().cpu().numpy() + else: + targets_np = targets.detach().cpu().numpy() + + unique_graphs, all_graphs, all_graphs, targets_log = compute_molecular_metrics( + graphs, + targets_np, + self.train_graphs, + self.stat_ref, + self.dataset_infos, + self.task_evaluator, + self.compute_config, + ) + + if test: + file_name = "final_graphs.txt" + with open(file_name, "w") as fp: + all_tasks_name = list(self.task_evaluator.keys()) + all_tasks_name = all_tasks_name.copy() + if 'meta_taskname' in all_tasks_name: + all_tasks_name.remove('meta_taskname') + + all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) + fp.write(all_tasks_str + "\n") + for i, graph in enumerate(all_graphs): + if targets_log is not None: + all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) + fp.write(all_result_str + "\n") + else: + fp.write("%s\n" % graph) + print("All graphs saved") + else: + result_path = os.path.join(os.getcwd(), f"graphs/{name}") + os.makedirs(result_path, exist_ok=True) + text_path = os.path.join( + result_path, + f"valid_unique_graphs_e{current_epoch}_b{val_counter}.txt", + ) + textfile = open(text_path, "w") + for graph in unique_graphs: + textfile.write(graph + "\n") + textfile.close() + + all_logs = all_graphs + if test: + all_logs["log_name"] = "test" + else: + all_logs["log_name"] = ( + "epoch" + str(current_epoch) + "_batch" + str(val_counter) + ) + + result_to_csv("output.csv", all_logs) + return all_graphs + + def reset(self): + pass + class SamplingMolecularMetrics(nn.Module): def __init__( self, @@ -40,21 +137,21 @@ class SamplingMolecularMetrics(nn.Module): self.active_atoms = dataset_infos.active_atoms self.train_smiles = train_smiles - if reference_smiles is not None: - print( - f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" - ) - start_time = time.time() - self.stat_ref = compute_intermediate_statistics( - reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size - ) - end_time = time.time() - elapsed_time = end_time - start_time - print( - f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" - ) - else: - self.stat_ref = None + # if reference_smiles is not None: + # print( + # f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" + # ) + # start_time = time.time() + # self.stat_ref = compute_intermediate_statistics( + # reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size + # ) + # end_time = time.time() + # elapsed_time = end_time - start_time + # print( + # f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" + # ) + # else: + self.stat_ref = None self.comput_config = { "n_jobs": n_jobs, diff --git a/graph_dit/metrics/molecular_metrics_train.py b/graph_dit/metrics/molecular_metrics_train.py index c5fd96c..f141fb1 100644 --- a/graph_dit/metrics/molecular_metrics_train.py +++ b/graph_dit/metrics/molecular_metrics_train.py @@ -77,6 +77,15 @@ class NodeMetricsCE(MetricCollection): for i, node_type in enumerate(active_nodes) : metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i)) + super().__init__(metrics_list) + +class EdgeMetricsCE(MetricCollection): + def __init__(self): + ce_no_bond = NoBondCE(0) + ce_SI = SingleCE(1) + ce_DO = DoubleCE(2) + ce_TR = TripleCE(3) + super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) class AtomMetricsCE(MetricCollection): def __init__(self, active_atoms): @@ -101,6 +110,41 @@ class BondMetricsCE(MetricCollection): class TrainGraphMetricsDiscrete(nn.Module): def __init__(self, dataset_infos): super().__init__() + active_nodes = dataset_infos.active_nodes + self.train_node_metrics = NodeMetricsCE(active_nodes=active_nodes) + self.train_edge_metrics = EdgeMetricsCE() + + def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): + self.train_node_metrics(masked_pred_X, true_X) + self.train_edge_metrics(masked_pred_E, true_E) + if log: + to_log = {} + for key, val in self.train_node_metrics.compute().items(): + to_log['train/' + key] = val.item() + for key, val in self.train_edge_metrics.compute().items(): + to_log['train/' + key] = val.item() + + def reset(self): + for metric in [self.train_node_metrics, self.train_edge_metrics]: + metric.reset() + + def log_epoch_metrics(self, current_epoch, log=True): + epoch_node_metrics = self.train_node_metrics.compute() + epoch_edge_metrics = self.train_edge_metrics.compute() + + to_log = {} + for key, val in epoch_node_metrics.items(): + to_log['train_epoch/' + key] = val.item() + for key, val in epoch_edge_metrics.items(): + to_log['train_epoch/' + key] = val.item() + + for key, val in epoch_node_metrics.items(): + epoch_node_metrics[key] = round(val.item(),4) + for key, val in epoch_edge_metrics.items(): + epoch_edge_metrics[key] = round(val.item(),4) + + if log: + print(f"Epoch {current_epoch}: {epoch_node_metrics} -- {epoch_edge_metrics}") class TrainMolecularMetricsDiscrete(nn.Module): def __init__(self, dataset_infos):