58 lines
1.3 KiB
Python
58 lines
1.3 KiB
Python
from dks.base.activation_getter import (
|
|
get_activation_function as _get_numpy_activation_function,
|
|
)
|
|
from dks.base.activation_transform import _get_activations_params
|
|
|
|
|
|
def subnet_max_func(x, r_fn):
|
|
depth = 7
|
|
res_x = r_fn(x)
|
|
x = r_fn(x)
|
|
for _ in range(depth):
|
|
x = r_fn(r_fn(x)) + x
|
|
return max(x, res_x)
|
|
|
|
|
|
def subnet_max_func_v2(x, r_fn):
|
|
depth = 2
|
|
res_x = r_fn(x)
|
|
|
|
x = r_fn(x)
|
|
for _ in range(depth):
|
|
x = 0.8 * r_fn(r_fn(x)) + 0.2 * x
|
|
|
|
return max(x, res_x)
|
|
|
|
|
|
def get_transformed_activations(
|
|
activation_names,
|
|
method="TAT",
|
|
dks_params=None,
|
|
tat_params=None,
|
|
max_slope_func=None,
|
|
max_curv_func=None,
|
|
subnet_max_func=None,
|
|
activation_getter=_get_numpy_activation_function,
|
|
):
|
|
params = _get_activations_params(
|
|
activation_names,
|
|
method=method,
|
|
dks_params=dks_params,
|
|
tat_params=tat_params,
|
|
max_slope_func=max_slope_func,
|
|
max_curv_func=max_curv_func,
|
|
subnet_max_func=subnet_max_func,
|
|
)
|
|
return params
|
|
|
|
|
|
params = get_transformed_activations(
|
|
["swish"], method="TAT", subnet_max_func=subnet_max_func
|
|
)
|
|
print(params)
|
|
|
|
params = get_transformed_activations(
|
|
["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2
|
|
)
|
|
print(params)
|