Fix 1-element in norm bug
This commit is contained in:
parent
80ccc49d92
commit
06f4a1f1cf
@ -58,6 +58,7 @@ def main(args):
|
|||||||
# build model
|
# build model
|
||||||
model = get_model(**model_kwargs)
|
model = get_model(**model_kwargs)
|
||||||
print(model)
|
print(model)
|
||||||
|
model.analyze_weights()
|
||||||
# build optimizer
|
# build optimizer
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
@ -168,7 +169,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epochs",
|
"--epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=1000,
|
default=300,
|
||||||
help="The total number of epochs.",
|
help="The total number of epochs.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -40,13 +40,10 @@ def get_model(config: Dict[Text, Any], **kwargs):
|
|||||||
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
||||||
sub_layers, last_dim = [], kwargs["input_dim"]
|
sub_layers, last_dim = [], kwargs["input_dim"]
|
||||||
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
||||||
sub_layers.extend(
|
if last_dim > 1:
|
||||||
[
|
sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
|
||||||
norm_cls(last_dim, elementwise_affine=False),
|
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
||||||
SuperLinear(last_dim, hidden_dim),
|
sub_layers.append(act_cls())
|
||||||
act_cls(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
last_dim = hidden_dim
|
last_dim = hidden_dim
|
||||||
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
||||||
model = SuperSequential(*sub_layers)
|
model = SuperSequential(*sub_layers)
|
||||||
|
@ -66,6 +66,15 @@ class SuperModule(abc.ABC, nn.Module):
|
|||||||
container.append(name, buf, False)
|
container.append(name, buf, False)
|
||||||
return container
|
return container
|
||||||
|
|
||||||
|
def analyze_weights(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
shapestr = "[{:10s}] shape={:}".format(name, list(param.shape))
|
||||||
|
finalstr = shapestr + "{:.2f} +- {:.2f}".format(
|
||||||
|
param.mean(), param.std()
|
||||||
|
)
|
||||||
|
print(finalstr)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def abstract_search_space(self):
|
def abstract_search_space(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
Loading…
Reference in New Issue
Block a user