"""Evaluation on random GIN features. Modified from https://github.com/uoguelph-mlrg/GGM-metrics""" import torch import numpy as np import sklearn import sklearn.metrics from sklearn.preprocessing import StandardScaler import time import dgl from .gin import GIN def load_feature_extractor( device, num_layers=3, hidden_dim=35, neighbor_pooling_type='sum', graph_pooling_type='sum', input_dim=1, edge_feat_dim=0, dont_concat=False, num_mlp_layers=2, output_dim=1, node_feat_loc='attr', edge_feat_loc='attr', init='orthogonal', **kwargs): model = GIN(num_layers=num_layers, hidden_dim=hidden_dim, neighbor_pooling_type=neighbor_pooling_type, graph_pooling_type=graph_pooling_type, input_dim=input_dim, edge_feat_dim=edge_feat_dim, num_mlp_layers=num_mlp_layers, output_dim=output_dim, init=init) model.node_feat_loc = node_feat_loc model.edge_feat_loc = edge_feat_loc model.eval() if dont_concat: model.forward = model.get_graph_embed_no_cat else: model.forward = model.get_graph_embed model.device = device return model.to(device) def time_function(func): def wrapper(*args, **kwargs): start = time.time() results = func(*args, **kwargs) end = time.time() return results, end - start return wrapper class GINMetric(): def __init__(self, model): self.feat_extractor = model self.get_activations = self.get_activations_gin @time_function def get_activations_gin(self, generated_dataset, reference_dataset): return self._get_activations(generated_dataset, reference_dataset) def _get_activations(self, generated_dataset, reference_dataset): gen_activations = self.__get_activations_single_dataset(generated_dataset) ref_activations = self.__get_activations_single_dataset(reference_dataset) scaler = StandardScaler() scaler.fit(ref_activations) ref_activations = scaler.transform(ref_activations) gen_activations = scaler.transform(gen_activations) return gen_activations, ref_activations def __get_activations_single_dataset(self, dataset): node_feat_loc = self.feat_extractor.node_feat_loc edge_feat_loc = self.feat_extractor.edge_feat_loc ndata = [node_feat_loc] if node_feat_loc in dataset[0].ndata else '__ALL__' edata = [edge_feat_loc] if edge_feat_loc in dataset[0].edata else '__ALL__' graphs = dgl.batch(dataset, ndata=ndata, edata=edata).to(self.feat_extractor.device) if node_feat_loc not in graphs.ndata: # Use degree as features feats = graphs.in_degrees() + graphs.out_degrees() feats = feats.unsqueeze(1).type(torch.float32) else: feats = graphs.ndata[node_feat_loc] graph_embeds = self.feat_extractor(graphs, feats) return graph_embeds.cpu().detach().numpy() def evaluate(self, *args, **kwargs): raise Exception('Must be implemented by child class') class MMDEvaluation(GINMetric): def __init__(self, model, kernel='rbf', sigma='range', multiplier='mean'): super().__init__(model) if multiplier == 'mean': self.__get_sigma_mult_factor = self.__mean_pairwise_distance elif multiplier == 'median': self.__get_sigma_mult_factor = self.__median_pairwise_distance elif multiplier is None: self.__get_sigma_mult_factor = lambda *args, **kwargs: 1 else: raise Exception(multiplier) if 'rbf' in kernel: if sigma == 'range': self.base_sigmas = np.array([0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]) if multiplier == 'mean': self.name = 'mmd_rbf' elif multiplier == 'median': self.name = 'mmd_rbf_adaptive_median' else: self.name = 'mmd_rbf_adaptive' elif sigma == 'one': self.base_sigmas = np.array([1]) if multiplier == 'mean': self.name = 'mmd_rbf_single_mean' elif multiplier == 'median': self.name = 'mmd_rbf_single_median' else: self.name = 'mmd_rbf_single' else: raise Exception(sigma) self.evaluate = self.calculate_MMD_rbf_quadratic elif 'linear' in kernel: self.evaluate = self.calculate_MMD_linear_kernel else: raise Exception() def __get_pairwise_distances(self, generated_dataset, reference_dataset): return sklearn.metrics.pairwise_distances(reference_dataset, generated_dataset, metric='euclidean', n_jobs=8)**2 def __mean_pairwise_distance(self, dists_GR): return np.sqrt(dists_GR.mean()) def __median_pairwise_distance(self, dists_GR): return np.sqrt(np.median(dists_GR)) def get_sigmas(self, dists_GR): mult_factor = self.__get_sigma_mult_factor(dists_GR) return self.base_sigmas * mult_factor @time_function def calculate_MMD_rbf_quadratic(self, generated_dataset=None, reference_dataset=None): # https://github.com/djsutherland/opt-mmd/blob/master/two_sample/mmd.py if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray): (generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset) GG = self.__get_pairwise_distances(generated_dataset, generated_dataset) GR = self.__get_pairwise_distances(generated_dataset, reference_dataset) RR = self.__get_pairwise_distances(reference_dataset, reference_dataset) max_mmd = 0 sigmas = self.get_sigmas(GR) for sigma in sigmas: gamma = 1 / (2 * sigma**2) K_GR = np.exp(-gamma * GR) K_GG = np.exp(-gamma * GG) K_RR = np.exp(-gamma * RR) mmd = K_GG.mean() + K_RR.mean() - 2 * K_GR.mean() max_mmd = mmd if mmd > max_mmd else max_mmd return {self.name: max_mmd} @time_function def calculate_MMD_linear_kernel(self, generated_dataset=None, reference_dataset=None): # https://github.com/djsutherland/opt-mmd/blob/master/two_sample/mmd.py if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray): (generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset) G_bar = generated_dataset.mean(axis=0) R_bar = reference_dataset.mean(axis=0) Z_bar = G_bar - R_bar mmd = Z_bar.dot(Z_bar) mmd = mmd if mmd >= 0 else 0 return {'mmd_linear': mmd} class prdcEvaluation(GINMetric): # From PRDC github: https://github.com/clovaai/generative-evaluation-prdc/blob/master/prdc/prdc.py#L54 def __init__(self, *args, use_pr=False, **kwargs): super().__init__(*args, **kwargs) self.use_pr = use_pr @time_function def evaluate(self, generated_dataset=None, reference_dataset=None, nearest_k=5): """ Computes precision, recall, density, and coverage given two manifolds. """ if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray): (generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset) real_nearest_neighbour_distances = self.__compute_nearest_neighbour_distances(reference_dataset, nearest_k) distance_real_fake = self.__compute_pairwise_distance(reference_dataset, generated_dataset) if self.use_pr: fake_nearest_neighbour_distances = self.__compute_nearest_neighbour_distances(generated_dataset, nearest_k) precision = ( distance_real_fake <= np.expand_dims(real_nearest_neighbour_distances, axis=1) ).any(axis=0).mean() recall = ( distance_real_fake <= np.expand_dims(fake_nearest_neighbour_distances, axis=0) ).any(axis=1).mean() f1_pr = 2 / ((1 / (precision + 1e-8)) + (1 / (recall + 1e-8))) result = dict(precision=precision, recall=recall, f1_pr=f1_pr) else: density = (1. / float(nearest_k)) * ( distance_real_fake <= np.expand_dims(real_nearest_neighbour_distances, axis=1)).sum(axis=0).mean() coverage = (distance_real_fake.min(axis=1) <= real_nearest_neighbour_distances).mean() f1_dc = 2 / ((1 / (density + 1e-8)) + (1 / (coverage + 1e-8))) result = dict(density=density, coverage=coverage, f1_dc=f1_dc) return result def __compute_pairwise_distance(self, data_x, data_y=None): """ Args: data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) Return: numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. """ if data_y is None: data_y = data_x dists = sklearn.metrics.pairwise_distances(data_x, data_y, metric='euclidean', n_jobs=8) return dists def __get_kth_value(self, unsorted, k, axis=-1): """ Args: unsorted: numpy.ndarray of any dimensionality. k: int Return: kth values along the designated axis. """ indices = np.argpartition(unsorted, k, axis=axis)[..., :k] k_smallest = np.take_along_axis(unsorted, indices, axis=axis) kth_values = k_smallest.max(axis=axis) return kth_values def __compute_nearest_neighbour_distances(self, input_features, nearest_k): """ Args: input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) nearest_k: int Return: Distances to kth nearest neighbours. """ distances = self.__compute_pairwise_distance(input_features) radii = self.__get_kth_value(distances, k=nearest_k + 1, axis=-1) return radii def nn_based_eval(graph_ref_list, graph_pred_list, N_gin=10): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') evaluators = [] for _ in range(N_gin): gin = load_feature_extractor(device) evaluators.append(MMDEvaluation(model=gin, kernel='rbf', sigma='range', multiplier='mean')) evaluators.append(prdcEvaluation(model=gin, use_pr=True)) evaluators.append(prdcEvaluation(model=gin, use_pr=False)) ref_graphs = [dgl.from_networkx(g).to(device) for g in graph_ref_list] gen_graphs = [dgl.from_networkx(g).to(device) for g in graph_pred_list] metrics = { 'mmd_rbf': [], 'f1_pr': [], 'f1_dc': [] } for evaluator in evaluators: res, time = evaluator.evaluate(generated_dataset=gen_graphs, reference_dataset=ref_graphs) for key in list(res.keys()): if key in metrics: metrics[key].append(res[key]) results = { 'MMD_RBF': (np.mean(metrics['mmd_rbf']), np.std(metrics['mmd_rbf'])), 'F1_PR': (np.mean(metrics['f1_pr']), np.std(metrics['f1_pr'])), 'F1_DC': (np.mean(metrics['f1_dc']), np.std(metrics['f1_dc'])) } return results