Fix 1-element in norm bug

This commit is contained in:
D-X-Y 2021-05-12 19:09:17 +08:00
parent 80ccc49d92
commit 06f4a1f1cf
3 changed files with 15 additions and 8 deletions

View File

@ -58,6 +58,7 @@ def main(args):
# build model # build model
model = get_model(**model_kwargs) model = get_model(**model_kwargs)
print(model) print(model)
model.analyze_weights()
# build optimizer # build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
@ -168,7 +169,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--epochs", "--epochs",
type=int, type=int,
default=1000, default=300,
help="The total number of epochs.", help="The total number of epochs.",
) )
parser.add_argument( parser.add_argument(

View File

@ -40,13 +40,10 @@ def get_model(config: Dict[Text, Any], **kwargs):
norm_cls = super_name2norm[kwargs["norm_cls"]] norm_cls = super_name2norm[kwargs["norm_cls"]]
sub_layers, last_dim = [], kwargs["input_dim"] sub_layers, last_dim = [], kwargs["input_dim"]
for i, hidden_dim in enumerate(kwargs["hidden_dims"]): for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
sub_layers.extend( if last_dim > 1:
[ sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
norm_cls(last_dim, elementwise_affine=False), sub_layers.append(SuperLinear(last_dim, hidden_dim))
SuperLinear(last_dim, hidden_dim), sub_layers.append(act_cls())
act_cls(),
]
)
last_dim = hidden_dim last_dim = hidden_dim
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
model = SuperSequential(*sub_layers) model = SuperSequential(*sub_layers)

View File

@ -66,6 +66,15 @@ class SuperModule(abc.ABC, nn.Module):
container.append(name, buf, False) container.append(name, buf, False)
return container return container
def analyze_weights(self):
with torch.no_grad():
for name, param in self.named_parameters():
shapestr = "[{:10s}] shape={:}".format(name, list(param.shape))
finalstr = shapestr + "{:.2f} +- {:.2f}".format(
param.mean(), param.std()
)
print(finalstr)
@property @property
def abstract_search_space(self): def abstract_search_space(self):
raise NotImplementedError raise NotImplementedError