#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.03 #
#####################################################
# Reformulate the codes in https://github.com/CalculatedContent/WeightWatcher
#####################################################
import numpy as np
from typing import List
import torch.nn as nn
from collections import OrderedDict
from sklearn.decomposition import TruncatedSVD


def available_module_types():
    return (nn.Conv2d, nn.Linear)


def get_conv2D_Wmats(tensor: np.ndarray) -> List[np.ndarray]:
    """
    Extract W slices from a 4 index conv2D tensor of shape: (N,M,i,j) or (M,N,i,j).
    Return ij (N x M) matrices
    """
    mats = []
    N, M, imax, jmax = tensor.shape
    assert (
        N + M >= imax + jmax
    ), "invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)".format(
        N, M, imax, jmax
    )
    for i in range(imax):
        for j in range(jmax):
            w = tensor[:, :, i, j]
            if N < M:
                w = w.T
            mats.append(w)
    return mats


def glorot_norm_check(W, N, M, rf_size, lower=0.5, upper=1.5):
    """Check if this layer needs Glorot Normalization Fix"""

    kappa = np.sqrt(2 / ((N + M) * rf_size))
    norm = np.linalg.norm(W)

    check1 = norm / np.sqrt(N * M)
    check2 = norm / (kappa * np.sqrt(N * M))

    if (rf_size > 1) and (check2 > lower) and (check2 < upper):
        return check2, True
    elif (check1 > lower) & (check1 < upper):
        return check1, True
    else:
        if rf_size > 1:
            return check2, False
        else:
            return check1, False


def glorot_norm_fix(w, n, m, rf_size):
    """Apply Glorot Normalization Fix."""
    kappa = np.sqrt(2 / ((n + m) * rf_size))
    w = w / kappa
    return w


def analyze_weights(
    weights,
    min_size,
    max_size,
    alphas,
    lognorms,
    spectralnorms,
    softranks,
    normalize,
    glorot_fix,
):
    results = OrderedDict()
    count = len(weights)
    if count == 0:
        return results

    for i, weight in enumerate(weights):
        M, N = np.min(weight.shape), np.max(weight.shape)
        Q = N / M
        results[i] = cur_res = OrderedDict(N=N, M=M, Q=Q)
        check, checkTF = glorot_norm_check(weight, N, M, count)
        cur_res["check"] = check
        cur_res["checkTF"] = checkTF
        # assume receptive field size is count
        if glorot_fix:
            weight = glorot_norm_fix(weight, N, M, count)
        else:
            # probably never needed since we always fix for glorot
            weight = weight * np.sqrt(count / 2.0)

        if spectralnorms:  # spectralnorm is the max eigenvalues
            svd = TruncatedSVD(n_components=1, n_iter=7, random_state=10)
            svd.fit(weight)
            sv = svd.singular_values_
            sv_max = np.max(sv)
            if normalize:
                evals = sv * sv / N
            else:
                evals = sv * sv
            lambda0 = evals[0]
            cur_res["spectralnorm"] = lambda0
            cur_res["logspectralnorm"] = np.log10(lambda0)
        else:
            lambda0 = None

        if M < min_size:
            summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(
                i + 1, count, M, N, min_size
            )
            cur_res["summary"] = summary
            continue
        elif max_size > 0 and M > max_size:
            summary = (
                "Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(
                    i + 1, count, M, N, max_size
                )
            )
            cur_res["summary"] = summary
            continue
        else:
            summary = []
        if alphas:
            import powerlaw

            svd = TruncatedSVD(n_components=M - 1, n_iter=7, random_state=10)
            svd.fit(weight.astype(float))
            sv = svd.singular_values_
            if normalize:
                evals = sv * sv / N
            else:
                evals = sv * sv

            lambda_max = np.max(evals)
            fit = powerlaw.Fit(evals, xmax=lambda_max, verbose=False)
            alpha = fit.alpha
            cur_res["alpha"] = alpha
            D = fit.D
            cur_res["D"] = D
            cur_res["lambda_min"] = np.min(evals)
            cur_res["lambda_max"] = lambda_max
            alpha_weighted = alpha * np.log10(lambda_max)
            cur_res["alpha_weighted"] = alpha_weighted
            tolerance = lambda_max * M * np.finfo(np.max(sv)).eps
            cur_res["rank_loss"] = np.count_nonzero(sv > tolerance, axis=-1)

            logpnorm = np.log10(np.sum([ev ** alpha for ev in evals]))
            cur_res["logpnorm"] = logpnorm

            summary.append(
                "Weight matrix {}/{} ({},{}): Alpha: {}, Alpha Weighted: {}, D: {}, pNorm {}".format(
                    i + 1, count, M, N, alpha, alpha_weighted, D, logpnorm
                )
            )

        if lognorms:
            norm = np.linalg.norm(weight)  # Frobenius Norm
            cur_res["norm"] = norm
            lognorm = np.log10(norm)
            cur_res["lognorm"] = lognorm

            X = np.dot(weight.T, weight)
            if normalize:
                X = X / N
            normX = np.linalg.norm(X)  # Frobenius Norm
            cur_res["normX"] = normX
            lognormX = np.log10(normX)
            cur_res["lognormX"] = lognormX

            summary.append(
                "Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(
                    i + 1, count, M, N, lognorm, lognormX
                )
            )

            if softranks:
                softrank = norm ** 2 / sv_max ** 2
                softranklog = np.log10(softrank)
                softranklogratio = lognorm / np.log10(sv_max)
                cur_res["softrank"] = softrank
                cur_res["softranklog"] = softranklog
                cur_res["softranklogratio"] = softranklogratio
                summary += (
                    "{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(
                        summary, softrank, softranklog, softranklogratio
                    )
                )
        cur_res["summary"] = "\n".join(summary)
    return results


def compute_details(results):
    """
    Return a pandas data frame.
    """
    final_summary = OrderedDict()

    metrics = {
        # key in "results" : pretty print name
        "check": "Check",
        "checkTF": "CheckTF",
        "norm": "Norm",
        "lognorm": "LogNorm",
        "normX": "Norm X",
        "lognormX": "LogNorm X",
        "alpha": "Alpha",
        "alpha_weighted": "Alpha Weighted",
        "spectralnorm": "Spectral Norm",
        "logspectralnorm": "Log Spectral Norm",
        "softrank": "Softrank",
        "softranklog": "Softrank Log",
        "softranklogratio": "Softrank Log Ratio",
        "sigma_mp": "Marchenko-Pastur (MP) fit sigma",
        "numofSpikes": "Number of spikes per MP fit",
        "ratio_numofSpikes": "aka, percent_mass, Number of spikes / total number of evals",
        "softrank_mp": "Softrank for MP fit",
        "logpnorm": "alpha pNorm",
    }

    metrics_stats = []
    for metric in metrics:
        metrics_stats.append("{}_min".format(metric))
        metrics_stats.append("{}_max".format(metric))
        metrics_stats.append("{}_avg".format(metric))

        metrics_stats.append("{}_compound_min".format(metric))
        metrics_stats.append("{}_compound_max".format(metric))
        metrics_stats.append("{}_compound_avg".format(metric))

    columns = (
        [
            "layer_id",
            "layer_type",
            "N",
            "M",
            "layer_count",
            "slice",
            "slice_count",
            "level",
            "comment",
        ]
        + [*metrics]
        + metrics_stats
    )

    metrics_values = {}
    metrics_values_compound = {}

    for metric in metrics:
        metrics_values[metric] = []
        metrics_values_compound[metric] = []

    layer_count = 0
    for layer_id, result in results.items():
        layer_count += 1

        layer_type = np.NAN
        if "layer_type" in result:
            layer_type = str(result["layer_type"]).replace("LAYER_TYPE.", "")

        compounds = {}  # temp var
        for metric in metrics:
            compounds[metric] = []

        slice_count, Ntotal, Mtotal = 0, 0, 0
        for slice_id, summary in result.items():
            if not str(slice_id).isdigit():
                continue
            slice_count += 1

            N = np.NAN
            if "N" in summary:
                N = summary["N"]
                Ntotal += N

            M = np.NAN
            if "M" in summary:
                M = summary["M"]
                Mtotal += M

            data = {
                "layer_id": layer_id,
                "layer_type": layer_type,
                "N": N,
                "M": M,
                "slice": slice_id,
                "level": "SLICE",
                "comment": "Slice level",
            }
            for metric in metrics:
                if metric in summary:
                    value = summary[metric]
                    if value is not None:
                        metrics_values[metric].append(value)
                        compounds[metric].append(value)
                        data[metric] = value

        data = {
            "layer_id": layer_id,
            "layer_type": layer_type,
            "N": Ntotal,
            "M": Mtotal,
            "slice_count": slice_count,
            "level": "LAYER",
            "comment": "Layer level",
        }
        # Compute the compound value over the slices
        for metric, value in compounds.items():
            count = len(value)
            if count == 0:
                continue

            compound = np.mean(value)
            metrics_values_compound[metric].append(compound)
            data[metric] = compound

    data = {"layer_count": layer_count, "level": "NETWORK", "comment": "Network Level"}
    for metric, metric_name in metrics.items():
        if metric not in metrics_values or len(metrics_values[metric]) == 0:
            continue

        values = metrics_values[metric]
        minimum = min(values)
        maximum = max(values)
        avg = np.mean(values)
        final_summary[metric] = avg
        # print("{}: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
        data["{}_min".format(metric)] = minimum
        data["{}_max".format(metric)] = maximum
        data["{}_avg".format(metric)] = avg

        values = metrics_values_compound[metric]
        minimum = min(values)
        maximum = max(values)
        avg = np.mean(values)
        final_summary["{}_compound".format(metric)] = avg
        # print("{} compound: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
        data["{}_compound_min".format(metric)] = minimum
        data["{}_compound_max".format(metric)] = maximum
        data["{}_compound_avg".format(metric)] = avg

    return final_summary


def analyze(
    model: nn.Module,
    min_size=50,
    max_size=0,
    alphas: bool = False,
    lognorms: bool = True,
    spectralnorms: bool = False,
    softranks: bool = False,
    normalize: bool = False,
    glorot_fix: bool = False,
):
    """
    Analyze the weight matrices of a model.
    :param model: A PyTorch model
    :param min_size: The minimum weight matrix size to analyze.
    :param max_size: The maximum weight matrix size to analyze (0 = no limit).
    :param alphas: Compute the power laws (alpha) of the weight matrices.
      Time consuming so disabled by default (use lognorm if you want speed)
    :param lognorms: Compute the log norms of the weight matrices.
    :param spectralnorms: Compute the spectral norm (max eigenvalue) of the weight matrices.
    :param softranks: Compute the soft norm (i.e. StableRank) of the weight matrices.
    :param normalize: Normalize or not.
    :param glorot_fix:
    :return: (a dict of all layers' results, a dict of the summarized info)
    """
    names, modules = [], []
    for name, module in model.named_modules():
        if isinstance(module, available_module_types()):
            names.append(name)
            modules.append(module)
    # print('There are {:} layers to be analyzed in this model.'.format(len(modules)))
    all_results = OrderedDict()
    for index, module in enumerate(modules):
        if isinstance(module, nn.Linear):
            weights = [module.weight.cpu().detach().numpy()]
        else:
            weights = get_conv2D_Wmats(module.weight.cpu().detach().numpy())
        results = analyze_weights(
            weights,
            min_size,
            max_size,
            alphas,
            lognorms,
            spectralnorms,
            softranks,
            normalize,
            glorot_fix,
        )
        results["id"] = index
        results["type"] = type(module)
        all_results[index] = results
    summary = compute_details(all_results)
    return all_results, summary