63 lines
2.6 KiB
Python
63 lines
2.6 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
def copy_conv(module, init):
|
||
|
assert isinstance(module, nn.Conv2d), 'invalid module : {:}'.format(module)
|
||
|
assert isinstance(init , nn.Conv2d), 'invalid module : {:}'.format(init)
|
||
|
new_i, new_o = module.in_channels, module.out_channels
|
||
|
module.weight.copy_( init.weight.detach()[:new_o, :new_i] )
|
||
|
if module.bias is not None:
|
||
|
module.bias.copy_( init.bias.detach()[:new_o] )
|
||
|
|
||
|
def copy_bn (module, init):
|
||
|
assert isinstance(module, nn.BatchNorm2d), 'invalid module : {:}'.format(module)
|
||
|
assert isinstance(init , nn.BatchNorm2d), 'invalid module : {:}'.format(init)
|
||
|
num_features = module.num_features
|
||
|
if module.weight is not None:
|
||
|
module.weight.copy_( init.weight.detach()[:num_features] )
|
||
|
if module.bias is not None:
|
||
|
module.bias.copy_( init.bias.detach()[:num_features] )
|
||
|
if module.running_mean is not None:
|
||
|
module.running_mean.copy_( init.running_mean.detach()[:num_features] )
|
||
|
if module.running_var is not None:
|
||
|
module.running_var.copy_( init.running_var.detach()[:num_features] )
|
||
|
|
||
|
def copy_fc (module, init):
|
||
|
assert isinstance(module, nn.Linear), 'invalid module : {:}'.format(module)
|
||
|
assert isinstance(init , nn.Linear), 'invalid module : {:}'.format(init)
|
||
|
new_i, new_o = module.in_features, module.out_features
|
||
|
module.weight.copy_( init.weight.detach()[:new_o, :new_i] )
|
||
|
if module.bias is not None:
|
||
|
module.bias.copy_( init.bias.detach()[:new_o] )
|
||
|
|
||
|
def copy_base(module, init):
|
||
|
assert type(module).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format(module)
|
||
|
assert type( init).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format( init)
|
||
|
if module.conv is not None:
|
||
|
copy_conv(module.conv, init.conv)
|
||
|
if module.bn is not None:
|
||
|
copy_bn (module.bn, init.bn)
|
||
|
|
||
|
def copy_basic(module, init):
|
||
|
copy_base(module.conv_a, init.conv_a)
|
||
|
copy_base(module.conv_b, init.conv_b)
|
||
|
if module.downsample is not None:
|
||
|
if init.downsample is not None:
|
||
|
copy_base(module.downsample, init.downsample)
|
||
|
#else:
|
||
|
# import pdb; pdb.set_trace()
|
||
|
|
||
|
|
||
|
def init_from_model(network, init_model):
|
||
|
with torch.no_grad():
|
||
|
copy_fc(network.classifier, init_model.classifier)
|
||
|
for base, target in zip(init_model.layers, network.layers):
|
||
|
assert type(base).__name__ == type(target).__name__, 'invalid type : {:} vs {:}'.format(base, target)
|
||
|
if type(base).__name__ == 'ConvBNReLU':
|
||
|
copy_base(target, base)
|
||
|
elif type(base).__name__ == 'ResNetBasicblock':
|
||
|
copy_basic(target, base)
|
||
|
else:
|
||
|
raise ValueError('unknown type name : {:}'.format( type(base).__name__ ))
|