Update DKS exploration
This commit is contained in:
		
							
								
								
									
										57
									
								
								exps/experimental/test-dks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								exps/experimental/test-dks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| from dks.base.activation_getter import ( | ||||
|     get_activation_function as _get_numpy_activation_function, | ||||
| ) | ||||
| from dks.base.activation_transform import _get_activations_params | ||||
|  | ||||
|  | ||||
| def subnet_max_func(x, r_fn): | ||||
|     depth = 7 | ||||
|     res_x = r_fn(x) | ||||
|     x = r_fn(x) | ||||
|     for _ in range(depth): | ||||
|         x = r_fn(r_fn(x)) + x | ||||
|     return max(x, res_x) | ||||
|  | ||||
|  | ||||
| def subnet_max_func_v2(x, r_fn): | ||||
|     depth = 2 | ||||
|     res_x = r_fn(x) | ||||
|  | ||||
|     x = r_fn(x) | ||||
|     for _ in range(depth): | ||||
|         x = 0.8 * r_fn(r_fn(x)) + 0.2 * x | ||||
|  | ||||
|     return max(x, res_x) | ||||
|  | ||||
|  | ||||
| def get_transformed_activations( | ||||
|     activation_names, | ||||
|     method="TAT", | ||||
|     dks_params=None, | ||||
|     tat_params=None, | ||||
|     max_slope_func=None, | ||||
|     max_curv_func=None, | ||||
|     subnet_max_func=None, | ||||
|     activation_getter=_get_numpy_activation_function, | ||||
| ): | ||||
|     params = _get_activations_params( | ||||
|         activation_names, | ||||
|         method=method, | ||||
|         dks_params=dks_params, | ||||
|         tat_params=tat_params, | ||||
|         max_slope_func=max_slope_func, | ||||
|         max_curv_func=max_curv_func, | ||||
|         subnet_max_func=subnet_max_func, | ||||
|     ) | ||||
|     return params | ||||
|  | ||||
|  | ||||
| params = get_transformed_activations( | ||||
|     ["swish"], method="TAT", subnet_max_func=subnet_max_func | ||||
| ) | ||||
| print(params) | ||||
|  | ||||
| params = get_transformed_activations( | ||||
|     ["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2 | ||||
| ) | ||||
| print(params) | ||||
		Reference in New Issue
	
	Block a user