2019-09-28 10:24:47 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_resnet(m):
|
2021-05-12 10:28:05 +02:00
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
|
|
if m.bias is not None:
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
|
if m.bias is not None:
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
elif isinstance(m, nn.Linear):
|
|
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
|
|
nn.init.constant_(m.bias, 0)
|