rewrite to graph metrics
This commit is contained in:
		| @@ -35,7 +35,13 @@ class CEPerClass(Metric): | ||||
|  | ||||
|     def compute(self): | ||||
|         return self.total_ce / self.total_samples | ||||
| class NodeCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
| class EdgeCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
| class AtomCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
| @@ -65,6 +71,12 @@ class AromaticCE(CEPerClass): | ||||
|     def __init__(self, i): | ||||
|         super().__init__(i) | ||||
|  | ||||
| class NodeMetricsCE(MetricCollection): | ||||
|     def __init__(self, active_nodes): | ||||
|         metrics_list = [] | ||||
|  | ||||
|         for i, node_type in enumerate(active_nodes) : | ||||
|             metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i)) | ||||
|  | ||||
| class AtomMetricsCE(MetricCollection): | ||||
|     def __init__(self, active_atoms): | ||||
| @@ -84,7 +96,12 @@ class BondMetricsCE(MetricCollection): | ||||
|         ce_TR = TripleCE(3) | ||||
|         super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR]) | ||||
|  | ||||
| #  | ||||
| # | ||||
|  | ||||
| class TrainGraphMetricsDiscrete(nn.Module): | ||||
|     def __init__(self, dataset_infos): | ||||
|         super().__init__() | ||||
|  | ||||
| class TrainMolecularMetricsDiscrete(nn.Module): | ||||
|     def __init__(self, dataset_infos): | ||||
|         super().__init__() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user