Update DKS exploration
This commit is contained in:
parent
b557a22928
commit
5bf036a763
57
exps/experimental/test-dks.py
Normal file
57
exps/experimental/test-dks.py
Normal file
@ -0,0 +1,57 @@
|
||||
from dks.base.activation_getter import (
|
||||
get_activation_function as _get_numpy_activation_function,
|
||||
)
|
||||
from dks.base.activation_transform import _get_activations_params
|
||||
|
||||
|
||||
def subnet_max_func(x, r_fn):
|
||||
depth = 7
|
||||
res_x = r_fn(x)
|
||||
x = r_fn(x)
|
||||
for _ in range(depth):
|
||||
x = r_fn(r_fn(x)) + x
|
||||
return max(x, res_x)
|
||||
|
||||
|
||||
def subnet_max_func_v2(x, r_fn):
|
||||
depth = 2
|
||||
res_x = r_fn(x)
|
||||
|
||||
x = r_fn(x)
|
||||
for _ in range(depth):
|
||||
x = 0.8 * r_fn(r_fn(x)) + 0.2 * x
|
||||
|
||||
return max(x, res_x)
|
||||
|
||||
|
||||
def get_transformed_activations(
|
||||
activation_names,
|
||||
method="TAT",
|
||||
dks_params=None,
|
||||
tat_params=None,
|
||||
max_slope_func=None,
|
||||
max_curv_func=None,
|
||||
subnet_max_func=None,
|
||||
activation_getter=_get_numpy_activation_function,
|
||||
):
|
||||
params = _get_activations_params(
|
||||
activation_names,
|
||||
method=method,
|
||||
dks_params=dks_params,
|
||||
tat_params=tat_params,
|
||||
max_slope_func=max_slope_func,
|
||||
max_curv_func=max_curv_func,
|
||||
subnet_max_func=subnet_max_func,
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
params = get_transformed_activations(
|
||||
["swish"], method="TAT", subnet_max_func=subnet_max_func
|
||||
)
|
||||
print(params)
|
||||
|
||||
params = get_transformed_activations(
|
||||
["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2
|
||||
)
|
||||
print(params)
|
Loading…
Reference in New Issue
Block a user