naswot/scores.py

22 lines
337 B
Python
Raw Permalink Normal View History

2021-02-26 17:12:51 +01:00
import numpy as np
import torch
def hooklogdet(K, labels=None):
s, ld = np.linalg.slogdet(K)
return ld
def random_score(jacob, label=None):
return np.random.normal()
_scores = {
'hook_logdet': hooklogdet,
'random': random_score
}
def get_score_func(score_name):
return _scores[score_name]