update EdgeMetricsCE class
This commit is contained in:
		| @@ -23,7 +23,104 @@ def result_to_csv(path, dict_data): | |||||||
|             writer.writeheader() |             writer.writeheader() | ||||||
|         writer.writerow(dict_data) |         writer.writerow(dict_data) | ||||||
|  |  | ||||||
|  | class SamplingGraphMetrics(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |             self, | ||||||
|  |             dataset_infos, | ||||||
|  |             train_graphs, | ||||||
|  |             reference_graphs, | ||||||
|  |             n_jobs=1, | ||||||
|  |             device="cpu", | ||||||
|  |             batch_size=512, | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.task_name = dataset_infos.task | ||||||
|  |         self.dataset_infos = dataset_infos | ||||||
|  |         self.active_nodes = dataset_infos.active_nodes | ||||||
|  |         self.train_graphs = train_graphs | ||||||
|  |  | ||||||
|  |         self.stat_ref = None | ||||||
|  |  | ||||||
|  |         self.compute_config = { | ||||||
|  |             "n_jobs": n_jobs, | ||||||
|  |             "device": device, | ||||||
|  |             "batch_size": batch_size, | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         self.task_evaluator = { | ||||||
|  |             'meta_taskname': dataset_infos.task, | ||||||
|  |             'sas': None, | ||||||
|  |             'scs': None | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         for cur_task in dataset_infos.task.split("-")[:]: | ||||||
|  |             model_path = os.path.join( | ||||||
|  |                 dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib" | ||||||
|  |             ) | ||||||
|  |             os.makedirs(os.path.dirname(model_path), exist_ok=True) | ||||||
|  |             evaluator = TaskModel(model_path, cur_task) | ||||||
|  |             self.task_evaluator[cur_task] = evaluator | ||||||
|  |  | ||||||
|  |     def forward(self, graphs, targets, name, current_epoch, val_counter, test=False): | ||||||
|  |         if isinstance(targets, list): | ||||||
|  |             targets_cat = torch.cat(targets, dim=0) | ||||||
|  |             targets_np = targets_cat.detach().cpu().numpy() | ||||||
|  |         else: | ||||||
|  |             targets_np = targets.detach().cpu().numpy() | ||||||
|  |  | ||||||
|  |         unique_graphs, all_graphs, all_graphs, targets_log = compute_molecular_metrics( | ||||||
|  |             graphs, | ||||||
|  |             targets_np, | ||||||
|  |             self.train_graphs, | ||||||
|  |             self.stat_ref, | ||||||
|  |             self.dataset_infos, | ||||||
|  |             self.task_evaluator, | ||||||
|  |             self.compute_config, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         if test: | ||||||
|  |             file_name = "final_graphs.txt" | ||||||
|  |             with open(file_name, "w") as fp: | ||||||
|  |                 all_tasks_name = list(self.task_evaluator.keys()) | ||||||
|  |                 all_tasks_name = all_tasks_name.copy() | ||||||
|  |                 if 'meta_taskname' in all_tasks_name: | ||||||
|  |                     all_tasks_name.remove('meta_taskname') | ||||||
|  |  | ||||||
|  |                 all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) | ||||||
|  |                 fp.write(all_tasks_str + "\n") | ||||||
|  |                 for i, graph in enumerate(all_graphs): | ||||||
|  |                     if targets_log is not None: | ||||||
|  |                         all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) | ||||||
|  |                         fp.write(all_result_str + "\n") | ||||||
|  |                     else: | ||||||
|  |                         fp.write("%s\n" % graph) | ||||||
|  |                 print("All graphs saved") | ||||||
|  |         else: | ||||||
|  |             result_path = os.path.join(os.getcwd(), f"graphs/{name}") | ||||||
|  |             os.makedirs(result_path, exist_ok=True) | ||||||
|  |             text_path = os.path.join( | ||||||
|  |                 result_path, | ||||||
|  |                 f"valid_unique_graphs_e{current_epoch}_b{val_counter}.txt", | ||||||
|  |             ) | ||||||
|  |             textfile = open(text_path, "w") | ||||||
|  |             for graph in unique_graphs: | ||||||
|  |                 textfile.write(graph + "\n") | ||||||
|  |             textfile.close() | ||||||
|  |          | ||||||
|  |         all_logs = all_graphs | ||||||
|  |         if test: | ||||||
|  |             all_logs["log_name"] = "test" | ||||||
|  |         else: | ||||||
|  |             all_logs["log_name"] = ( | ||||||
|  |                 "epoch" + str(current_epoch) + "_batch" + str(val_counter) | ||||||
|  |             ) | ||||||
|  |          | ||||||
|  |         result_to_csv("output.csv", all_logs) | ||||||
|  |         return all_graphs | ||||||
|  |      | ||||||
|  |     def reset(self): | ||||||
|  |         pass | ||||||
|  |              | ||||||
| class SamplingMolecularMetrics(nn.Module): | class SamplingMolecularMetrics(nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
| @@ -40,21 +137,21 @@ class SamplingMolecularMetrics(nn.Module): | |||||||
|         self.active_atoms = dataset_infos.active_atoms |         self.active_atoms = dataset_infos.active_atoms | ||||||
|         self.train_smiles = train_smiles |         self.train_smiles = train_smiles | ||||||
|  |  | ||||||
|         if reference_smiles is not None: |         # if reference_smiles is not None: | ||||||
|             print( |         #     print( | ||||||
|                 f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" |         #         f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" | ||||||
|             ) |         #     ) | ||||||
|             start_time = time.time() |         #     start_time = time.time() | ||||||
|             self.stat_ref = compute_intermediate_statistics( |         #     self.stat_ref = compute_intermediate_statistics( | ||||||
|                 reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size |         #         reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size | ||||||
|             ) |         #     ) | ||||||
|             end_time = time.time() |         #     end_time = time.time() | ||||||
|             elapsed_time = end_time - start_time |         #     elapsed_time = end_time - start_time | ||||||
|             print( |         #     print( | ||||||
|                 f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" |         #         f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" | ||||||
|             ) |         #     ) | ||||||
|         else: |         # else: | ||||||
|             self.stat_ref = None |         self.stat_ref = None | ||||||
|      |      | ||||||
|         self.comput_config = { |         self.comput_config = { | ||||||
|             "n_jobs": n_jobs, |             "n_jobs": n_jobs, | ||||||
|   | |||||||
| @@ -77,6 +77,15 @@ class NodeMetricsCE(MetricCollection): | |||||||
|  |  | ||||||
|         for i, node_type in enumerate(active_nodes) : |         for i, node_type in enumerate(active_nodes) : | ||||||
|             metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i)) |             metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i)) | ||||||
|  |         super().__init__(metrics_list) | ||||||
|  |  | ||||||
|  | class EdgeMetricsCE(MetricCollection): | ||||||
|  |     def __init__(self): | ||||||
|  |         ce_no_bond = NoBondCE(0) | ||||||
|  |         ce_SI = SingleCE(1) | ||||||
|  |         ce_DO = DoubleCE(2) | ||||||
|  |         ce_TR = TripleCE(3) | ||||||
|  |         super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) | ||||||
|  |  | ||||||
| class AtomMetricsCE(MetricCollection): | class AtomMetricsCE(MetricCollection): | ||||||
|     def __init__(self, active_atoms): |     def __init__(self, active_atoms): | ||||||
| @@ -101,6 +110,41 @@ class BondMetricsCE(MetricCollection): | |||||||
| class TrainGraphMetricsDiscrete(nn.Module): | class TrainGraphMetricsDiscrete(nn.Module): | ||||||
|     def __init__(self, dataset_infos): |     def __init__(self, dataset_infos): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |         active_nodes = dataset_infos.active_nodes | ||||||
|  |         self.train_node_metrics = NodeMetricsCE(active_nodes=active_nodes) | ||||||
|  |         self.train_edge_metrics = EdgeMetricsCE() | ||||||
|  |  | ||||||
|  |     def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool): | ||||||
|  |         self.train_node_metrics(masked_pred_X, true_X) | ||||||
|  |         self.train_edge_metrics(masked_pred_E, true_E) | ||||||
|  |         if log: | ||||||
|  |             to_log = {} | ||||||
|  |             for key, val in self.train_node_metrics.compute().items(): | ||||||
|  |                 to_log['train/' + key] = val.item() | ||||||
|  |             for key, val in self.train_edge_metrics.compute().items(): | ||||||
|  |                 to_log['train/' + key] = val.item() | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         for metric in [self.train_node_metrics, self.train_edge_metrics]: | ||||||
|  |             metric.reset() | ||||||
|  |  | ||||||
|  |     def log_epoch_metrics(self, current_epoch, log=True): | ||||||
|  |         epoch_node_metrics = self.train_node_metrics.compute() | ||||||
|  |         epoch_edge_metrics = self.train_edge_metrics.compute() | ||||||
|  |  | ||||||
|  |         to_log = {} | ||||||
|  |         for key, val in epoch_node_metrics.items(): | ||||||
|  |             to_log['train_epoch/' + key] = val.item() | ||||||
|  |         for key, val in epoch_edge_metrics.items(): | ||||||
|  |             to_log['train_epoch/' + key] = val.item() | ||||||
|  |  | ||||||
|  |         for key, val in epoch_node_metrics.items(): | ||||||
|  |             epoch_node_metrics[key] = round(val.item(),4) | ||||||
|  |         for key, val in epoch_edge_metrics.items(): | ||||||
|  |             epoch_edge_metrics[key] = round(val.item(),4) | ||||||
|  |  | ||||||
|  |         if log: | ||||||
|  |             print(f"Epoch {current_epoch}: {epoch_node_metrics} -- {epoch_edge_metrics}") | ||||||
|  |  | ||||||
| class TrainMolecularMetricsDiscrete(nn.Module): | class TrainMolecularMetricsDiscrete(nn.Module): | ||||||
|     def __init__(self, dataset_infos): |     def __init__(self, dataset_infos): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user