diffusionNAG/MobileNetV3/evaluation/evaluator.py
2024-03-15 14:38:51 +00:00

59 lines
1.9 KiB
Python

import networkx as nx
from .structure_evaluator import mmd_eval
from .gin_evaluator import nn_based_eval
from torch_geometric.utils import to_networkx
import torch
import torch.nn.functional as F
import dgl
def get_stats_eval(config):
if config.eval.mmd_distance.lower() == 'rbf':
method = [('degree', 1., 'argmax'), ('cluster', 0.1, 'argmax'),
('spectral', 1., 'argmax')]
else:
raise ValueError
def eval_stats_fn(test_dataset, pred_graph_list):
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
sub_pred_G = []
if config.eval.max_subgraph:
for G in pred_G:
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
sub_pred_G += [CGs[0]]
pred_G = sub_pred_G
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
for i in range(len(test_dataset))]
results = mmd_eval(test_G, pred_G, method)
return results
return eval_stats_fn
def get_nn_eval(config):
if hasattr(config.eval, "N_gin"):
N_gin = config.eval.N_gin
else:
N_gin = 10
def nn_eval_fn(test_dataset, pred_graph_list):
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
sub_pred_G = []
if config.eval.max_subgraph:
for G in pred_G:
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
sub_pred_G += [CGs[0]]
pred_G = sub_pred_G
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
for i in range(len(test_dataset))]
results = nn_based_eval(test_G, pred_G, N_gin)
return results
return nn_eval_fn