import torch import torch.nn as nn def copy_conv(module, init): assert isinstance(module, nn.Conv2d), 'invalid module : {:}'.format(module) assert isinstance(init , nn.Conv2d), 'invalid module : {:}'.format(init) new_i, new_o = module.in_channels, module.out_channels module.weight.copy_( init.weight.detach()[:new_o, :new_i] ) if module.bias is not None: module.bias.copy_( init.bias.detach()[:new_o] ) def copy_bn (module, init): assert isinstance(module, nn.BatchNorm2d), 'invalid module : {:}'.format(module) assert isinstance(init , nn.BatchNorm2d), 'invalid module : {:}'.format(init) num_features = module.num_features if module.weight is not None: module.weight.copy_( init.weight.detach()[:num_features] ) if module.bias is not None: module.bias.copy_( init.bias.detach()[:num_features] ) if module.running_mean is not None: module.running_mean.copy_( init.running_mean.detach()[:num_features] ) if module.running_var is not None: module.running_var.copy_( init.running_var.detach()[:num_features] ) def copy_fc (module, init): assert isinstance(module, nn.Linear), 'invalid module : {:}'.format(module) assert isinstance(init , nn.Linear), 'invalid module : {:}'.format(init) new_i, new_o = module.in_features, module.out_features module.weight.copy_( init.weight.detach()[:new_o, :new_i] ) if module.bias is not None: module.bias.copy_( init.bias.detach()[:new_o] ) def copy_base(module, init): assert type(module).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format(module) assert type( init).__name__ in ['ConvBNReLU', 'Downsample'], 'invalid module : {:}'.format( init) if module.conv is not None: copy_conv(module.conv, init.conv) if module.bn is not None: copy_bn (module.bn, init.bn) def copy_basic(module, init): copy_base(module.conv_a, init.conv_a) copy_base(module.conv_b, init.conv_b) if module.downsample is not None: if init.downsample is not None: copy_base(module.downsample, init.downsample) #else: # import pdb; pdb.set_trace() def init_from_model(network, init_model): with torch.no_grad(): copy_fc(network.classifier, init_model.classifier) for base, target in zip(init_model.layers, network.layers): assert type(base).__name__ == type(target).__name__, 'invalid type : {:} vs {:}'.format(base, target) if type(base).__name__ == 'ConvBNReLU': copy_base(target, base) elif type(base).__name__ == 'ResNetBasicblock': copy_basic(target, base) else: raise ValueError('unknown type name : {:}'.format( type(base).__name__ ))