from dks.base.activation_getter import (
    get_activation_function as _get_numpy_activation_function,
)
from dks.base.activation_transform import _get_activations_params


def subnet_max_func(x, r_fn):
    depth = 7
    res_x = r_fn(x)
    x = r_fn(x)
    for _ in range(depth):
        x = r_fn(r_fn(x)) + x
    return max(x, res_x)


def subnet_max_func_v2(x, r_fn):
    depth = 2
    res_x = r_fn(x)

    x = r_fn(x)
    for _ in range(depth):
        x = 0.8 * r_fn(r_fn(x)) + 0.2 * x

    return max(x, res_x)


def get_transformed_activations(
    activation_names,
    method="TAT",
    dks_params=None,
    tat_params=None,
    max_slope_func=None,
    max_curv_func=None,
    subnet_max_func=None,
    activation_getter=_get_numpy_activation_function,
):
    params = _get_activations_params(
        activation_names,
        method=method,
        dks_params=dks_params,
        tat_params=tat_params,
        max_slope_func=max_slope_func,
        max_curv_func=max_curv_func,
        subnet_max_func=subnet_max_func,
    )
    return params


params = get_transformed_activations(
    ["swish"], method="TAT", subnet_max_func=subnet_max_func
)
print(params)

params = get_transformed_activations(
    ["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2
)
print(params)