From 222470a43cbe6aa12634905797f52da95a480f86 Mon Sep 17 00:00:00 2001 From: mhz Date: Thu, 27 Jun 2024 20:44:04 +0200 Subject: [PATCH] rewrite to graph metrics --- graph_dit/metrics/molecular_metrics_train.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/graph_dit/metrics/molecular_metrics_train.py b/graph_dit/metrics/molecular_metrics_train.py index 7b4fd0f..c5fd96c 100644 --- a/graph_dit/metrics/molecular_metrics_train.py +++ b/graph_dit/metrics/molecular_metrics_train.py @@ -35,7 +35,13 @@ class CEPerClass(Metric): def compute(self): return self.total_ce / self.total_samples +class NodeCE(CEPerClass): + def __init__(self, i): + super().__init__(i) +class EdgeCE(CEPerClass): + def __init__(self, i): + super().__init__(i) class AtomCE(CEPerClass): def __init__(self, i): @@ -65,6 +71,12 @@ class AromaticCE(CEPerClass): def __init__(self, i): super().__init__(i) +class NodeMetricsCE(MetricCollection): + def __init__(self, active_nodes): + metrics_list = [] + + for i, node_type in enumerate(active_nodes) : + metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i)) class AtomMetricsCE(MetricCollection): def __init__(self, active_atoms): @@ -84,7 +96,12 @@ class BondMetricsCE(MetricCollection): ce_TR = TripleCE(3) super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) -# +# + +class TrainGraphMetricsDiscrete(nn.Module): + def __init__(self, dataset_infos): + super().__init__() + class TrainMolecularMetricsDiscrete(nn.Module): def __init__(self, dataset_infos): super().__init__()