rewrite to graph metrics
This commit is contained in:
parent
a7f7010da7
commit
222470a43c
@ -35,7 +35,13 @@ class CEPerClass(Metric):
|
|||||||
|
|
||||||
def compute(self):
|
def compute(self):
|
||||||
return self.total_ce / self.total_samples
|
return self.total_ce / self.total_samples
|
||||||
|
class NodeCE(CEPerClass):
|
||||||
|
def __init__(self, i):
|
||||||
|
super().__init__(i)
|
||||||
|
|
||||||
|
class EdgeCE(CEPerClass):
|
||||||
|
def __init__(self, i):
|
||||||
|
super().__init__(i)
|
||||||
|
|
||||||
class AtomCE(CEPerClass):
|
class AtomCE(CEPerClass):
|
||||||
def __init__(self, i):
|
def __init__(self, i):
|
||||||
@ -65,6 +71,12 @@ class AromaticCE(CEPerClass):
|
|||||||
def __init__(self, i):
|
def __init__(self, i):
|
||||||
super().__init__(i)
|
super().__init__(i)
|
||||||
|
|
||||||
|
class NodeMetricsCE(MetricCollection):
|
||||||
|
def __init__(self, active_nodes):
|
||||||
|
metrics_list = []
|
||||||
|
|
||||||
|
for i, node_type in enumerate(active_nodes) :
|
||||||
|
metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i))
|
||||||
|
|
||||||
class AtomMetricsCE(MetricCollection):
|
class AtomMetricsCE(MetricCollection):
|
||||||
def __init__(self, active_atoms):
|
def __init__(self, active_atoms):
|
||||||
@ -85,6 +97,11 @@ class BondMetricsCE(MetricCollection):
|
|||||||
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
|
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
||||||
|
class TrainGraphMetricsDiscrete(nn.Module):
|
||||||
|
def __init__(self, dataset_infos):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
class TrainMolecularMetricsDiscrete(nn.Module):
|
class TrainMolecularMetricsDiscrete(nn.Module):
|
||||||
def __init__(self, dataset_infos):
|
def __init__(self, dataset_infos):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user