diff --git a/exps/LFNA/lfna_models.py b/exps/LFNA/lfna_models.py index b4dbcbc..f07ffb0 100644 --- a/exps/LFNA/lfna_models.py +++ b/exps/LFNA/lfna_models.py @@ -34,7 +34,7 @@ class HyperNet(super_core.SuperModule): config=dict(model_type="dual_norm_mlp"), input_dim=layer_embeding + task_embedding, output_dim=max(self._numel_per_layer), - hidden_dims=[layer_embeding * 4] * 3, + hidden_dims=[layer_embeding * 2] * 3, act_cls="gelu", norm_cls="layer_norm_1d", dropout=0.1,