59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
import networkx as nx
|
|
from .structure_evaluator import mmd_eval
|
|
from .gin_evaluator import nn_based_eval
|
|
from torch_geometric.utils import to_networkx
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import dgl
|
|
|
|
|
|
def get_stats_eval(config):
|
|
|
|
if config.eval.mmd_distance.lower() == 'rbf':
|
|
method = [('degree', 1., 'argmax'), ('cluster', 0.1, 'argmax'),
|
|
('spectral', 1., 'argmax')]
|
|
else:
|
|
raise ValueError
|
|
|
|
def eval_stats_fn(test_dataset, pred_graph_list):
|
|
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
|
|
sub_pred_G = []
|
|
if config.eval.max_subgraph:
|
|
for G in pred_G:
|
|
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
|
|
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
|
sub_pred_G += [CGs[0]]
|
|
pred_G = sub_pred_G
|
|
|
|
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
|
|
for i in range(len(test_dataset))]
|
|
results = mmd_eval(test_G, pred_G, method)
|
|
return results
|
|
|
|
return eval_stats_fn
|
|
|
|
|
|
def get_nn_eval(config):
|
|
|
|
if hasattr(config.eval, "N_gin"):
|
|
N_gin = config.eval.N_gin
|
|
else:
|
|
N_gin = 10
|
|
|
|
def nn_eval_fn(test_dataset, pred_graph_list):
|
|
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
|
|
sub_pred_G = []
|
|
if config.eval.max_subgraph:
|
|
for G in pred_G:
|
|
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
|
|
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
|
sub_pred_G += [CGs[0]]
|
|
pred_G = sub_pred_G
|
|
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
|
|
for i in range(len(test_dataset))]
|
|
|
|
results = nn_based_eval(test_G, pred_G, N_gin)
|
|
return results
|
|
|
|
return nn_eval_fn
|